diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 5b871f69f9..6ce72e80df 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -189,6 +189,11 @@ class PluginConfig(BaseSettings): default="plugin-api-key", ) + PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field( + description="Timeout in seconds for requests to the plugin daemon (set to None to disable)", + default=300.0, + ) + INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") PLUGIN_REMOTE_INSTALL_HOST: str = Field( diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b1e3813f33..efab1bc03c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,4 +1,5 @@ import flask_restx +from flask import Response from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus from sqlalchemy import select @@ -7,8 +8,7 @@ from werkzeug.exceptions import Forbidden from extensions.ext_database import db from libs.helper import TimestampField -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.model import ApiToken, App @@ -57,9 +57,9 @@ class BaseApiKeyListResource(Resource): def get(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + _, current_tenant_id = current_account_with_tenant() + + _get_resource(resource_id, current_tenant_id, self.resource_model) keys = db.session.scalars( select(ApiToken).where( ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id @@ -71,9 +71,8 @@ class BaseApiKeyListResource(Resource): def post(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + current_user, current_tenant_id = current_account_with_tenant() + _get_resource(resource_id, current_tenant_id, self.resource_model) if not current_user.has_edit_permission: raise Forbidden() @@ -93,7 +92,7 @@ class BaseApiKeyListResource(Resource): key = ApiToken.generate_api_key(self.token_prefix or "", 24) api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) - api_token.tenant_id = current_user.current_tenant_id + api_token.tenant_id = current_tenant_id api_token.token = key api_token.type = self.resource_type db.session.add(api_token) @@ -112,9 +111,8 @@ class BaseApiKeyResource(Resource): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + current_user, current_tenant_id = current_account_with_tenant() + _get_resource(resource_id, current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: @@ -158,7 +156,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - def after_request(self, resp): + def after_request(self, resp: Response): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" return resp @@ -208,7 +206,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - def after_request(self, resp): + def after_request(self, resp: Response): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" return resp @@ -229,7 +227,7 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - def after_request(self, resp): + def after_request(self, resp: Response): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" return resp diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index d0ee11fe75..589955738c 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,7 +1,6 @@ from typing import Literal from flask import request -from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden @@ -17,7 +16,7 @@ from fields.annotation_fields import ( annotation_fields, annotation_hit_history_fields, ) -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from services.annotation_service import AppAnnotationService @@ -43,7 +42,9 @@ class AnnotationReplyActionApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def post(self, app_id, action: Literal["enable", "disable"]): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -70,7 +71,9 @@ class AppAnnotationSettingDetailApi(Resource): @login_required @account_initialization_required def get(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -99,7 +102,9 @@ class AppAnnotationSettingUpdateApi(Resource): @login_required @account_initialization_required def post(self, app_id, annotation_setting_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -125,7 +130,9 @@ class AnnotationReplyActionStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id, action): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() job_id = str(job_id) @@ -160,7 +167,9 @@ class AnnotationApi(Resource): @login_required @account_initialization_required def get(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() page = request.args.get("page", default=1, type=int) @@ -199,7 +208,9 @@ class AnnotationApi(Resource): @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -214,7 +225,9 @@ class AnnotationApi(Resource): @login_required @account_initialization_required def delete(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -250,7 +263,9 @@ class AnnotationExportApi(Resource): @login_required @account_initialization_required def get(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -273,7 +288,9 @@ class AnnotationUpdateDeleteApi(Resource): @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id, annotation_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -289,7 +306,9 @@ class AnnotationUpdateDeleteApi(Resource): @login_required @account_initialization_required def delete(self, app_id, annotation_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -311,7 +330,9 @@ class AnnotationBatchImportApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def post(self, app_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() app_id = str(app_id) @@ -342,7 +363,9 @@ class AnnotationBatchImportStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() job_id = str(job_id) @@ -377,7 +400,9 @@ class AnnotationHitHistoryListApi(Resource): @login_required @account_initialization_required def get(self, app_id, annotation_id): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() page = request.args.get("page", default=1, type=int) diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 037561cfed..5751ff1f86 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,6 +1,3 @@ -from typing import cast - -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -13,8 +10,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.app_fields import app_import_check_dependencies_fields, app_import_fields -from libs.login import login_required -from models import Account +from libs.login import current_account_with_tenant, login_required from models.model import App from services.app_dsl_service import AppDslService, ImportStatus from services.enterprise.enterprise_service import EnterpriseService @@ -32,7 +28,8 @@ class AppImportApi(Resource): @cloud_edition_billing_resource_check("apps") def post(self): # Check user role first - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -51,7 +48,7 @@ class AppImportApi(Resource): with Session(db.engine) as session: import_service = AppDslService(session) # Import app - account = cast(Account, current_user) + account = current_user result = import_service.import_app( account=account, import_mode=args["mode"], @@ -85,14 +82,15 @@ class AppImportConfirmApi(Resource): @marshal_with(app_import_fields) def post(self, import_id): # Check user role first - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() # Create service with session with Session(db.engine) as session: import_service = AppDslService(session) # Confirm import - account = cast(Account, current_user) + account = current_user result = import_service.confirm_import(import_id=import_id, account=account) session.commit() @@ -110,7 +108,8 @@ class AppImportCheckDependenciesApi(Resource): @account_initialization_required @marshal_with(app_import_check_dependencies_fields) def get(self, app_model: App): - if not current_user.is_editor: + current_user, _ = current_account_with_tenant() + if not current_user.has_edit_permission: raise Forbidden() with Session(db.engine) as session: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 230ccdca15..4a9b6e7801 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -from flask_login import current_user from flask_restx import Resource, fields, reqparse from controllers.console import api, console_ns @@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService @@ -48,11 +47,11 @@ class RuleGenerateApi(Resource): parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() - account = current_user try: rules = LLMGenerator.generate_rule_config( - tenant_id=account.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=args["no_variable"], @@ -99,11 +98,11 @@ class RuleCodeGenerateApi(Resource): parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() - account = current_user try: code_result = LLMGenerator.generate_code( - tenant_id=account.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], code_language=args["code_language"], @@ -144,11 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource): parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() - account = current_user try: structured_output = LLMGenerator.generate_structured_output( - tenant_id=account.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], ) @@ -198,6 +197,7 @@ class InstructionGenerateApi(Resource): parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument("ideal_output", type=str, required=False, default="", location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() code_template = ( Python3CodeProvider.get_default_code() if args["language"] == "python" @@ -222,21 +222,21 @@ class InstructionGenerateApi(Resource): match node_type: case "llm": return LLMGenerator.generate_rule_config( - current_user.current_tenant_id, + current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=True, ) case "agent": return LLMGenerator.generate_rule_config( - current_user.current_tenant_id, + current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=True, ) case "code": return LLMGenerator.generate_code( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], code_language=args["language"], @@ -245,7 +245,7 @@ class InstructionGenerateApi(Resource): return {"error": f"invalid node type: {node_type}"} if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow return LLMGenerator.instruction_modify_legacy( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, flow_id=args["flow_id"], current=args["current"], instruction=args["instruction"], @@ -254,7 +254,7 @@ class InstructionGenerateApi(Resource): ) if args["node_id"] != "" and args["current"] != "": # For workflow node return LLMGenerator.instruction_modify_workflow( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, flow_id=args["flow_id"], node_id=args["node_id"], current=args["current"], diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index fa6e3f8738..72ce8a7ddf 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -2,7 +2,6 @@ import json from typing import cast from flask import request -from flask_login import current_user from flask_restx import Resource, fields from werkzeug.exceptions import Forbidden @@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from libs.login import login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService @@ -54,16 +52,14 @@ class ModelConfigResource(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" - if not isinstance(current_user, Account): - raise Forbidden() + current_user, current_tenant_id = current_account_with_tenant() if not current_user.has_edit_permission: raise Forbidden() - assert current_user.current_tenant_id is not None, "The tenant information should be loaded." # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, config=cast(dict, request.json), app_mode=AppMode.value_of(app_model.mode), ) @@ -95,12 +91,12 @@ class ModelConfigResource(Resource): # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, ) manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, @@ -134,7 +130,7 @@ class ModelConfigResource(Resource): else: try: tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, ) @@ -142,7 +138,7 @@ class ModelConfigResource(Resource): continue manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 95befc5df9..6537ea5a50 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -9,7 +8,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import Account, Site @@ -76,9 +75,10 @@ class AppSite(Resource): @marshal_with(app_site_fields) def post(self, app_model): args = parse_app_site_args() + current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be editor, admin, or owner - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() site = db.session.query(Site).where(Site.app_id == app_model.id).first() @@ -131,6 +131,8 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_fields) def post(self, app_model): # The role of the current user in the ta table must be admin or owner + current_user, _ = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 207303b212..d9ab7de29b 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,10 +1,9 @@ -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.auth.error import ApiKeyAuthFailedError -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService from ..wraps import account_initialization_required, setup_required @@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource): @login_required @account_initialization_required def get(self): - data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id) if data_source_api_key_bindings: return { "sources": [ @@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource): @account_initialization_required def post(self): # The role of the current user in the table must be admin or owner + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() @@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource): args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: - ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) + ApiKeyAuthService.create_provider_auth(current_tenant_id, args) except Exception as e: raise ApiKeyAuthFailedError(str(e)) return {"result": "success"}, 200 @@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @account_initialization_required def delete(self, binding_id): # The role of the current user in the table must be admin or owner + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() - ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) + ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id) return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 6d9d675e87..058ef4408a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -3,7 +3,6 @@ from collections.abc import Generator from typing import cast from flask import request -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService from services.datasource_provider_service import DatasourceProviderService @@ -37,10 +36,12 @@ class DataSourceApi(Resource): @account_initialization_required @marshal_with(integrate_list_fields) def get(self): + _, current_tenant_id = current_account_with_tenant() + # get workspace data source integrates data_source_integrates = db.session.scalars( select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.tenant_id == current_tenant_id, DataSourceOauthBinding.disabled == False, ) ).all() @@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource): @account_initialization_required @marshal_with(integrate_notion_info_list_fields) def get(self): + current_user, current_tenant_id = current_account_with_tenant() + dataset_id = request.args.get("dataset_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str) if not credential_id: raise ValueError("Credential id is required.") datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, credential_id=credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", @@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource): documents = session.scalars( select(Document).filter_by( dataset_id=dataset_id, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, data_source_type="notion_import", enabled=True, ) @@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource): datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id="langgenius/notion_datasource/notion_datasource", datasource_name="notion_datasource", - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, ) datasource_provider_service = DatasourceProviderService() @@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def get(self, workspace_id, page_id, page_type): + _, current_tenant_id = current_account_with_tenant() + credential_id = request.args.get("credential_id", default=None, type=str) if not credential_id: raise ValueError("Credential id is required.") datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, credential_id=credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", @@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource): notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=credential.get("integration_secret"), - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, ) text_docs = extractor.extract() @@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def post(self): + _, current_tenant_id = current_account_with_tenant() + parser = reqparse.RequestParser() parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") @@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource): "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, + "tenant_id": current_tenant_id, } ), document_model=args["doc_form"], @@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource): extract_settings.append(extract_setting) indexing_runner = IndexingRunner() response = indexing_runner.indexing_estimate( - current_user.current_tenant_id, + current_tenant_id, extract_settings, args["process_rule"], args["doc_form"], diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index d6bd02483d..d4d484a2e2 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,7 +1,6 @@ import uuid from flask import request -from flask_login import current_user from flask_restx import Resource, marshal, reqparse from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound @@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService, SegmentService @@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource): @login_required @account_initialization_required def get(self, dataset_id, document_id): + current_user, current_tenant_id = current_account_with_tenant() + dataset_id = str(dataset_id) document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) @@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource): select(DocumentSegment) .where( DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id, + DocumentSegment.tenant_id == current_tenant_id, ) .order_by(DocumentSegment.position.asc()) ) @@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id): + current_user, _ = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, action): + current_user, current_tenant_id = current_account_with_tenant() + dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: @@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): upload_file_id, dataset_id, document_id, - current_user.current_tenant_id, + current_tenant_id, current_user.id, ) except Exception as e: @@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource): @login_required @account_initialization_required def get(self, dataset_id, document_id, segment_id): + _, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def delete(self, dataset_id, document_id, segment_id, child_chunk_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource): db.session.query(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), - ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) @@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def patch(self, dataset_id, document_id, segment_id, child_chunk_id): + current_user, current_tenant_id = current_account_with_tenant() + # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) @@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource): segment_id = str(segment_id) segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) .first() ) if not segment: @@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource): db.session.query(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), - ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 53b5a0d965..4c34df7eff 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,5 +1,4 @@ from flask import make_response, redirect, request -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -13,7 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler from libs.helper import StrLen -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService @@ -25,9 +24,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): @login_required @account_initialization_required def get(self, provider_id: str): - user = current_user - tenant_id = user.current_tenant_id - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + tenant_id = current_tenant_id + if not current_user.has_edit_permission: raise Forbidden() credential_id = request.args.get("credential_id") @@ -52,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" authorization_url_response = oauth_handler.get_authorization_url( tenant_id=tenant_id, - user_id=user.id, + user_id=current_user.id, plugin_id=plugin_id, provider=provider_name, redirect_uri=redirect_uri, @@ -131,7 +131,9 @@ class DatasourceAuth(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() @@ -145,7 +147,7 @@ class DatasourceAuth(Resource): try: datasource_provider_service.add_datasource_api_key_provider( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider_id=datasource_provider_id, credentials=args["credentials"], name=args["name"], @@ -160,8 +162,10 @@ class DatasourceAuth(Resource): def get(self, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() + _, current_tenant_id = current_account_with_tenant() + datasources = datasource_provider_service.list_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, ) @@ -174,17 +178,19 @@ class DatasourceAuthDeleteApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): + current_user, current_tenant_id = current_account_with_tenant() + datasource_provider_id = DatasourceProviderID(provider_id) plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, auth_id=args["credential_id"], provider=provider_name, plugin_id=plugin_id, @@ -198,17 +204,19 @@ class DatasourceAuthUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): + current_user, current_tenant_id = current_account_with_tenant() + datasource_provider_id = DatasourceProviderID(provider_id) parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_credentials( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, auth_id=args["credential_id"], provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, @@ -224,10 +232,10 @@ class DatasourceAuthListApi(Resource): @login_required @account_initialization_required def get(self): + _, current_tenant_id = current_account_with_tenant() + datasource_provider_service = DatasourceProviderService() - datasources = datasource_provider_service.get_all_datasource_credentials( - tenant_id=current_user.current_tenant_id - ) + datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id) return {"result": jsonable_encoder(datasources)}, 200 @@ -237,10 +245,10 @@ class DatasourceHardCodeAuthListApi(Resource): @login_required @account_initialization_required def get(self): + _, current_tenant_id = current_account_with_tenant() + datasource_provider_service = DatasourceProviderService() - datasources = datasource_provider_service.get_hard_code_datasource_credentials( - tenant_id=current_user.current_tenant_id - ) + datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id) return {"result": jsonable_encoder(datasources)}, 200 @@ -250,7 +258,9 @@ class DatasourceAuthOauthCustomClient(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") @@ -259,7 +269,7 @@ class DatasourceAuthOauthCustomClient(Resource): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.setup_oauth_custom_client_params( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, client_params=args.get("client_params", {}), enabled=args.get("enable_oauth_custom_client", False), @@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource): @login_required @account_initialization_required def delete(self, provider_id: str): + _, current_tenant_id = current_account_with_tenant() + datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_oauth_custom_client_params( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, ) return {"result": "success"}, 200 @@ -285,7 +297,9 @@ class DatasourceAuthDefaultApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, nullable=False, location="json") @@ -293,7 +307,7 @@ class DatasourceAuthDefaultApi(Resource): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.set_default_datasource_provider( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, credential_id=args["id"], ) @@ -306,7 +320,9 @@ class DatasourceUpdateProviderNameApi(Resource): @login_required @account_initialization_required def post(self, provider_id: str): - if not current_user.is_editor: + current_user, current_tenant_id = current_account_with_tenant() + + if not current_user.has_edit_permission: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") @@ -315,7 +331,7 @@ class DatasourceUpdateProviderNameApi(Resource): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_provider_name( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, name=args["name"], credential_id=args["credential_id"], diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 404aa42073..b394887783 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import Resource, marshal, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -13,7 +12,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.dataset_fields import dataset_detail_fields -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity @@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource): ) args = parser.parse_args() - + current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() @@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource): with Session(db.engine) as session: rag_pipeline_dsl_service = RagPipelineDslService(session) import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, ) if rag_pipeline_dataset_create_entity.permission == "partial_members": DatasetPermissionService.update_partial_member_list( - current_user.current_tenant_id, + current_tenant_id, import_info["dataset_id"], rag_pipeline_dataset_create_entity.partial_member_list, ) @@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") def post(self): # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity( name="", description="", diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index c86c243c9b..7ead93a1b6 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import current_user, login_required -from models import Account, App, InstalledApp, RecommendedApp +from libs.login import current_account_with_tenant, login_required +from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): app_id = request.args.get("app_id", default=None, type=str) - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - current_tenant_id = current_user.current_tenant_id + current_user, current_tenant_id = current_account_with_tenant() if app_id: installed_apps = db.session.scalars( @@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource): if recommended_app is None: raise NotFound("App not found") - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - current_tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() + app = db.session.query(App).where(App.id == args["app_id"]).first() if app is None: @@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource): """ def delete(self, installed_app): - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - if installed_app.app_owner_tenant_id == current_user.current_tenant_id: + _, current_tenant_id = current_account_with_tenant() + if installed_app.app_owner_tenant_id == current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") db.session.delete(installed_app) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index c6b3cf7515..dd618307e9 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -4,7 +4,7 @@ from constants import HIDDEN_VALUE from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields -from libs.login import current_user, login_required +from libs.login import current_account_with_tenant, current_user, login_required from models.account import Account from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService @@ -47,9 +47,7 @@ class APIBasedExtensionAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def get(self): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) @api.doc("create_api_based_extension") @@ -77,9 +75,10 @@ class APIBasedExtensionAPI(Resource): parser.add_argument("api_endpoint", type=str, required=True, location="json") parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() extension_data = APIBasedExtension( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, name=args["name"], api_endpoint=args["api_endpoint"], api_key=args["api_key"], @@ -102,7 +101,7 @@ class APIBasedExtensionDetailAPI(Resource): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None api_based_extension_id = str(id) - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) @@ -128,9 +127,9 @@ class APIBasedExtensionDetailAPI(Resource): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None api_based_extension_id = str(id) - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") @@ -157,9 +156,9 @@ class APIBasedExtensionDetailAPI(Resource): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None api_based_extension_id = str(id) - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) APIBasedExtensionService.delete(extension_data_from_db) diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 80847b8fef..39bcf3424c 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,7 +1,6 @@ from flask_restx import Resource, fields -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.feature_service import FeatureService from . import api, console_ns @@ -23,9 +22,9 @@ class FeatureApi(Resource): @cloud_utm_record def get(self): """Get feature configuration for current tenant""" - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - return FeatureService.get_features(current_user.current_tenant_id).model_dump() + _, current_tenant_id = current_account_with_tenant() + + return FeatureService.get_features(current_tenant_id).model_dump() @console_ns.route("/system-features") diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 34f186e2f0..2b63f6febc 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -108,4 +108,4 @@ class FileSupportTypeApi(Resource): @login_required @account_initialization_required def get(self): - return {"allowed_extensions": DOCUMENT_EXTENSIONS} + return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)} diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 4d4bb5d779..b053f222df 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -14,8 +14,7 @@ from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields -from libs.login import current_user -from models.account import Account +from libs.login import current_account_with_tenant from services.file_service import FileService from . import console_ns @@ -64,8 +63,7 @@ class RemoteFileUploadApi(Resource): content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: - assert isinstance(current_user, Account) - user = current_user + user, _ = current_account_with_tenant() upload_file = FileService(db.engine).upload_file( filename=file_info.filename, content=content, diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 782bd72565..b31011b4a3 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -5,18 +5,10 @@ from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError -from libs.login import current_user, login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService -def _current_account_with_tenant() -> tuple[Account, str]: - assert isinstance(current_user, Account) - tenant_id = current_user.current_tenant_id - assert tenant_id is not None - return current_user, tenant_id - - @console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): @api.doc("create_endpoint") @@ -41,7 +33,7 @@ class EndpointCreateApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() @@ -87,7 +79,7 @@ class EndpointListApi(Resource): @login_required @account_initialization_required def get(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource): @login_required @account_initialization_required def get(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -255,7 +247,7 @@ class EndpointEnableApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -288,7 +280,7 @@ class EndpointDisableApi(Resource): @login_required @account_initialization_required def post(self): - user, tenant_id = _current_account_with_tenant() + user, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index dd6a878d87..4f080708cc 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields from libs.helper import extract_remote_ip -from libs.login import current_user, login_required +from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError @@ -41,8 +41,7 @@ class MemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) @@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource): interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() inviter = current_user if not inviter.current_tenant: raise ValueError("No current tenant") @@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") member = db.session.query(Account).where(Account.id == str(member_id)).first() @@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource): if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") member = db.session.get(Account, str(member_id)) @@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) @@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource): ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - + current_user, _ = current_account_with_tenant() # check if the current user is the owner of the workspace - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): @@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource): parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() # check if the current user is the owner of the workspace - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): @@ -296,8 +286,7 @@ class OwnerTransfer(Resource): args = parser.parse_args() # check if the current user is the owner of the workspace - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 7012580362..acdd467b30 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,7 +1,6 @@ import io from flask import send_file -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value -from libs.login import login_required -from models.account import Account +from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -23,11 +21,8 @@ class ModelProviderListApi(Resource): @login_required @account_initialization_required def get(self): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() + tenant_id = current_tenant_id parser = reqparse.RequestParser() parser.add_argument( @@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + _, current_tenant_id = current_account_with_tenant() + tenant_id = current_tenant_id # if credential_id is not provided, return current used credential parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") @@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() @@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") try: model_provider_service.create_provider_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, credentials=args["credentials"], credential_name=args["name"], @@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() @@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") try: model_provider_service.update_provider_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, credentials=args["credentials"], credential_id=args["credential_id"], @@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") model_provider_service = ModelProviderService() model_provider_service.remove_provider_credential( - tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] + tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"] ) return {"result": "success"}, 204 @@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") service = ModelProviderService() service.switch_active_provider_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"], ) @@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + _, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + tenant_id = current_tenant_id model_provider_service = ModelProviderService() @@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() - if not current_user.current_tenant_id: - raise ValueError("No current tenant") - tenant_id = current_user.current_tenant_id + tenant_id = current_tenant_id parser = reqparse.RequestParser() parser.add_argument( @@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): def get(self, provider: str): if provider != "anthropic": raise ValueError(f"provider name {provider} is invalid") - if not isinstance(current_user, Account): - raise ValueError("Invalid user account") + current_user, current_tenant_id = current_account_with_tenant() BillingService.is_tenant_owner_or_admin(current_user) - if not current_user.current_tenant_id: - raise ValueError("No current tenant") data = BillingService.get_model_provider_payment_link( provider_name=provider, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, account_id=current_user.id, prefilled_email=current_user.email, ) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d38bb16ea7..d5d1aed00e 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -23,6 +22,8 @@ class DefaultModelApi(Resource): @login_required @account_initialization_required def get(self): + _, tenant_id = current_account_with_tenant() + parser = reqparse.RequestParser() parser.add_argument( "model_type", @@ -34,8 +35,6 @@ class DefaultModelApi(Resource): ) args = parser.parse_args() - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( tenant_id=tenant_id, model_type=args["model_type"] @@ -47,15 +46,14 @@ class DefaultModelApi(Resource): @login_required @account_initialization_required def post(self): + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() - - tenant_id = current_user.current_tenant_id - model_provider_service = ModelProviderService() model_settings = args["model_settings"] for model_setting in model_settings: @@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource): @login_required @account_initialization_required def get(self, provider): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) @@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource): @account_initialization_required def post(self, provider: str): # To save the model's load balance configs + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource): raise ValueError("credential_id is required when configuring a custom-model") service = ModelProviderService() service.switch_active_custom_model_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id - parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") parser.add_argument( @@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="args") @@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + current_user, tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() @@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource): parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() try: @@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() @@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource): try: model_provider_service.update_model_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() @@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource): model_provider_service = ModelProviderService() model_provider_service.remove_model_credential( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + current_user, current_tenant_id = current_account_with_tenant() + if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() @@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource): service = ModelProviderService() service.add_model_credential_to_model_list( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=provider, model_type=args["model_type"], model=args["model"], @@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource): @login_required @account_initialization_required def patch(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource): @login_required @account_initialization_required def patch(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="json") @@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource): parser = reqparse.RequestParser() parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( @@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource): @login_required @account_initialization_required def get(self, model_type): - tenant_id = current_user.current_tenant_id - + _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 7c70fb8aa0..ed5426376f 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -1,7 +1,6 @@ import io from flask import request, send_file -from flask_login import current_user from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden @@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginDaemonClientSideError -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_parameter_service import PluginParameterService @@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource): @account_initialization_required @plugin_permission_required(debug_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return { @@ -44,7 +43,7 @@ class PluginListApi(Resource): @login_required @account_initialization_required def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=False, location="args", default=1) parser.add_argument("page_size", type=int, required=False, location="args", default=256) @@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource): @login_required @account_initialization_required def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_ids", type=list, required=True, location="json") @@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() file = request.files["pkg"] @@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") @@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() file = request.files["bundle"] @@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") @@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") @@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") @@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") @@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") @@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def get(self, task_id: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) @@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self, task_id: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} @@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} @@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self, task_id: str, identifier: str): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} @@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") @@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource): @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") @@ -466,7 +465,7 @@ class PluginUninstallApi(Resource): req.add_argument("plugin_installation_id", type=str, required=True, location="json") args = req.parse_args() - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() try: return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} @@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource): @login_required @account_initialization_required def post(self): + current_user, current_tenant_id = current_account_with_tenant() user = current_user if not user.is_admin_or_owner: raise Forbidden() @@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource): install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) - tenant_id = user.current_tenant_id + tenant_id = current_tenant_id return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} @@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource): @login_required @account_initialization_required def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() permission = PluginPermissionService.get_permission(tenant_id) if not permission: @@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource): @account_initialization_required def get(self): # check if the user is admin or owner + current_user, tenant_id = current_account_with_tenant() if not current_user.is_admin_or_owner: raise Forbidden() - tenant_id = current_user.current_tenant_id user_id = current_user.id parser = reqparse.RequestParser() @@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() @@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource): req.add_argument("auto_upgrade", type=dict, required=True, location="json") args = req.parse_args() - tenant_id = user.current_tenant_id - permission = args["permission"] install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone")) @@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource): @login_required @account_initialization_required def get(self): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() permission = PluginPermissionService.get_permission(tenant_id) permission_dict = { @@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource): @account_initialization_required def post(self): # exclude one single plugin - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() req = reqparse.RequestParser() req.add_argument("plugin_id", type=str, required=True, location="json") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 34474ef506..5d91b37702 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -2,7 +2,6 @@ import io from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_login import current_user from flask_restx import ( Resource, reqparse, @@ -26,7 +25,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import CredentialType from extensions.ext_database import db from libs.helper import StrLen, alphanumeric, uuid_value -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService @@ -55,10 +54,9 @@ class ToolProviderListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id req = reqparse.RequestParser() req.add_argument( @@ -80,9 +78,7 @@ class ToolBuiltinProviderListToolsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user = current_user - - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.list_builtin_tool_provider_tools( @@ -98,9 +94,7 @@ class ToolBuiltinProviderInfoApi(Resource): @login_required @account_initialization_required def get(self, provider): - user = current_user - - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) @@ -111,11 +105,10 @@ class ToolBuiltinProviderDeleteApi(Resource): @login_required @account_initialization_required def post(self, provider): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() - tenant_id = user.current_tenant_id req = reqparse.RequestParser() req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = req.parse_args() @@ -133,10 +126,9 @@ class ToolBuiltinProviderAddApi(Resource): @login_required @account_initialization_required def post(self, provider): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") @@ -163,13 +155,12 @@ class ToolBuiltinProviderUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") @@ -195,7 +186,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.get_builtin_tool_provider_credentials( @@ -220,13 +211,12 @@ class ToolApiProviderAddApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") @@ -260,10 +250,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -284,10 +273,9 @@ class ToolApiProviderListToolsApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -310,13 +298,12 @@ class ToolApiProviderUpdateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") @@ -352,13 +339,12 @@ class ToolApiProviderDeleteApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -379,10 +365,9 @@ class ToolApiProviderGetApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() @@ -403,8 +388,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @login_required @account_initialization_required def get(self, provider, credential_type): - user = current_user - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.list_builtin_provider_credentials_schema( @@ -446,9 +430,9 @@ class ToolApiProviderPreviousTestApi(Resource): parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - + _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( - current_user.current_tenant_id, + current_tenant_id, args["provider_name"] or "", args["tool_name"], args["credentials"], @@ -464,13 +448,12 @@ class ToolWorkflowProviderCreateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") @@ -504,13 +487,12 @@ class ToolWorkflowProviderUpdateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") @@ -547,13 +529,12 @@ class ToolWorkflowProviderDeleteApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() user_id = user.id - tenant_id = user.current_tenant_id reqparser = reqparse.RequestParser() reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") @@ -573,10 +554,9 @@ class ToolWorkflowProviderGetApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") @@ -608,10 +588,9 @@ class ToolWorkflowProviderListToolApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") @@ -633,10 +612,9 @@ class ToolBuiltinListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id return jsonable_encoder( [ @@ -655,8 +633,7 @@ class ToolApiListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( [ @@ -674,10 +651,9 @@ class ToolWorkflowListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = current_account_with_tenant() user_id = user.id - tenant_id = user.current_tenant_id return jsonable_encoder( [ @@ -711,19 +687,18 @@ class ToolPluginOAuthApi(Resource): provider_name = tool_provider.provider_name # todo check permission - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() - tenant_id = user.current_tenant_id oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) if oauth_client_params is None: raise Forbidden("no oauth available client config found for this tool provider") oauth_handler = OAuthHandler() context_id = OAuthProxyService.create_proxy_context( - user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name + user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name ) redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" authorization_url_response = oauth_handler.get_authorization_url( @@ -802,11 +777,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource): @login_required @account_initialization_required def post(self, provider): + current_user, current_tenant_id = current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return BuiltinToolManageService.set_default_provider( - tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] ) @@ -821,13 +797,13 @@ class ToolOAuthCustomClient(Resource): parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") args = parser.parse_args() - user = current_user + user, tenant_id = current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() return BuiltinToolManageService.save_custom_oauth_client_params( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider=provider, client_params=args.get("client_params", {}), enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), @@ -837,20 +813,18 @@ class ToolOAuthCustomClient(Resource): @login_required @account_initialization_required def get(self, provider): + _, current_tenant_id = current_account_with_tenant() return jsonable_encoder( - BuiltinToolManageService.get_custom_oauth_client_params( - tenant_id=current_user.current_tenant_id, provider=provider - ) + BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider) ) @setup_required @login_required @account_initialization_required def delete(self, provider): + _, current_tenant_id = current_account_with_tenant() return jsonable_encoder( - BuiltinToolManageService.delete_custom_oauth_client_params( - tenant_id=current_user.current_tenant_id, provider=provider - ) + BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider) ) @@ -860,9 +834,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): @login_required @account_initialization_required def get(self, provider): + _, current_tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( - tenant_id=current_user.current_tenant_id, provider_name=provider + tenant_id=current_tenant_id, provider_name=provider ) ) @@ -873,7 +848,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): @login_required @account_initialization_required def get(self, provider): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() return jsonable_encoder( BuiltinToolManageService.get_builtin_tool_provider_credential_info( @@ -900,6 +875,7 @@ class ToolProviderMCPApi(Resource): parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) args = parser.parse_args() + user, tenant_id = current_account_with_tenant() # Validate server URL if not is_valid_url(args["server_url"]): @@ -913,8 +889,8 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) result = service.create_provider( - tenant_id=current_user.current_tenant_id, - user_id=current_user.id, + tenant_id=tenant_id, + user_id=user.id, server_url=args["server_url"], name=args["name"], icon=args["icon"], @@ -954,10 +930,12 @@ class ToolProviderMCPApi(Resource): raise ValueError("Server URL is not valid.") configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + _, current_tenant_id = current_account_with_tenant() + with Session(db.engine) as session: service = MCPToolManageService(session=session) service.update_provider( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider_id=args["provider_id"], server_url=args["server_url"], name=args["name"], @@ -982,9 +960,11 @@ class ToolProviderMCPApi(Resource): parser = reqparse.RequestParser() parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + _, current_tenant_id = current_account_with_tenant() + with Session(db.engine) as session: service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) + service.delete_provider(tenant_id=current_tenant_id.current_tenant_id, provider_id=args["provider_id"]) session.commit() return {"result": "success"} @@ -1000,7 +980,7 @@ class ToolMCPAuthApi(Resource): parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json") args = parser.parse_args() provider_id = args["provider_id"] - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() with Session(db.engine) as session: service = MCPToolManageService(session=session) @@ -1049,10 +1029,10 @@ class ToolMCPDetailApi(Resource): @login_required @account_initialization_required def get(self, provider_id): - user = current_user + _, tenant_id = current_account_with_tenant() with Session(db.engine) as session: service = MCPToolManageService(session=session) - provider = service.get_provider(provider_id=provider_id, tenant_id=user.current_tenant_id) + provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) @@ -1062,8 +1042,7 @@ class ToolMCPListAllApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user - tenant_id = user.current_tenant_id + _, tenant_id = current_account_with_tenant() with Session(db.engine) as session: service = MCPToolManageService(session=session) @@ -1078,7 +1057,7 @@ class ToolMCPUpdateApi(Resource): @login_required @account_initialization_required def get(self, provider_id): - tenant_id = current_user.current_tenant_id + _, tenant_id = current_account_with_tenant() with Session(db.engine) as session: service = MCPToolManageService(session=session) tools = service.list_provider_tools( diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 9e903d9286..4158f0524f 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -12,8 +12,8 @@ from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.login import current_user -from models.account import Account, AccountStatus +from libs.login import current_account_with_tenant +from models.account import AccountStatus from models.dataset import RateLimitLog from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus @@ -25,16 +25,13 @@ P = ParamSpec("P") R = TypeVar("R") -def _current_account() -> Account: - assert isinstance(current_user, Account) - return current_user - - def account_initialization_required(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): # check account initialization - account = _current_account() + current_user, _ = current_account_with_tenant() + + account = current_user if account.status == AccountStatus.UNINITIALIZED: raise AccountNotInitializedError() @@ -80,9 +77,8 @@ def only_edition_self_hosted(view: Callable[P, R]): def cloud_edition_billing_enabled(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if not features.billing.enabled: abort(403, "Billing feature is not enabled.") return view(*args, **kwargs) @@ -94,10 +90,8 @@ def cloud_edition_billing_resource_check(resource: str): def interceptor(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - tenant_id = account.current_tenant_id - features = FeatureService.get_features(tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: members = features.members apps = features.apps @@ -138,9 +132,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): def interceptor(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: if resource == "add_segment": if features.billing.subscription.plan == "sandbox": @@ -163,13 +156,11 @@ def cloud_edition_billing_rate_limit_check(resource: str): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): if resource == "knowledge": - account = _current_account() - assert account.current_tenant_id is not None - tenant_id = account.current_tenant_id - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id) + _, current_tenant_id = current_account_with_tenant() + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id) if knowledge_rate_limit.enabled: current_time = int(time.time() * 1000) - key = f"rate_limit_{tenant_id}" + key = f"rate_limit_{current_tenant_id}" redis_client.zadd(key, {current_time: current_time}) @@ -180,7 +171,7 @@ def cloud_edition_billing_rate_limit_check(resource: str): if request_count > knowledge_rate_limit.limit: # add ratelimit record rate_limit_log = RateLimitLog( - tenant_id=tenant_id, + tenant_id=current_tenant_id, subscription_plan=knowledge_rate_limit.subscription_plan, operation="knowledge", ) @@ -200,17 +191,15 @@ def cloud_utm_record(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): with contextlib.suppress(Exception): - account = _current_account() - assert account.current_tenant_id is not None - tenant_id = account.current_tenant_id - features = FeatureService.get_features(tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: utm_info = request.cookies.get("utm_info") if utm_info: utm_info_dict: dict = json.loads(utm_info) - OperationService.record_utm(tenant_id, utm_info_dict) + OperationService.record_utm(current_tenant_id, utm_info_dict) return view(*args, **kwargs) @@ -289,9 +278,8 @@ def enable_change_email(view: Callable[P, R]): def is_allow_transfer_owner(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.is_allow_transfer_workspace: return view(*args, **kwargs) @@ -301,12 +289,11 @@ def is_allow_transfer_owner(view: Callable[P, R]): return decorated -def knowledge_pipeline_publish_enabled(view): +def knowledge_pipeline_publish_enabled(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): - account = _current_account() - assert account.current_tenant_id is not None - features = FeatureService.get_features(account.current_tenant_id) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + features = FeatureService.get_features(current_tenant_id) if features.knowledge_pipeline.publish_enabled: return view(*args, **kwargs) abort(403) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index d674c7467d..acbbf4531b 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,5 +1,4 @@ from flask import request -from flask_login import current_user from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound @@ -16,6 +15,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs @@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: str, document_id: str): + _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource): } ) def get(self, tenant_id: str, dataset_id: str, document_id: str): + _, current_tenant_id = current_account_with_tenant() """Get segments.""" # check dataset page = request.args.get("page", default=1, type=int) @@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource): segments, total = SegmentService.get_segments( document_id=document_id, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, status_list=args["status"], keyword=args["keyword"], page=page, @@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource): ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) @@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource): } ) def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource): try: model_manager = ModelManager() model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, @@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource): } ) def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): + _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): + _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # check child chunk - child_chunk = SegmentService.get_child_chunk_by_id( - child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id - ) + child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id) if not child_chunk: raise NotFound("Child chunk not found.") @@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource): @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): + _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Document not found.") # get segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # get child chunk - child_chunk = SegmentService.get_child_chunk_by_id( - child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id - ) + child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 4b246a53d3..074555e31b 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,10 +1,12 @@ import logging import queue +import threading import time from abc import abstractmethod from enum import IntEnum, auto from typing import Any +from cachetools import TTLCache, cachedmethod from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -45,6 +47,8 @@ class AppQueueManager: q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self._q = q + self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1) + self._cache_lock = threading.Lock() def listen(self): """ @@ -157,6 +161,7 @@ class AppQueueManager: stopped_cache_key = cls._generate_stopped_cache_key(task_id) redis_client.setex(stopped_cache_key, 600, 1) + @cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock) def _is_stopped(self) -> bool: """ Check if task is stopped diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 952fefdbbc..5095b46432 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -2,7 +2,7 @@ import inspect import json import logging from collections.abc import Callable, Generator -from typing import Any, TypeVar +from typing import Any, TypeVar, cast import httpx from pydantic import BaseModel @@ -31,6 +31,17 @@ from core.plugin.impl.exc import ( ) plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) +_plugin_daemon_timeout_config = cast( + float | httpx.Timeout | None, + getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0), +) +plugin_daemon_request_timeout: httpx.Timeout | None +if _plugin_daemon_timeout_config is None: + plugin_daemon_request_timeout = None +elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout): + plugin_daemon_request_timeout = _plugin_daemon_timeout_config +else: + plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config) T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) @@ -58,6 +69,7 @@ class BasePluginClient: "headers": headers, "params": params, "files": files, + "timeout": plugin_daemon_request_timeout, } if isinstance(prepared_data, dict): request_kwargs["data"] = prepared_data @@ -116,6 +128,7 @@ class BasePluginClient: "headers": headers, "params": params, "files": files, + "timeout": plugin_daemon_request_timeout, } if isinstance(prepared_data, dict): stream_kwargs["data"] = prepared_data diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index c2f17cd148..937b8f033c 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -43,8 +43,7 @@ class CacheEmbedding(Embeddings): else: embedding_queue_indices.append(i) - # release database connection, because embedding may take a long time - db.session.close() + # NOTE: avoid closing the shared scoped session here; downstream code may still have pending work if embedding_queue_indices: embedding_queue_texts = [texts[i] for i in embedding_queue_indices] diff --git a/api/libs/login.py b/api/libs/login.py index 0535f52ea1..24b8c4011a 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -13,6 +13,15 @@ from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) + + +def current_account_with_tenant(): + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + assert current_user.current_tenant_id is not None, "The tenant information should be loaded." + return current_user, current_user.current_tenant_id + + from typing import ParamSpec, TypeVar P = ParamSpec("P") diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 9feca7337f..c0d26cdd27 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -8,8 +8,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from libs.login import current_user -from models.account import Account +from libs.login import current_account_with_tenant from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation from services.feature_service import FeatureService from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task @@ -24,10 +23,10 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - assert isinstance(current_user, Account) + current_user, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -63,12 +62,12 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , add annotation to index annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - assert current_user.current_tenant_id is not None + assert current_tenant_id is not None if annotation_setting: add_annotation_to_index_task.delay( annotation.id, args["question"], - current_user.current_tenant_id, + current_tenant_id, app_id, annotation_setting.collection_binding_id, ) @@ -86,13 +85,12 @@ class AppAnnotationService: enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" # send batch add segments task redis_client.setnx(enable_app_annotation_job_key, "waiting") - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() enable_annotation_reply_task.delay( str(job_id), app_id, current_user.id, - current_user.current_tenant_id, + current_tenant_id, args["score_threshold"], args["embedding_provider_name"], args["embedding_model_name"], @@ -101,8 +99,7 @@ class AppAnnotationService: @classmethod def disable_app_annotation(cls, app_id: str): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: @@ -113,17 +110,16 @@ class AppAnnotationService: disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" # send batch add segments task redis_client.setnx(disable_app_annotation_job_key, "waiting") - disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) + disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id) return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -153,11 +149,10 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -174,11 +169,10 @@ class AppAnnotationService: @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -196,7 +190,7 @@ class AppAnnotationService: add_annotation_to_index_task.delay( annotation.id, args["question"], - current_user.current_tenant_id, + current_tenant_id, app_id, annotation_setting.collection_binding_id, ) @@ -205,11 +199,10 @@ class AppAnnotationService: @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -234,7 +227,7 @@ class AppAnnotationService: update_annotation_to_index_task.delay( annotation.id, annotation.question, - current_user.current_tenant_id, + current_tenant_id, app_id, app_annotation_setting.collection_binding_id, ) @@ -244,11 +237,10 @@ class AppAnnotationService: @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -277,17 +269,16 @@ class AppAnnotationService: if app_annotation_setting: delete_annotation_index_task.delay( - annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id ) @classmethod def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -317,7 +308,7 @@ class AppAnnotationService: for annotation, annotation_setting in annotations_to_delete: if annotation_setting: delete_annotation_index_task.delay( - annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id + annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id ) # Step 4: Bulk delete annotations in a single query @@ -333,11 +324,10 @@ class AppAnnotationService: @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage): # get app info - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -354,7 +344,7 @@ class AppAnnotationService: if len(result) == 0: raise ValueError("The CSV file is empty.") # check annotation limit - features = FeatureService.get_features(current_user.current_tenant_id) + features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: annotation_quota_limit = features.annotation_quota_limit if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size: @@ -364,21 +354,18 @@ class AppAnnotationService: indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") - batch_import_annotations_task.delay( - str(job_id), result, app_id, current_user.current_tenant_id, current_user.id - ) + batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id) except Exception as e: return {"error_msg": str(e)} return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() # get app info app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -445,12 +432,11 @@ class AppAnnotationService: @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() # get app info app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -481,12 +467,11 @@ class AppAnnotationService: @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + current_user, current_tenant_id = current_account_with_tenant() # get app info app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -531,11 +516,10 @@ class AppAnnotationService: @classmethod def clear_all_annotations(cls, app_id: str): - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None + _, current_tenant_id = current_account_with_tenant() app = ( db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") .first() ) @@ -558,7 +542,7 @@ class AppAnnotationService: # if annotation reply is enabled, delete annotation index if app_annotation_setting: delete_annotation_index_task.delay( - annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id ) db.session.delete(annotation) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 36b7084973..fcb6ab1d40 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,7 +3,6 @@ import time from collections.abc import Mapping from typing import Any -from flask_login import current_user from sqlalchemy.orm import Session from configs import dify_config @@ -18,6 +17,7 @@ from core.tools.entities.tool_entities import CredentialType from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.login import current_account_with_tenant from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService @@ -93,6 +93,8 @@ class DatasourceProviderService: """ get credential by id """ + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: if credential_id: datasource_provider = ( @@ -157,6 +159,8 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: datasource_providers = ( session.query(DatasourceProvider) @@ -604,6 +608,8 @@ class DatasourceProviderService: """ provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" with redis_client.lock(lock, timeout=20): @@ -901,6 +907,8 @@ class DatasourceProviderService: """ update datasource credentials. """ + current_user, _ = current_account_with_tenant() + with Session(db.engine) as session: datasource_provider = ( session.query(DatasourceProvider) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 951b9e5653..b528728364 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -8,7 +8,6 @@ import click import pandas as pd from celery import shared_task from sqlalchemy import func -from sqlalchemy.orm import Session from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -50,54 +49,48 @@ def batch_create_segment_to_index_task( indexing_cache_key = f"segment_batch_import_{job_id}" try: - with Session(db.engine) as session: - dataset = session.get(Dataset, dataset_id) - if not dataset: - raise ValueError("Dataset not exist.") + dataset = db.session.get(Dataset, dataset_id) + if not dataset: + raise ValueError("Dataset not exist.") - dataset_document = session.get(Document, document_id) - if not dataset_document: - raise ValueError("Document not exist.") + dataset_document = db.session.get(Document, document_id) + if not dataset_document: + raise ValueError("Document not exist.") - if ( - not dataset_document.enabled - or dataset_document.archived - or dataset_document.indexing_status != "completed" - ): - raise ValueError("Document is not available.") + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + raise ValueError("Document is not available.") - upload_file = session.get(UploadFile, upload_file_id) - if not upload_file: - raise ValueError("UploadFile not found.") + upload_file = db.session.get(UploadFile, upload_file_id) + if not upload_file: + raise ValueError("UploadFile not found.") - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore - storage.download(upload_file.key, file_path) + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file.key, file_path) - # Skip the first row - df = pd.read_csv(file_path) - content = [] - for _, row in df.iterrows(): - if dataset_document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - content.append(data) - if len(content) == 0: - raise ValueError("The CSV file is empty.") + df = pd.read_csv(file_path) + content = [] + for _, row in df.iterrows(): + if dataset_document.doc_form == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} + else: + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") + + document_segments = [] + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) word_count_change = 0 if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens( @@ -105,6 +98,7 @@ def batch_create_segment_to_index_task( ) else: tokens_list = [0] * len(content) + for segment, tokens in zip(content, tokens_list): content = segment["content"] doc_id = str(uuid.uuid4()) @@ -135,11 +129,11 @@ def batch_create_segment_to_index_task( word_count_change += segment_document.word_count db.session.add(segment_document) document_segments.append(segment_document) - # update document word count + assert dataset_document.word_count is not None dataset_document.word_count += word_count_change db.session.add(dataset_document) - # add index to db + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) db.session.commit() redis_client.setex(indexing_cache_key, 600, "completed") diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 3cb7424df8..4768b981cc 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -25,9 +25,7 @@ class TestAnnotationService: patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, - patch( - "services.annotation_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, + patch("services.annotation_service.current_account_with_tenant") as mock_current_account_with_tenant, ): # Setup default mock returns mock_account_feature_service.get_features.return_value.billing.enabled = False @@ -38,6 +36,9 @@ class TestAnnotationService: mock_disable_task.delay.return_value = None mock_batch_import_task.delay.return_value = None + # Create mock user that will be returned by current_account_with_tenant + mock_user = create_autospec(Account, instance=True) + yield { "account_feature_service": mock_account_feature_service, "feature_service": mock_feature_service, @@ -47,7 +48,8 @@ class TestAnnotationService: "enable_task": mock_enable_task, "disable_task": mock_disable_task, "batch_import_task": mock_batch_import_task, - "current_user": mock_current_user, + "current_account_with_tenant": mock_current_account_with_tenant, + "current_user": mock_user, } def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): @@ -107,6 +109,11 @@ class TestAnnotationService: """ mock_external_service_dependencies["current_user"].id = account_id mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id + # Configure current_account_with_tenant to return (user, tenant_id) + mock_external_service_dependencies["current_account_with_tenant"].return_value = ( + mock_external_service_dependencies["current_user"], + tenant_id, + ) def _create_test_conversation(self, app, account, fake): """ diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 5d132cb787..6777077de8 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -60,7 +60,7 @@ class TestAccountInitialization: return "success" # Act - with patch("controllers.console.wraps._current_account", return_value=mock_user): + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")): result = protected_view() # Assert @@ -77,7 +77,7 @@ class TestAccountInitialization: return "success" # Act & Assert - with patch("controllers.console.wraps._current_account", return_value=mock_user): + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")): with pytest.raises(AccountNotInitializedError): protected_view() @@ -163,7 +163,9 @@ class TestBillingResourceLimits: return "member_added" # Act - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123") + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = add_member() @@ -185,7 +187,10 @@ class TestBillingResourceLimits: # Act & Assert with app.test_request_context(): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: add_member() @@ -207,7 +212,10 @@ class TestBillingResourceLimits: # Test 1: Should reject when source is datasets with app.test_request_context("/?source=datasets"): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: upload_document() @@ -215,7 +223,10 @@ class TestBillingResourceLimits: # Test 2: Should allow when source is not datasets with app.test_request_context("/?source=other"): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = upload_document() assert result == "document_uploaded" @@ -239,7 +250,9 @@ class TestRateLimiting: return "knowledge_success" # Act - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123") + ): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): @@ -271,7 +284,10 @@ class TestRateLimiting: # Act & Assert with app.test_request_context(): - with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): + with patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MockUser("test_user"), "tenant123"), + ): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index cc5e46deeb..29b27a60ad 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -53,9 +53,6 @@ const ChatWrapper = () => { initUserVariables, } = useChatWithHistoryContext() - // Semantic variable for better code readability - const isHistoryConversation = !!currentConversationId - const appConfig = useMemo(() => { const config = appParams || {} @@ -66,9 +63,9 @@ const ChatWrapper = () => { fileUploadConfig: (config as any).system_parameters, }, supportFeedback: true, - opening_statement: isHistoryConversation ? currentConversationItem?.introduction : (config as any).opening_statement, + opening_statement: currentConversationItem?.introduction || (config as any).opening_statement, } as ChatConfig - }, [appParams, currentConversationItem?.introduction, isHistoryConversation]) + }, [appParams, currentConversationItem?.introduction]) const { chatList, setTargetMessageId, @@ -79,7 +76,7 @@ const ChatWrapper = () => { } = useChat( appConfig, { - inputs: (isHistoryConversation ? currentConversationInputs : newConversationInputs) as any, + inputs: (currentConversationId ? currentConversationInputs : newConversationInputs) as any, inputsForm: inputsForms, }, appPrevChatTree, @@ -87,7 +84,7 @@ const ChatWrapper = () => { clearChatList, setClearChatList, ) - const inputsFormValue = isHistoryConversation ? currentConversationInputs : newConversationInputsRef?.current + const inputsFormValue = currentConversationId ? currentConversationInputs : newConversationInputsRef?.current const inputDisabled = useMemo(() => { if (allInputsHidden) return false @@ -136,7 +133,7 @@ const ChatWrapper = () => { const data: any = { query: message, files, - inputs: formatBooleanInputs(inputsForms, isHistoryConversation ? currentConversationInputs : newConversationInputs), + inputs: formatBooleanInputs(inputsForms, currentConversationId ? currentConversationInputs : newConversationInputs), conversation_id: currentConversationId, parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null, } @@ -146,11 +143,11 @@ const ChatWrapper = () => { data, { onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId), - onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted, + onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, isPublicAPI: !isInstalledApp, }, ) - }, [chatList, handleNewConversationCompleted, handleSend, isHistoryConversation, currentConversationInputs, newConversationInputs, isInstalledApp, appId]) + }, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId]) const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => { const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)! @@ -163,30 +160,38 @@ const ChatWrapper = () => { }, [chatList, doSend]) const messageList = useMemo(() => { - // Always filter out opening statement from message list as it's handled separately in welcome component + if (currentConversationId || chatList.length > 1) + return chatList + // Without messages we are in the welcome screen, so hide the opening statement from chatlist return chatList.filter(item => !item.isOpeningStatement) }, [chatList]) - const [collapsed, setCollapsed] = useState(isHistoryConversation) + const [collapsed, setCollapsed] = useState(!!currentConversationId) const chatNode = useMemo(() => { if (allInputsHidden || !inputsForms.length) return null if (isMobile) { - if (!isHistoryConversation) + if (!currentConversationId) return return null } else { return } - }, [inputsForms.length, isMobile, isHistoryConversation, collapsed, allInputsHidden]) + }, + [ + inputsForms.length, + isMobile, + currentConversationId, + collapsed, allInputsHidden, + ]) const welcome = useMemo(() => { const welcomeMessage = chatList.find(item => item.isOpeningStatement) if (respondingState) return null - if (isHistoryConversation) + if (currentConversationId) return null if (!welcomeMessage) return null @@ -227,7 +232,18 @@ const ChatWrapper = () => { ) - }, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, isHistoryConversation, inputsForms.length, respondingState, allInputsHidden]) + }, + [ + appData?.site.icon, + appData?.site.icon_background, + appData?.site.icon_type, + appData?.site.icon_url, + chatList, collapsed, + currentConversationId, + inputsForms.length, + respondingState, + allInputsHidden, + ]) const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon) ? { chatFooterClassName='pb-4' chatFooterInnerClassName={`mx-auto w-full max-w-[768px] ${isMobile ? 'px-2' : 'px-4'}`} onSend={doSend} - inputs={isHistoryConversation ? currentConversationInputs as any : newConversationInputs} + inputs={currentConversationId ? currentConversationInputs as any : newConversationInputs} inputsForm={inputsForms} onRegenerate={doRegenerate} onStopResponding={handleStop} diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 51eb00cfc5..a1b458ba9a 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -111,6 +111,8 @@ const Answer: FC = ({ } }, [switchSibling, item.prevSibling, item.nextSibling]) + const contentIsEmpty = content.trim() === '' + return (
@@ -153,14 +155,14 @@ const Answer: FC = ({ ) } { - responding && !content && !hasAgentThoughts && ( + responding && contentIsEmpty && !hasAgentThoughts && (
) } { - content && !hasAgentThoughts && ( + !contentIsEmpty && !hasAgentThoughts && ( ) } diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index cbfa3168e9..a1144d5537 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -83,6 +83,15 @@ const ChatInputArea = ({ const historyRef = useRef(['']) const [currentIndex, setCurrentIndex] = useState(-1) const isComposingRef = useRef(false) + + const handleQueryChange = useCallback( + (value: string) => { + setQuery(value) + setTimeout(handleTextareaResize, 0) + }, + [handleTextareaResize], + ) + const handleSend = () => { if (isResponding) { notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') }) @@ -101,7 +110,7 @@ const ChatInputArea = ({ } if (checkInputsForm(inputs, inputsForm)) { onSend(query, files) - setQuery('') + handleQueryChange('') setFiles([]) } } @@ -131,19 +140,19 @@ const ChatInputArea = ({ // When the cmd + up key is pressed, output the previous element if (currentIndex > 0) { setCurrentIndex(currentIndex - 1) - setQuery(historyRef.current[currentIndex - 1]) + handleQueryChange(historyRef.current[currentIndex - 1]) } } else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) { // When the cmd + down key is pressed, output the next element if (currentIndex < historyRef.current.length - 1) { setCurrentIndex(currentIndex + 1) - setQuery(historyRef.current[currentIndex + 1]) + handleQueryChange(historyRef.current[currentIndex + 1]) } else if (currentIndex === historyRef.current.length - 1) { // If it is the last element, clear the input box setCurrentIndex(historyRef.current.length) - setQuery('') + handleQueryChange('') } } } @@ -171,7 +180,7 @@ const ChatInputArea = ({ <>
{ - setQuery(e.target.value) - setTimeout(handleTextareaResize, 0) - }} + onChange={e => handleQueryChange(e.target.value)} onKeyDown={handleKeyDown} onCompositionStart={handleCompositionStart} onCompositionEnd={handleCompositionEnd} @@ -221,7 +226,7 @@ const ChatInputArea = ({ showVoiceInput && ( setShowVoiceInput(false)} - onConverted={text => setQuery(text)} + onConverted={text => handleQueryChange(text)} /> ) } diff --git a/web/app/components/base/chat/chat/chat-input-area/operation.tsx b/web/app/components/base/chat/chat/chat-input-area/operation.tsx index 122dfcb6fb..014ca6651f 100644 --- a/web/app/components/base/chat/chat/chat-input-area/operation.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/operation.tsx @@ -1,3 +1,4 @@ +import type { FC, Ref } from 'react' import { memo } from 'react' import { RiMicLine, @@ -18,20 +19,17 @@ type OperationProps = { speechToTextConfig?: EnableType onShowVoiceInput?: () => void onSend: () => void - theme?: Theme | null + theme?: Theme | null, + ref?: Ref; } -const Operation = ( - { - ref, - fileConfig, - speechToTextConfig, - onShowVoiceInput, - onSend, - theme, - }: OperationProps & { - ref: React.RefObject; - }, -) => { +const Operation: FC = ({ + ref, + fileConfig, + speechToTextConfig, + onShowVoiceInput, + onSend, + theme, +}) => { return (
{ fileUploadConfig: (config as any).system_parameters, }, supportFeedback: true, - opening_statement: currentConversationId ? currentConversationItem?.introduction : (config as any).opening_statement, + opening_statement: currentConversationItem?.introduction || (config as any).opening_statement, } as ChatConfig - }, [appParams, currentConversationItem?.introduction, currentConversationId]) + }, [appParams, currentConversationItem?.introduction]) const { chatList, setTargetMessageId, @@ -158,8 +158,9 @@ const ChatWrapper = () => { }, [chatList, doSend]) const messageList = useMemo(() => { - if (currentConversationId) + if (currentConversationId || chatList.length > 1) return chatList + // Without messages we are in the welcome screen, so hide the opening statement from chatlist return chatList.filter(item => !item.isOpeningStatement) }, [chatList, currentConversationId]) @@ -240,7 +241,7 @@ const ChatWrapper = () => { config={appConfig} chatList={messageList} isResponding={respondingState} - chatContainerInnerClassName={cn('mx-auto w-full max-w-full pt-4 tablet:px-4', isMobile && 'px-4')} + chatContainerInnerClassName={cn('mx-auto w-full max-w-full px-4', messageList.length && 'pt-4')} chatFooterClassName={cn('pb-4', !isMobile && 'rounded-b-2xl')} chatFooterInnerClassName={cn('mx-auto w-full max-w-full px-4', isMobile && 'px-2')} onSend={doSend} diff --git a/web/app/components/base/chat/embedded-chatbot/context.tsx b/web/app/components/base/chat/embedded-chatbot/context.tsx index 544da253af..d6c35e172e 100644 --- a/web/app/components/base/chat/embedded-chatbot/context.tsx +++ b/web/app/components/base/chat/embedded-chatbot/context.tsx @@ -17,12 +17,9 @@ import type { import { noop } from 'lodash-es' export type EmbeddedChatbotContextValue = { - userCanAccess?: boolean - appInfoError?: any - appInfoLoading?: boolean - appMeta?: AppMeta - appData?: AppData - appParams?: ChatConfig + appMeta: AppMeta | null + appData: AppData | null + appParams: ChatConfig | null appChatListDataLoading?: boolean currentConversationId: string currentConversationItem?: ConversationItem @@ -59,7 +56,10 @@ export type EmbeddedChatbotContextValue = { } export const EmbeddedChatbotContext = createContext({ - userCanAccess: false, + appData: null, + appMeta: null, + appParams: null, + appChatListDataLoading: false, currentConversationId: '', appPrevChatList: [], pinnedConversationList: [], diff --git a/web/app/components/base/chat/embedded-chatbot/header/index.tsx b/web/app/components/base/chat/embedded-chatbot/header/index.tsx index 95975e29e7..904506f2bc 100644 --- a/web/app/components/base/chat/embedded-chatbot/header/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/header/index.tsx @@ -135,7 +135,7 @@ const Header: FC = ({ return (
{customerIcon} diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index aa7006db25..6987955b74 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -18,9 +18,6 @@ import { CONVERSATION_ID_INFO } from '../constants' import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams, getProcessedUserVariablesFromUrlParams } from '../utils' import { getProcessedFilesFromResponse } from '../../file-uploader/utils' import { - fetchAppInfo, - fetchAppMeta, - fetchAppParams, fetchChatList, fetchConversations, generationConversationName, @@ -36,8 +33,7 @@ import { InputVarType } from '@/app/components/workflow/types' import { TransferMethod } from '@/types/app' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { noop } from 'lodash-es' -import { useGetUserCanAccessApp } from '@/service/access-control' -import { useGlobalPublicStore } from '@/context/global-public-context' +import { useWebAppStore } from '@/context/web-app-context' function getFormattedChatList(messages: any[]) { const newChatList: ChatItem[] = [] @@ -67,18 +63,10 @@ function getFormattedChatList(messages: any[]) { export const useEmbeddedChatbot = () => { const isInstalledApp = false - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR('appInfo', fetchAppInfo) - const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({ - appId: appInfo?.app_id, - isInstalledApp, - enabled: systemFeatures.webapp_auth.enabled, - }) - - const appData = useMemo(() => { - return appInfo - }, [appInfo]) - const appId = useMemo(() => appData?.app_id, [appData]) + const appInfo = useWebAppStore(s => s.appInfo) + const appMeta = useWebAppStore(s => s.appMeta) + const appParams = useWebAppStore(s => s.appParams) + const appId = useMemo(() => appInfo?.app_id, [appInfo]) const [userId, setUserId] = useState() const [conversationId, setConversationId] = useState() @@ -145,8 +133,6 @@ export const useEmbeddedChatbot = () => { return currentConversationId }, [currentConversationId, newConversationId]) - const { data: appParams } = useSWR(['appParams', isInstalledApp, appId], () => fetchAppParams(isInstalledApp, appId)) - const { data: appMeta } = useSWR(['appMeta', isInstalledApp, appId], () => fetchAppMeta(isInstalledApp, appId)) const { data: appPinnedConversationData } = useSWR(['appConversationData', isInstalledApp, appId, true], () => fetchConversations(isInstalledApp, appId, undefined, true, 100)) const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100)) const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId)) @@ -398,16 +384,13 @@ export const useEmbeddedChatbot = () => { }, [isInstalledApp, appId, t, notify]) return { - appInfoError, - appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission), - userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true, isInstalledApp, allowResetChat, appId, currentConversationId, currentConversationItem, handleConversationIdInfoChange, - appData, + appData: appInfo, appParams: appParams || {} as ChatConfig, appMeta, appPinnedConversationData, diff --git a/web/app/components/base/chat/embedded-chatbot/index.tsx b/web/app/components/base/chat/embedded-chatbot/index.tsx index 4c8c0a2455..1553d1f153 100644 --- a/web/app/components/base/chat/embedded-chatbot/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/index.tsx @@ -49,8 +49,8 @@ const Chatbot = () => {
@@ -62,7 +62,7 @@ const Chatbot = () => { theme={themeBuilder?.theme} onCreateNewChat={handleNewConversation} /> -
+
{appChatListDataLoading && ( )} @@ -101,7 +101,6 @@ const EmbeddedChatbotWrapper = () => { const { appData, - userCanAccess, appParams, appMeta, appChatListDataLoading, @@ -135,7 +134,6 @@ const EmbeddedChatbotWrapper = () => { } = useEmbeddedChatbot() return