From 566e0fd3e5b51941b2249c5c652ad2b2144d4af6 Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 9 Sep 2025 13:47:29 +0800 Subject: [PATCH 01/18] fix(container-test): batch create segment position sort (#25394) --- .../tasks/test_batch_create_segment_to_index_task.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index b77975c032..065bcc2cd7 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -296,7 +296,12 @@ class TestBatchCreateSegmentToIndexTask: from extensions.ext_database import db # Check that segments were created - segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = ( + db.session.query(DocumentSegment) + .filter_by(document_id=document.id) + .order_by(DocumentSegment.position) + .all() + ) assert len(segments) == 3 # Verify segment content and metadata From 64c9a2f678414ee9614a1467ad967d198236e617 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Mon, 8 Sep 2025 23:45:05 -0700 Subject: [PATCH 02/18] Feat/credential policy (#25151) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/workflow.py | 6 +- api/core/entities/provider_configuration.py | 28 +- api/core/entities/provider_entities.py | 1 + api/core/helper/credential_utils.py | 75 +++++ api/core/model_manager.py | 36 +++ api/core/provider_manager.py | 1 + api/core/tools/errors.py | 4 + api/core/tools/tool_manager.py | 15 +- api/services/enterprise/base.py | 22 +- .../enterprise/plugin_manager_service.py | 52 ++++ api/services/feature_service.py | 6 + api/services/workflow_service.py | 274 +++++++++++++++++- 12 files changed, 495 insertions(+), 25 deletions(-) create mode 100644 api/core/helper/credential_utils.py create mode 100644 api/services/enterprise/plugin_manager_service.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index bf20a5ae62..05178328fe 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -11,11 +11,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from configs import dify_config from controllers.console import api -from controllers.console.app.error import ( - ConversationCompletedError, - DraftWorkflowNotExist, - DraftWorkflowNotSync, -) +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 61a960c3d4..9cf35e559d 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -42,6 +42,7 @@ from models.provider import ( ProviderType, TenantPreferredModelProvider, ) +from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) @@ -129,14 +130,38 @@ class ProviderConfiguration(BaseModel): return copy_credentials else: credentials = None + current_credential_id = None + if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: credentials = model_configuration.credentials + current_credential_id = model_configuration.current_credential_id break if not credentials and self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials + current_credential_id = self.custom_configuration.provider.current_credential_id + + if current_credential_id: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=current_credential_id, + provider=self.provider.provider, + credential_type=PluginCredentialType.MODEL, + ) + else: + # no current credential id, check all available credentials + if self.custom_configuration.provider: + for credential_configuration in self.custom_configuration.provider.available_credentials: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=credential_configuration.credential_id, + provider=self.provider.provider, + credential_type=PluginCredentialType.MODEL, + ) return credentials @@ -266,7 +291,6 @@ class ProviderConfiguration(BaseModel): :param credential_id: if provided, return the specified credential :return: """ - if credential_id: return self._get_specific_provider_credential(credential_id) @@ -738,6 +762,7 @@ class ProviderConfiguration(BaseModel): current_credential_id = credential_record.id current_credential_name = credential_record.credential_name + credentials = self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas @@ -792,6 +817,7 @@ class ProviderConfiguration(BaseModel): ): current_credential_id = model_configuration.current_credential_id current_credential_name = model_configuration.current_credential_name + credentials = self.obfuscated_credentials( credentials=model_configuration.credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 79a7514bbc..9b8baf1973 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -145,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel): name: str credentials: dict credential_source_type: str | None = None + credential_id: str | None = None class ModelSettings(BaseModel): diff --git a/api/core/helper/credential_utils.py b/api/core/helper/credential_utils.py new file mode 100644 index 0000000000..240f498181 --- /dev/null +++ b/api/core/helper/credential_utils.py @@ -0,0 +1,75 @@ +""" +Credential utility functions for checking credential existence and policy compliance. +""" + +from services.enterprise.plugin_manager_service import PluginCredentialType + + +def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: + """ + Check if the credential still exists in the database. + + :param credential_id: The credential ID to check + :param credential_type: The type of credential (MODEL or TOOL) + :return: True if credential exists, False otherwise + """ + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + from models.provider import ProviderCredential, ProviderModelCredential + from models.tools import BuiltinToolProvider + + with Session(db.engine) as session: + if credential_type == PluginCredentialType.MODEL: + # Check both pre-defined and custom model credentials using a single UNION query + stmt = ( + select(ProviderCredential.id) + .where(ProviderCredential.id == credential_id) + .union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id)) + ) + return session.scalar(stmt) is not None + + if credential_type == PluginCredentialType.TOOL: + return ( + session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id)) + is not None + ) + + return False + + +def check_credential_policy_compliance( + credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True +) -> None: + """ + Check credential policy compliance for the given credential ID. + + :param credential_id: The credential ID to check + :param provider: The provider name + :param credential_type: The type of credential (MODEL or TOOL) + :param check_existence: Whether to check if credential exists in database first + :raises ValueError: If credential policy compliance check fails + """ + from services.enterprise.plugin_manager_service import ( + CheckCredentialPolicyComplianceRequest, + PluginManagerService, + ) + from services.feature_service import FeatureService + + if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id: + return + + # Check if credential exists in database first (if requested) + if check_existence: + if not is_credential_exists(credential_id, credential_type): + raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.") + + # Check policy compliance + PluginManagerService.check_credential_policy_compliance( + CheckCredentialPolicyComplianceRequest( + dify_credential_id=credential_id, + provider=provider, + credential_type=credential_type, + ) + ) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a59b0ae826..10df2ad79e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -23,6 +23,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.provider_manager import ProviderManager from extensions.ext_redis import redis_client from models.provider import ProviderType +from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) @@ -362,6 +363,23 @@ class ModelInstance: else: raise last_exception + # Additional policy compliance check as fallback (in case fetch_next didn't catch it) + try: + from core.helper.credential_utils import check_credential_policy_compliance + + if lb_config.credential_id: + check_credential_policy_compliance( + credential_id=lb_config.credential_id, + provider=self.provider, + credential_type=PluginCredentialType.MODEL, + ) + except Exception as e: + logger.warning( + "Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e) + ) + self.load_balancing_manager.cooldown(lb_config, expire=60) + continue + try: if "credentials" in kwargs: del kwargs["credentials"] @@ -515,6 +533,24 @@ class LBModelManager: continue + # Check policy compliance for the selected configuration + try: + from core.helper.credential_utils import check_credential_policy_compliance + + if config.credential_id: + check_credential_policy_compliance( + credential_id=config.credential_id, + provider=self._provider, + credential_type=PluginCredentialType.MODEL, + ) + except Exception as e: + logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e)) + cooldown_load_balancing_configs.append(config) + if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): + # all configs are in cooldown or failed policy compliance + return None + continue + if dify_config.DEBUG: logger.info( """Model LB diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 13dcef1a1f..e4e8b09a04 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1129,6 +1129,7 @@ class ProviderManager: name=load_balancing_model_config.name, credentials=provider_model_credentials, credential_source_type=load_balancing_model_config.credential_source_type, + credential_id=load_balancing_model_config.credential_id, ) ) diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index c5f9ca4774..b0c2232857 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolCredentialPolicyViolationError(ValueError): + pass + + class ToolEngineInvokeError(Exception): meta: ToolInvokeMeta diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 00fc57a3f1..bc1f09a2fc 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -27,6 +27,7 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.workflow.entities.variable_pool import VariablePool +from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: @@ -55,9 +56,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -237,6 +236,16 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + # check if the credential is allowed to be used + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=builtin_provider.id, + provider=provider_id, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, config=[ diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 3c3f970444..edb76408e8 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -3,18 +3,30 @@ import os import requests -class EnterpriseRequest: - base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") - secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") - +class BaseRequest: proxies = { "http": "", "https": "", } + base_url = "" + secret_key = "" + secret_key_header = "" @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} + headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) return response.json() + + +class EnterpriseRequest(BaseRequest): + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") + secret_key_header = "Enterprise-Api-Secret-Key" + + +class EnterprisePluginManagerRequest(BaseRequest): + base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL") + secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY") + secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key" diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py new file mode 100644 index 0000000000..cfcc39416a --- /dev/null +++ b/api/services/enterprise/plugin_manager_service.py @@ -0,0 +1,52 @@ +import enum +import logging + +from pydantic import BaseModel + +from services.enterprise.base import EnterprisePluginManagerRequest +from services.errors.base import BaseServiceError + + +class PluginCredentialType(enum.Enum): + MODEL = 0 + TOOL = 1 + + def to_number(self): + return self.value + + +class CheckCredentialPolicyComplianceRequest(BaseModel): + dify_credential_id: str + provider: str + credential_type: PluginCredentialType + + def model_dump(self, **kwargs): + data = super().model_dump(**kwargs) + data["credential_type"] = self.credential_type.to_number() + return data + + +class CredentialPolicyViolationError(BaseServiceError): + pass + + +class PluginManagerService: + @classmethod + def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest): + try: + ret = EnterprisePluginManagerRequest.send_request( + "POST", "/check-credential-policy-compliance", json=body.model_dump() + ) + if not isinstance(ret, dict) or "result" not in ret: + raise ValueError("Invalid response format from plugin manager API") + except Exception as e: + raise CredentialPolicyViolationError( + f"error occurred while checking credential policy compliance: {e}" + ) from e + + if not ret.get("result", False): + raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") + + logging.debug( + f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}" + ) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 1441e6ce16..c27c0b0d58 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -134,6 +134,10 @@ class KnowledgeRateLimitModel(BaseModel): subscription_plan: str = "" +class PluginManagerModel(BaseModel): + enabled: bool = False + + class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" @@ -150,6 +154,7 @@ class SystemFeatureModel(BaseModel): webapp_auth: WebAppAuthModel = WebAppAuthModel() plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() enable_change_email: bool = True + plugin_manager: PluginManagerModel = PluginManagerModel() class FeatureService: @@ -188,6 +193,7 @@ class FeatureService: system_features.branding.enabled = True system_features.webapp_auth.enabled = True system_features.enable_change_email = False + system_features.plugin_manager.enabled = True cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 350e52e438..0a14007349 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -36,22 +36,14 @@ from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider -from models.workflow import ( - Workflow, - WorkflowNodeExecutionModel, - WorkflowNodeExecutionTriggeredFrom, - WorkflowType, -) +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType from repositories.factory import DifyAPIRepositoryFactory +from services.enterprise.plugin_manager_service import PluginCredentialType from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from .workflow_draft_variable_service import ( - DraftVariableSaver, - DraftVarLoader, - WorkflowDraftVariableService, -) +from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService class WorkflowService: @@ -271,6 +263,12 @@ class WorkflowService: if not draft_workflow: raise ValueError("No valid workflow found.") + # Validate credentials before publishing, for credential policy check + from services.feature_service import FeatureService + + if FeatureService.get_system_features().plugin_manager.enabled: + self._validate_workflow_credentials(draft_workflow) + # create new workflow workflow = Workflow.new( tenant_id=app_model.tenant_id, @@ -295,6 +293,260 @@ class WorkflowService: # return new workflow return workflow + def _validate_workflow_credentials(self, workflow: Workflow) -> None: + """ + Validate all credentials in workflow nodes before publishing. + + :param workflow: The workflow to validate + :raises ValueError: If any credentials violate policy compliance + """ + graph_dict = workflow.graph_dict + nodes = graph_dict.get("nodes", []) + + for node in nodes: + node_data = node.get("data", {}) + node_type = node_data.get("type") + node_id = node.get("id", "unknown") + + try: + # Extract and validate credentials based on node type + if node_type == "tool": + credential_id = node_data.get("credential_id") + provider = node_data.get("provider_id") + if provider: + if credential_id: + # Check specific credential + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=credential_id, + provider=provider, + credential_type=PluginCredentialType.TOOL, + ) + else: + # Check default workspace credential for this provider + self._check_default_tool_credential(workflow.tenant_id, provider) + + elif node_type == "agent": + agent_params = node_data.get("agent_parameters", {}) + + model_config = agent_params.get("model", {}).get("value", {}) + if model_config.get("provider") and model_config.get("model"): + self._validate_llm_model_config( + workflow.tenant_id, model_config["provider"], model_config["model"] + ) + + # Validate load balancing credentials for agent model if load balancing is enabled + agent_model_node_data = {"model": model_config} + self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id) + + # Validate agent tools + tools = agent_params.get("tools", {}).get("value", []) + for tool in tools: + # Agent tools store provider in provider_name field + provider = tool.get("provider_name") + credential_id = tool.get("credential_id") + if provider: + if credential_id: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL) + else: + self._check_default_tool_credential(workflow.tenant_id, provider) + + elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]: + model_config = node_data.get("model", {}) + provider = model_config.get("provider") + model_name = model_config.get("name") + + if provider and model_name: + # Validate that the provider+model combination can fetch valid credentials + self._validate_llm_model_config(workflow.tenant_id, provider, model_name) + # Validate load balancing credentials if load balancing is enabled + self._validate_load_balancing_credentials(workflow, node_data, node_id) + else: + raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration") + + except Exception as e: + if isinstance(e, ValueError): + raise e + else: + raise ValueError(f"Node {node_id} ({node_type}): {str(e)}") + + def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None: + """ + Validate that an LLM model configuration can fetch valid credentials. + + This method attempts to get the model instance and validates that: + 1. The provider exists and is configured + 2. The model exists in the provider + 3. Credentials can be fetched for the model + 4. The credentials pass policy compliance checks + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :raises ValueError: If the model configuration is invalid or credentials fail policy checks + """ + try: + from core.model_manager import ModelManager + from core.model_runtime.entities.model_entities import ModelType + + # Get model instance to validate provider+model combination + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name + ) + + # The ModelInstance constructor will automatically check credential policy compliance + # via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance() + # If it fails, an exception will be raised + + except Exception as e: + raise ValueError( + f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}" + ) + + def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None: + """ + Check credential policy compliance for the default workspace credential of a tool provider. + + This method finds the default credential for the given provider and validates it. + Uses the same fallback logic as runtime to handle deauthorized credentials. + + :param tenant_id: The tenant ID + :param provider: The tool provider name + :raises ValueError: If no default credential exists or if it fails policy compliance + """ + try: + from models.tools import BuiltinToolProvider + + # Use the same fallback logic as runtime: get the first available credential + # ordered by is_default DESC, created_at ASC (same as tool_manager.py) + default_provider = ( + db.session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() + ) + + if not default_provider: + raise ValueError("No default credential found") + + # Check credential policy compliance using the default credential ID + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=default_provider.id, + provider=provider, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + + except Exception as e: + raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") + + def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: + """ + Validate load balancing credentials for a workflow node. + + :param workflow: The workflow being validated + :param node_data: The node data containing model configuration + :param node_id: The node ID for error reporting + :raises ValueError: If load balancing credentials violate policy compliance + """ + # Extract model configuration + model_config = node_data.get("model", {}) + provider = model_config.get("provider") + model_name = model_config.get("name") + + if not provider or not model_name: + return # No model config to validate + + # Check if this model has load balancing enabled + if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name): + # Get all load balancing configurations for this model + load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name) + # Validate each load balancing configuration + try: + for config in load_balancing_configs: + if config.get("credential_id"): + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + config["credential_id"], provider, PluginCredentialType.MODEL + ) + except Exception as e: + raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") + + def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool: + """ + Check if load balancing is enabled for a specific model. + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :return: True if load balancing is enabled, False otherwise + """ + try: + from core.model_runtime.entities.model_entities import ModelType + from core.provider_manager import ProviderManager + + # Get provider configurations + provider_manager = ProviderManager() + provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configuration = provider_configurations.get(provider) + + if not provider_configuration: + return False + + # Get provider model setting + provider_model_setting = provider_configuration.get_provider_model_setting( + model_type=ModelType.LLM, + model=model_name, + ) + return provider_model_setting is not None and provider_model_setting.load_balancing_enabled + + except Exception: + # If we can't determine the status, assume load balancing is not enabled + return False + + def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]: + """ + Get all load balancing configurations for a model. + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :return: List of load balancing configuration dictionaries + """ + try: + from services.model_load_balancing_service import ModelLoadBalancingService + + model_load_balancing_service = ModelLoadBalancingService() + _, configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, + provider=provider, + model=model_name, + model_type="llm", # Load balancing is primarily used for LLM models + config_from="predefined-model", # Check both predefined and custom models + ) + + _, custom_configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" + ) + all_configs = configs + custom_configs + + return [config for config in all_configs if config.get("credential_id")] + + except Exception: + # If we can't get the configurations, return empty list + # This will prevent validation errors from breaking the workflow + return [] + def get_default_block_configs(self) -> list[dict]: """ Get default block configs From c595c03452b25dac7d3fd6b872b14ef7149fa59e Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Tue, 9 Sep 2025 14:52:50 +0800 Subject: [PATCH 03/18] fix: credential not allow to use in load balancing (#25401) --- .../provider-added-card/model-load-balancing-configs.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx index 900ca1b392..29da0ffc0c 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-configs.tsx @@ -196,7 +196,7 @@ const ModelLoadBalancingConfigs = ({ ) : ( - + )} @@ -232,7 +232,7 @@ const ModelLoadBalancingConfigs = ({ <> toggleConfigEntryEnabled(index, value)} From e180c19cca9aadfef04c1c27ff5947c06c028ec0 Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 9 Sep 2025 14:58:14 +0800 Subject: [PATCH 04/18] fix(mcp): current_user not being set in MCP requests (#25393) --- api/extensions/ext_login.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index cd01a31068..5571c0d9ba 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -86,9 +86,7 @@ def load_user_from_request(request_from_flask_login): if not app_mcp_server: raise NotFound("App MCP server not found.") end_user = ( - db.session.query(EndUser) - .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") - .first() + db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first() ) if not end_user: raise NotFound("End user not found.") From 4aba570fa849cbe0138ef7abe5f9fe3b611ddb89 Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Tue, 9 Sep 2025 15:06:18 +0800 Subject: [PATCH 05/18] Fix flask response: 200 -> {}, 200 (#25404) --- api/controllers/console/datasets/data_source.py | 4 ++-- api/controllers/console/datasets/metadata.py | 4 ++-- api/controllers/console/tag/tags.py | 4 ++-- api/controllers/service_api/dataset/metadata.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index e4d5f1be6e..45c647659b 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -249,7 +249,7 @@ class DataSourceNotionDatasetSyncApi(Resource): documents = DocumentService.get_document_by_dataset_id(dataset_id_str) for document in documents: document_indexing_sync_task.delay(dataset_id_str, document.id) - return 200 + return {"result": "success"}, 200 class DataSourceNotionDocumentSyncApi(Resource): @@ -267,7 +267,7 @@ class DataSourceNotionDocumentSyncApi(Resource): if document is None: raise NotFound("Document not found.") document_indexing_sync_task.delay(dataset_id_str, document_id_str) - return 200 + return {"result": "success"}, 200 api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 6aa309f930..21ab5e4fe1 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -113,7 +113,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): MetadataService.enable_built_in_field(dataset) elif action == "disable": MetadataService.disable_built_in_field(dataset) - return 200 + return {"result": "success"}, 200 class DocumentMetadataEditApi(Resource): @@ -135,7 +135,7 @@ class DocumentMetadataEditApi(Resource): MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 + return {"result": "success"}, 200 api.add_resource(DatasetMetadataCreateApi, "/datasets//metadata") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index c45e7dbb26..da236ee5af 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -111,7 +111,7 @@ class TagBindingCreateApi(Resource): args = parser.parse_args() TagService.save_tag_binding(args) - return 200 + return {"result": "success"}, 200 class TagBindingDeleteApi(Resource): @@ -132,7 +132,7 @@ class TagBindingDeleteApi(Resource): args = parser.parse_args() TagService.delete_tag_binding(args) - return 200 + return {"result": "success"}, 200 api.add_resource(TagListApi, "/tags") diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 444a791c01..c2df97eaec 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -174,7 +174,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): MetadataService.enable_built_in_field(dataset) elif action == "disable": MetadataService.disable_built_in_field(dataset) - return 200 + return {"result": "success"}, 200 @service_api_ns.route("/datasets//documents/metadata") @@ -204,4 +204,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 + return {"result": "success"}, 200 From 37975319f288c1cbc4f500d9d13309cb2cfa4797 Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:15:32 +0800 Subject: [PATCH 06/18] feat: Add customized json schema validation (#25408) --- .../error-message.tsx | 2 +- .../components/workflow/nodes/llm/utils.ts | 200 ++------------ web/pnpm-lock.yaml | 2 +- web/utils/draft-07.json | 245 ++++++++++++++++++ web/utils/validators.ts | 27 ++ 5 files changed, 289 insertions(+), 187 deletions(-) create mode 100644 web/utils/draft-07.json create mode 100644 web/utils/validators.ts diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx index c21aa1405e..6e8a2b2fad 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx @@ -17,7 +17,7 @@ const ErrorMessage: FC = ({ className, )}> -
+
{message}
diff --git a/web/app/components/workflow/nodes/llm/utils.ts b/web/app/components/workflow/nodes/llm/utils.ts index 045acf3993..7f13998cd7 100644 --- a/web/app/components/workflow/nodes/llm/utils.ts +++ b/web/app/components/workflow/nodes/llm/utils.ts @@ -1,9 +1,8 @@ +import { z } from 'zod' import { ArrayType, Type } from './types' import type { ArrayItems, Field, LLMNodeType } from './types' -import type { Schema, ValidationError } from 'jsonschema' -import { Validator } from 'jsonschema' -import produce from 'immer' -import { z } from 'zod' +import { draft07Validator, forbidBooleanProperties } from '@/utils/validators' +import type { ValidationError } from 'jsonschema' export const checkNodeValid = (_payload: LLMNodeType) => { return true @@ -14,7 +13,7 @@ export const getFieldType = (field: Field) => { if (type !== Type.array || !items) return type - return ArrayType[items.type] + return ArrayType[items.type as keyof typeof ArrayType] } export const getHasChildren = (schema: Field) => { @@ -115,191 +114,22 @@ export const findPropertyWithPath = (target: any, path: string[]) => { return current } -const draft07MetaSchema = { - $schema: 'http://json-schema.org/draft-07/schema#', - $id: 'http://json-schema.org/draft-07/schema#', - title: 'Core schema meta-schema', - definitions: { - schemaArray: { - type: 'array', - minItems: 1, - items: { $ref: '#' }, - }, - nonNegativeInteger: { - type: 'integer', - minimum: 0, - }, - nonNegativeIntegerDefault0: { - allOf: [ - { $ref: '#/definitions/nonNegativeInteger' }, - { default: 0 }, - ], - }, - simpleTypes: { - enum: [ - 'array', - 'boolean', - 'integer', - 'null', - 'number', - 'object', - 'string', - ], - }, - stringArray: { - type: 'array', - items: { type: 'string' }, - uniqueItems: true, - default: [], - }, - }, - type: ['object', 'boolean'], - properties: { - $id: { - type: 'string', - format: 'uri-reference', - }, - $schema: { - type: 'string', - format: 'uri', - }, - $ref: { - type: 'string', - format: 'uri-reference', - }, - title: { - type: 'string', - }, - description: { - type: 'string', - }, - default: true, - readOnly: { - type: 'boolean', - default: false, - }, - examples: { - type: 'array', - items: true, - }, - multipleOf: { - type: 'number', - exclusiveMinimum: 0, - }, - maximum: { - type: 'number', - }, - exclusiveMaximum: { - type: 'number', - }, - minimum: { - type: 'number', - }, - exclusiveMinimum: { - type: 'number', - }, - maxLength: { $ref: '#/definitions/nonNegativeInteger' }, - minLength: { $ref: '#/definitions/nonNegativeIntegerDefault0' }, - pattern: { - type: 'string', - format: 'regex', - }, - additionalItems: { $ref: '#' }, - items: { - anyOf: [ - { $ref: '#' }, - { $ref: '#/definitions/schemaArray' }, - ], - default: true, - }, - maxItems: { $ref: '#/definitions/nonNegativeInteger' }, - minItems: { $ref: '#/definitions/nonNegativeIntegerDefault0' }, - uniqueItems: { - type: 'boolean', - default: false, - }, - contains: { $ref: '#' }, - maxProperties: { $ref: '#/definitions/nonNegativeInteger' }, - minProperties: { $ref: '#/definitions/nonNegativeIntegerDefault0' }, - required: { $ref: '#/definitions/stringArray' }, - additionalProperties: { $ref: '#' }, - definitions: { - type: 'object', - additionalProperties: { $ref: '#' }, - default: {}, - }, - properties: { - type: 'object', - additionalProperties: { $ref: '#' }, - default: {}, - }, - patternProperties: { - type: 'object', - additionalProperties: { $ref: '#' }, - propertyNames: { format: 'regex' }, - default: {}, - }, - dependencies: { - type: 'object', - additionalProperties: { - anyOf: [ - { $ref: '#' }, - { $ref: '#/definitions/stringArray' }, - ], - }, - }, - propertyNames: { $ref: '#' }, - const: true, - enum: { - type: 'array', - items: true, - minItems: 1, - uniqueItems: true, - }, - type: { - anyOf: [ - { $ref: '#/definitions/simpleTypes' }, - { - type: 'array', - items: { $ref: '#/definitions/simpleTypes' }, - minItems: 1, - uniqueItems: true, - }, - ], - }, - format: { type: 'string' }, - allOf: { $ref: '#/definitions/schemaArray' }, - anyOf: { $ref: '#/definitions/schemaArray' }, - oneOf: { $ref: '#/definitions/schemaArray' }, - not: { $ref: '#' }, - }, - default: true, -} as unknown as Schema - -const validator = new Validator() - export const validateSchemaAgainstDraft7 = (schemaToValidate: any) => { - const schema = produce(schemaToValidate, (draft: any) => { - // Make sure the schema has the $schema property for draft-07 - if (!draft.$schema) - draft.$schema = 'http://json-schema.org/draft-07/schema#' - }) + // First check against Draft-07 + const result = draft07Validator(schemaToValidate) + // Then apply custom rule + const customErrors = forbidBooleanProperties(schemaToValidate) - const result = validator.validate(schema, draft07MetaSchema, { - nestedErrors: true, - throwError: false, - }) - - // Access errors from the validation result - const errors = result.valid ? [] : result.errors || [] - - return errors + return [...result.errors, ...customErrors] } -export const getValidationErrorMessage = (errors: ValidationError[]) => { +export const getValidationErrorMessage = (errors: Array) => { const message = errors.map((error) => { - return `Error: ${error.path.join('.')} ${error.message} Details: ${JSON.stringify(error.stack)}` - }).join('; ') + if (typeof error === 'string') + return error + else + return `Error: ${error.stack}\n` + }).join('') return message } diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 694b7fb2da..c815ecb5e7 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -12603,7 +12603,7 @@ snapshots: '@vue/compiler-sfc@3.5.17': dependencies: - '@babel/parser': 7.28.0 + '@babel/parser': 7.28.3 '@vue/compiler-core': 3.5.17 '@vue/compiler-dom': 3.5.17 '@vue/compiler-ssr': 3.5.17 diff --git a/web/utils/draft-07.json b/web/utils/draft-07.json new file mode 100644 index 0000000000..99389d7ab4 --- /dev/null +++ b/web/utils/draft-07.json @@ -0,0 +1,245 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "$id": "http://json-schema.org/draft-07/schema#", + "title": "Core schema meta-schema", + "definitions": { + "schemaArray": { + "type": "array", + "minItems": 1, + "items": { + "$ref": "#" + } + }, + "nonNegativeInteger": { + "type": "integer", + "minimum": 0 + }, + "nonNegativeIntegerDefault0": { + "allOf": [ + { + "$ref": "#/definitions/nonNegativeInteger" + }, + { + "default": 0 + } + ] + }, + "simpleTypes": { + "enum": [ + "array", + "boolean", + "integer", + "null", + "number", + "object", + "string" + ] + }, + "stringArray": { + "type": "array", + "items": { + "type": "string" + }, + "uniqueItems": true, + "default": [] + } + }, + "type": [ + "object", + "boolean" + ], + "properties": { + "$id": { + "type": "string", + "format": "uri-reference" + }, + "$schema": { + "type": "string", + "format": "uri" + }, + "$ref": { + "type": "string", + "format": "uri-reference" + }, + "$comment": { + "type": "string" + }, + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "default": true, + "readOnly": { + "type": "boolean", + "default": false + }, + "writeOnly": { + "type": "boolean", + "default": false + }, + "examples": { + "type": "array", + "items": true + }, + "multipleOf": { + "type": "number", + "exclusiveMinimum": 0 + }, + "maximum": { + "type": "number" + }, + "exclusiveMaximum": { + "type": "number" + }, + "minimum": { + "type": "number" + }, + "exclusiveMinimum": { + "type": "number" + }, + "maxLength": { + "$ref": "#/definitions/nonNegativeInteger" + }, + "minLength": { + "$ref": "#/definitions/nonNegativeIntegerDefault0" + }, + "pattern": { + "type": "string", + "format": "regex" + }, + "additionalItems": { + "$ref": "#" + }, + "items": { + "anyOf": [ + { + "$ref": "#" + }, + { + "$ref": "#/definitions/schemaArray" + } + ], + "default": true + }, + "maxItems": { + "$ref": "#/definitions/nonNegativeInteger" + }, + "minItems": { + "$ref": "#/definitions/nonNegativeIntegerDefault0" + }, + "uniqueItems": { + "type": "boolean", + "default": false + }, + "contains": { + "$ref": "#" + }, + "maxProperties": { + "$ref": "#/definitions/nonNegativeInteger" + }, + "minProperties": { + "$ref": "#/definitions/nonNegativeIntegerDefault0" + }, + "required": { + "$ref": "#/definitions/stringArray" + }, + "additionalProperties": { + "$ref": "#" + }, + "definitions": { + "type": "object", + "additionalProperties": { + "$ref": "#" + }, + "default": {} + }, + "properties": { + "type": "object", + "additionalProperties": { + "$ref": "#" + }, + "default": {} + }, + "patternProperties": { + "type": "object", + "additionalProperties": { + "$ref": "#" + }, + "propertyNames": { + "format": "regex" + }, + "default": {} + }, + "dependencies": { + "type": "object", + "additionalProperties": { + "anyOf": [ + { + "$ref": "#" + }, + { + "$ref": "#/definitions/stringArray" + } + ] + } + }, + "propertyNames": { + "$ref": "#" + }, + "const": true, + "enum": { + "type": "array", + "items": true, + "minItems": 1, + "uniqueItems": true + }, + "type": { + "anyOf": [ + { + "$ref": "#/definitions/simpleTypes" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/simpleTypes" + }, + "minItems": 1, + "uniqueItems": true + } + ] + }, + "format": { + "type": "string" + }, + "contentMediaType": { + "type": "string" + }, + "contentEncoding": { + "type": "string" + }, + "if": { + "$ref": "#" + }, + "then": { + "$ref": "#" + }, + "else": { + "$ref": "#" + }, + "allOf": { + "$ref": "#/definitions/schemaArray" + }, + "anyOf": { + "$ref": "#/definitions/schemaArray" + }, + "oneOf": { + "$ref": "#/definitions/schemaArray" + }, + "not": { + "$ref": "#" + } + }, + "default": true +} diff --git a/web/utils/validators.ts b/web/utils/validators.ts new file mode 100644 index 0000000000..51b47feddf --- /dev/null +++ b/web/utils/validators.ts @@ -0,0 +1,27 @@ +import type { Schema } from 'jsonschema' +import { Validator } from 'jsonschema' +import draft07Schema from './draft-07.json' + +const validator = new Validator() + +export const draft07Validator = (schema: any) => { + return validator.validate(schema, draft07Schema as unknown as Schema) +} + +export const forbidBooleanProperties = (schema: any, path: string[] = []): string[] => { + let errors: string[] = [] + + if (schema && typeof schema === 'object' && schema.properties) { + for (const [key, val] of Object.entries(schema.properties)) { + if (typeof val === 'boolean') { + errors.push( + `Error: Property '${[...path, key].join('.')}' must not be a boolean schema`, + ) + } + else if (typeof val === 'object') { + errors = errors.concat(forbidBooleanProperties(val, [...path, key])) + } + } + } + return errors +} From d2e50a508c73f405812693488179b2932329c53f Mon Sep 17 00:00:00 2001 From: ttz12345 <160324589+ttz12345@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:18:31 +0800 Subject: [PATCH 07/18] Fix:About the error problem of creating an empty knowledge base interface in service_api (#25398) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/services/dataset_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2b151f9a8e..65dc673100 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -217,7 +217,7 @@ class DatasetService: and retrieval_model.reranking_model.reranking_model_name ): # check if reranking model setting is valid - DatasetService.check_embedding_model_setting( + DatasetService.check_reranking_model_setting( tenant_id, retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, From ac2aa967c4a748598375cefeb376427b98addec4 Mon Sep 17 00:00:00 2001 From: XiamuSanhua <91169172+AllesOderNicht@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:18:42 +0800 Subject: [PATCH 08/18] feat: change history by supplementary node information (#25294) Co-authored-by: alleschen Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- .../components/workflow/candidate-node.tsx | 4 +-- .../workflow/header/view-workflow-history.tsx | 27 ++++++++++++++++--- .../workflow/hooks/use-nodes-interactions.ts | 16 +++++------ .../workflow/hooks/use-workflow-history.ts | 10 ++++--- .../_base/components/workflow-panel/index.tsx | 4 +-- .../components/workflow/note-node/hooks.ts | 4 +-- .../workflow/workflow-history-store.tsx | 8 ++++++ 7 files changed, 52 insertions(+), 21 deletions(-) diff --git a/web/app/components/workflow/candidate-node.tsx b/web/app/components/workflow/candidate-node.tsx index eb59a4618c..35bcd5c201 100644 --- a/web/app/components/workflow/candidate-node.tsx +++ b/web/app/components/workflow/candidate-node.tsx @@ -62,9 +62,9 @@ const CandidateNode = () => { }) setNodes(newNodes) if (candidateNode.type === CUSTOM_NOTE_NODE) - saveStateToHistory(WorkflowHistoryEvent.NoteAdd) + saveStateToHistory(WorkflowHistoryEvent.NoteAdd, { nodeId: candidateNode.id }) else - saveStateToHistory(WorkflowHistoryEvent.NodeAdd) + saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: candidateNode.id }) workflowStore.setState({ candidateNode: undefined }) diff --git a/web/app/components/workflow/header/view-workflow-history.tsx b/web/app/components/workflow/header/view-workflow-history.tsx index 5c31677f5e..42afd18d25 100644 --- a/web/app/components/workflow/header/view-workflow-history.tsx +++ b/web/app/components/workflow/header/view-workflow-history.tsx @@ -89,10 +89,19 @@ const ViewWorkflowHistory = () => { const calculateChangeList: ChangeHistoryList = useMemo(() => { const filterList = (list: any, startIndex = 0, reverse = false) => list.map((state: Partial, index: number) => { + const nodes = (state.nodes || store.getState().nodes) || [] + const nodeId = state?.workflowHistoryEventMeta?.nodeId + const targetTitle = nodes.find(n => n.id === nodeId)?.data?.title ?? '' return { label: state.workflowHistoryEvent && getHistoryLabel(state.workflowHistoryEvent), index: reverse ? list.length - 1 - index - startIndex : index - startIndex, - state, + state: { + ...state, + workflowHistoryEventMeta: state.workflowHistoryEventMeta ? { + ...state.workflowHistoryEventMeta, + nodeTitle: state.workflowHistoryEventMeta.nodeTitle || targetTitle, + } : undefined, + }, } }).filter(Boolean) @@ -110,6 +119,12 @@ const ViewWorkflowHistory = () => { } }, [futureStates, getHistoryLabel, pastStates, store]) + const composeHistoryItemLabel = useCallback((nodeTitle: string | undefined, baseLabel: string) => { + if (!nodeTitle) + return baseLabel + return `${nodeTitle} ${baseLabel}` + }, []) + return ( ( { 'flex items-center text-[13px] font-medium leading-[18px] text-text-secondary', )} > - {item?.label || t('workflow.changeHistory.sessionStart')} ({calculateStepLabel(item?.index)}{item?.index === currentHistoryStateIndex && t('workflow.changeHistory.currentState')}) + {composeHistoryItemLabel( + item?.state?.workflowHistoryEventMeta?.nodeTitle, + item?.label || t('workflow.changeHistory.sessionStart'), + )} ({calculateStepLabel(item?.index)}{item?.index === currentHistoryStateIndex && t('workflow.changeHistory.currentState')}) @@ -222,7 +240,10 @@ const ViewWorkflowHistory = () => { 'flex items-center text-[13px] font-medium leading-[18px] text-text-secondary', )} > - {item?.label || t('workflow.changeHistory.sessionStart')} ({calculateStepLabel(item?.index)}) + {composeHistoryItemLabel( + item?.state?.workflowHistoryEventMeta?.nodeTitle, + item?.label || t('workflow.changeHistory.sessionStart'), + )} ({calculateStepLabel(item?.index)}) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 7046d1a93a..60549c870e 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -174,7 +174,7 @@ export const useNodesInteractions = () => { if (x !== 0 && y !== 0) { // selecting a note will trigger a drag stop event with x and y as 0 - saveStateToHistory(WorkflowHistoryEvent.NodeDragStop) + saveStateToHistory(WorkflowHistoryEvent.NodeDragStop, { nodeId: node.id }) } } }, [workflowStore, getNodesReadOnly, saveStateToHistory, handleSyncWorkflowDraft]) @@ -423,7 +423,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeConnect) + saveStateToHistory(WorkflowHistoryEvent.NodeConnect, { nodeId: targetNode?.id }) } else { const { @@ -659,10 +659,10 @@ export const useNodesInteractions = () => { handleSyncWorkflowDraft() if (currentNode.type === CUSTOM_NOTE_NODE) - saveStateToHistory(WorkflowHistoryEvent.NoteDelete) + saveStateToHistory(WorkflowHistoryEvent.NoteDelete, { nodeId: currentNode.id }) else - saveStateToHistory(WorkflowHistoryEvent.NodeDelete) + saveStateToHistory(WorkflowHistoryEvent.NodeDelete, { nodeId: currentNode.id }) }, [getNodesReadOnly, store, deleteNodeInspectorVars, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, t]) const handleNodeAdd = useCallback(( @@ -1100,7 +1100,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) } handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeAdd) + saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: newNode.id }) }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch, checkNestedParallelLimit]) const handleNodeChange = useCallback(( @@ -1182,7 +1182,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeChange) + saveStateToHistory(WorkflowHistoryEvent.NodeChange, { nodeId: currentNodeId }) }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory]) const handleNodesCancelSelected = useCallback(() => { @@ -1404,7 +1404,7 @@ export const useNodesInteractions = () => { setNodes([...nodes, ...nodesToPaste]) setEdges([...edges, ...edgesToPaste]) - saveStateToHistory(WorkflowHistoryEvent.NodePaste) + saveStateToHistory(WorkflowHistoryEvent.NodePaste, { nodeId: nodesToPaste?.[0]?.id }) handleSyncWorkflowDraft() } }, [getNodesReadOnly, workflowStore, store, reactflow, saveStateToHistory, handleSyncWorkflowDraft, handleNodeIterationChildrenCopy, handleNodeLoopChildrenCopy]) @@ -1501,7 +1501,7 @@ export const useNodesInteractions = () => { }) setNodes(newNodes) handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeResize) + saveStateToHistory(WorkflowHistoryEvent.NodeResize, { nodeId }) }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) const handleNodeDisconnect = useCallback((nodeId: string) => { diff --git a/web/app/components/workflow/hooks/use-workflow-history.ts b/web/app/components/workflow/hooks/use-workflow-history.ts index 592c0b01cd..b7338dc4f8 100644 --- a/web/app/components/workflow/hooks/use-workflow-history.ts +++ b/web/app/components/workflow/hooks/use-workflow-history.ts @@ -8,6 +8,7 @@ import { } from 'reactflow' import { useTranslation } from 'react-i18next' import { useWorkflowHistoryStore } from '../workflow-history-store' +import type { WorkflowHistoryEventMeta } from '../workflow-history-store' /** * All supported Events that create a new history state. @@ -64,20 +65,21 @@ export const useWorkflowHistory = () => { // Some events may be triggered multiple times in a short period of time. // We debounce the history state update to avoid creating multiple history states // with minimal changes. - const saveStateToHistoryRef = useRef(debounce((event: WorkflowHistoryEvent) => { + const saveStateToHistoryRef = useRef(debounce((event: WorkflowHistoryEvent, meta?: WorkflowHistoryEventMeta) => { workflowHistoryStore.setState({ workflowHistoryEvent: event, + workflowHistoryEventMeta: meta, nodes: store.getState().getNodes(), edges: store.getState().edges, }) }, 500)) - const saveStateToHistory = useCallback((event: WorkflowHistoryEvent) => { + const saveStateToHistory = useCallback((event: WorkflowHistoryEvent, meta?: WorkflowHistoryEventMeta) => { switch (event) { case WorkflowHistoryEvent.NoteChange: // Hint: Note change does not trigger when note text changes, // because the note editors have their own history states. - saveStateToHistoryRef.current(event) + saveStateToHistoryRef.current(event, meta) break case WorkflowHistoryEvent.NodeTitleChange: case WorkflowHistoryEvent.NodeDescriptionChange: @@ -93,7 +95,7 @@ export const useWorkflowHistory = () => { case WorkflowHistoryEvent.NoteAdd: case WorkflowHistoryEvent.LayoutOrganize: case WorkflowHistoryEvent.NoteDelete: - saveStateToHistoryRef.current(event) + saveStateToHistoryRef.current(event, meta) break default: // We do not create a history state for every event. diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx index 3594b8fdbc..a5bf1befbd 100644 --- a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx @@ -154,11 +154,11 @@ const BasePanel: FC = ({ const handleTitleBlur = useCallback((title: string) => { handleNodeDataUpdateWithSyncDraft({ id, data: { title } }) - saveStateToHistory(WorkflowHistoryEvent.NodeTitleChange) + saveStateToHistory(WorkflowHistoryEvent.NodeTitleChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) const handleDescriptionChange = useCallback((desc: string) => { handleNodeDataUpdateWithSyncDraft({ id, data: { desc } }) - saveStateToHistory(WorkflowHistoryEvent.NodeDescriptionChange) + saveStateToHistory(WorkflowHistoryEvent.NodeDescriptionChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) const isChildNode = !!(data.isInIteration || data.isInLoop) diff --git a/web/app/components/workflow/note-node/hooks.ts b/web/app/components/workflow/note-node/hooks.ts index 04e8081692..29642f90df 100644 --- a/web/app/components/workflow/note-node/hooks.ts +++ b/web/app/components/workflow/note-node/hooks.ts @@ -9,7 +9,7 @@ export const useNote = (id: string) => { const handleThemeChange = useCallback((theme: NoteTheme) => { handleNodeDataUpdateWithSyncDraft({ id, data: { theme } }) - saveStateToHistory(WorkflowHistoryEvent.NoteChange) + saveStateToHistory(WorkflowHistoryEvent.NoteChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) const handleEditorChange = useCallback((editorState: EditorState) => { @@ -21,7 +21,7 @@ export const useNote = (id: string) => { const handleShowAuthorChange = useCallback((showAuthor: boolean) => { handleNodeDataUpdateWithSyncDraft({ id, data: { showAuthor } }) - saveStateToHistory(WorkflowHistoryEvent.NoteChange) + saveStateToHistory(WorkflowHistoryEvent.NoteChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) return { diff --git a/web/app/components/workflow/workflow-history-store.tsx b/web/app/components/workflow/workflow-history-store.tsx index 52132f3657..c250708177 100644 --- a/web/app/components/workflow/workflow-history-store.tsx +++ b/web/app/components/workflow/workflow-history-store.tsx @@ -51,6 +51,7 @@ export function useWorkflowHistoryStore() { setState: (state: WorkflowHistoryState) => { store.setState({ workflowHistoryEvent: state.workflowHistoryEvent, + workflowHistoryEventMeta: state.workflowHistoryEventMeta, nodes: state.nodes.map((node: Node) => ({ ...node, data: { ...node.data, selected: false } })), edges: state.edges.map((edge: Edge) => ({ ...edge, selected: false }) as Edge), }) @@ -76,6 +77,7 @@ function createStore({ (set, get) => { return { workflowHistoryEvent: undefined, + workflowHistoryEventMeta: undefined, nodes: storeNodes, edges: storeEdges, getNodes: () => get().nodes, @@ -97,6 +99,7 @@ export type WorkflowHistoryStore = { nodes: Node[] edges: Edge[] workflowHistoryEvent: WorkflowHistoryEvent | undefined + workflowHistoryEventMeta?: WorkflowHistoryEventMeta } export type WorkflowHistoryActions = { @@ -119,3 +122,8 @@ export type WorkflowWithHistoryProviderProps = { edges: Edge[] children: ReactNode } + +export type WorkflowHistoryEventMeta = { + nodeId?: string + nodeTitle?: string +} From 4c92e63b0b95deb3ff1b5ee09e8b8ebe198aef8f Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 9 Sep 2025 16:00:50 +0800 Subject: [PATCH 09/18] fix: avatar is not updated after setted (#25414) --- .../(commonLayout)/account-page/AvatarWithEdit.tsx | 2 +- web/app/components/base/avatar/index.tsx | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 5890c2ea92..f3dbc9421c 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -43,9 +43,9 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const handleSaveAvatar = useCallback(async (uploadedFileId: string) => { try { await updateUserProfile({ url: 'account/avatar', body: { avatar: uploadedFileId } }) - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) setIsShowAvatarPicker(false) onSave?.() + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) } catch (e) { notify({ type: 'error', message: (e as Error).message }) diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/avatar/index.tsx index a6e04a0755..89019a19b0 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/avatar/index.tsx @@ -1,5 +1,5 @@ 'use client' -import { useState } from 'react' +import { useEffect, useState } from 'react' import cn from '@/utils/classnames' export type AvatarProps = { @@ -27,6 +27,12 @@ const Avatar = ({ onError?.(true) } + // after uploaded, api would first return error imgs url: '.../files//file-preview/...'. Then return the right url, Which caused not show the avatar + useEffect(() => { + if(avatar && imgError) + setImgError(false) + }, [avatar]) + if (avatar && !imgError) { return ( Date: Tue, 9 Sep 2025 16:23:44 +0800 Subject: [PATCH 10/18] Revert "example of remove useEffect" (#25418) --- .../variable-inspect/value-content.tsx | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/web/app/components/workflow/variable-inspect/value-content.tsx b/web/app/components/workflow/variable-inspect/value-content.tsx index 2b28cd8ef4..a3ede311c4 100644 --- a/web/app/components/workflow/variable-inspect/value-content.tsx +++ b/web/app/components/workflow/variable-inspect/value-content.tsx @@ -60,18 +60,22 @@ const ValueContent = ({ const [fileValue, setFileValue] = useState(formatFileValue(currentVar)) const { run: debounceValueChange } = useDebounceFn(handleValueChange, { wait: 500 }) - if (showTextEditor) { - if (currentVar.value_type === 'number') - setValue(JSON.stringify(currentVar.value)) - if (!currentVar.value) - setValue('') - setValue(currentVar.value) - } - if (showJSONEditor) - setJson(currentVar.value ? JSON.stringify(currentVar.value, null, 2) : '') - if (showFileEditor) - setFileValue(formatFileValue(currentVar)) + // update default value when id changed + useEffect(() => { + if (showTextEditor) { + if (currentVar.value_type === 'number') + return setValue(JSON.stringify(currentVar.value)) + if (!currentVar.value) + return setValue('') + setValue(currentVar.value) + } + if (showJSONEditor) + setJson(currentVar.value ? JSON.stringify(currentVar.value, null, 2) : '') + + if (showFileEditor) + setFileValue(formatFileValue(currentVar)) + }, [currentVar.id, currentVar.value]) const handleTextChange = (value: string) => { if (currentVar.value_type === 'string') From 38057b1b0ed4398970dc34c78d4d67dec02b84c9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 9 Sep 2025 17:48:33 +0900 Subject: [PATCH 11/18] add typing to all wraps (#25405) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/wraps.py | 11 +++++--- api/controllers/inner_api/plugin/wraps.py | 23 +++++++++------- api/controllers/inner_api/wraps.py | 4 +-- .../service_api/workspace/models.py | 2 +- api/controllers/service_api/wraps.py | 15 ++++++----- api/controllers/web/wraps.py | 10 +++---- .../vdb/matrixone/matrixone_vector.py | 27 ++++++++++--------- .../enterprise/plugin_manager_service.py | 15 +++++++---- 8 files changed, 61 insertions(+), 46 deletions(-) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c7e300279a..5a871f896a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, Union +from typing import Optional, ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db @@ -8,6 +8,9 @@ from libs.login import current_user from models import App, AppMode from models.account import Account +P = ParamSpec("P") +R = TypeVar("R") + def _load_app_model(app_id: str) -> Optional[App]: assert isinstance(current_user, Account) @@ -19,10 +22,10 @@ def _load_app_model(app_id: str) -> Optional[App]: return app_model -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): - def decorator(view_func): +def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index f751e06ddf..68711f7257 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional +from typing import Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -14,6 +14,9 @@ from libs.login import _get_user from models.account import Tenant from models.model import EndUser +P = ParamSpec("P") +R = TypeVar("R") + def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ @@ -52,19 +55,19 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Optional[Callable] = None): - def decorator(view_func): +def get_user_tenant(view: Optional[Callable[P, R]] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): # fetch json body parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json") - kwargs = parser.parse_args() + p = parser.parse_args() - user_id = kwargs.get("user_id") - tenant_id = kwargs.get("tenant_id") + user_id: Optional[str] = p.get("user_id") + tenant_id: str = p.get("tenant_id") if not tenant_id: raise ValueError("tenant_id is required") @@ -107,9 +110,9 @@ def get_user_tenant(view: Optional[Callable] = None): return decorator(view) -def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): - def decorator(view_func): - def decorated_view(*args, **kwargs): +def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]): + def decorator(view_func: Callable[P, R]): + def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() except Exception: diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index de4f1da801..4bdcc6832a 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]): return decorated -def enterprise_inner_api_user_auth(view): +def enterprise_inner_api_user_auth(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.INNER_API: return view(*args, **kwargs) diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 536cf81a2f..fffcb47bd4 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -19,7 +19,7 @@ class ModelProviderAvailableModelApi(Resource): } ) @validate_dataset_token - def get(self, _, model_type): + def get(self, _, model_type: str): """Get available models by model type. Returns a list of available models for the specified model type. diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 14291578d5..4394e64ad9 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Optional, ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -25,6 +25,7 @@ from services.feature_service import FeatureService P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") class WhereisUserArg(StrEnum): @@ -42,10 +43,10 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): - def decorator(view_func): +def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("app") app_model = db.session.query(App).where(App.id == api_token.app_id).first() @@ -189,10 +190,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view=None): - def decorator(view): +def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None): + def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1fbb2c165f..e79456535a 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -20,12 +21,11 @@ P = ParamSpec("P") R = TypeVar("R") -def validate_jwt_token(view=None): - def decorator(view): +def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None): + def decorator(view: Callable[Concatenate[App, EndUser, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): app_model, end_user = decode_jwt_token() - return view(app_model, end_user, *args, **kwargs) return decorated diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 7da830f643..3dd073ce50 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -1,8 +1,9 @@ import json import logging import uuid +from collections.abc import Callable from functools import wraps -from typing import Any, Optional +from typing import Any, Concatenate, Optional, ParamSpec, TypeVar from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset logger = logging.getLogger(__name__) -from typing import ParamSpec, TypeVar P = ParamSpec("P") R = TypeVar("R") @@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel): return values -def ensure_client(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.client is None: - self.client = self._get_client(None, False) - return func(self, *args, **kwargs) - - return wrapper - - class MatrixoneVector(BaseVector): """ Matrixone vector storage implementation. @@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector): self.client.delete() +T = TypeVar("T", bound=MatrixoneVector) + + +def ensure_client(func: Callable[Concatenate[T, P], R]): + @wraps(func) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs): + if self.client is None: + self.client = self._get_client(None, False) + return func(self, *args, **kwargs) + + return wrapper + + class MatrixoneVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: if dataset.index_struct_dict: diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index cfcc39416a..ee8a932ded 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -6,10 +6,12 @@ from pydantic import BaseModel from services.enterprise.base import EnterprisePluginManagerRequest from services.errors.base import BaseServiceError +logger = logging.getLogger(__name__) -class PluginCredentialType(enum.Enum): - MODEL = 0 - TOOL = 1 + +class PluginCredentialType(enum.IntEnum): + MODEL = enum.auto() + TOOL = enum.auto() def to_number(self): return self.value @@ -47,6 +49,9 @@ class PluginManagerService: if not ret.get("result", False): raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") - logging.debug( - f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}" + logger.debug( + "Credential policy compliance checked for %s with credential %s, result: %s", + body.provider, + body.dify_credential_id, + ret.get("result", False), ) From 22cd97e2e0563ee0269d64e46ed267b47722e556 Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Tue, 9 Sep 2025 16:49:22 +0800 Subject: [PATCH 12/18] Fix: judgement of open in explore (#25420) --- web/app/components/apps/app-card.tsx | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index d0d42dc32c..e9a64d8867 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -279,12 +279,21 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { )} { - (isGettingUserCanAccessApp || !userCanAccessApp?.result) ? null : <> - - - + (!systemFeatures.webapp_auth.enabled) + ? <> + + + + : !(isGettingUserCanAccessApp || !userCanAccessApp?.result) && ( + <> + + + + ) } { From e5122945fe1fdb4584ac367729182da38530dadb Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 9 Sep 2025 17:00:00 +0800 Subject: [PATCH 13/18] Fix: Use --fix flag instead of --fix-only in autofix workflow (#25425) --- .github/workflows/autofix.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 82ba95444f..be6ce80dfc 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -20,7 +20,7 @@ jobs: cd api uv sync --dev # Fix lint errors - uv run ruff check --fix-only . + uv run ruff check --fix . # Format code uv run ruff format . - name: ast-grep From a1cf48f84e7af7c792f456d1caec8b2be271868a Mon Sep 17 00:00:00 2001 From: GuanMu Date: Tue, 9 Sep 2025 17:11:49 +0800 Subject: [PATCH 14/18] Add lib test (#25410) --- api/tests/unit_tests/libs/test_file_utils.py | 55 ++++++++++++ .../unit_tests/libs/test_json_in_md_parser.py | 88 +++++++++++++++++++ api/tests/unit_tests/libs/test_orjson.py | 25 ++++++ 3 files changed, 168 insertions(+) create mode 100644 api/tests/unit_tests/libs/test_file_utils.py create mode 100644 api/tests/unit_tests/libs/test_json_in_md_parser.py create mode 100644 api/tests/unit_tests/libs/test_orjson.py diff --git a/api/tests/unit_tests/libs/test_file_utils.py b/api/tests/unit_tests/libs/test_file_utils.py new file mode 100644 index 0000000000..8d9b4e803a --- /dev/null +++ b/api/tests/unit_tests/libs/test_file_utils.py @@ -0,0 +1,55 @@ +from pathlib import Path + +import pytest + +from libs.file_utils import search_file_upwards + + +def test_search_file_upwards_found_in_parent(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + + found = search_file_upwards(base, "target.txt", max_search_parent_depth=5) + assert found == target + + +def test_search_file_upwards_found_in_current(tmp_path: Path): + base = tmp_path / "x" + base.mkdir() + target = base / "here.txt" + target.write_text("x", encoding="utf-8") + + found = search_file_upwards(base, "here.txt", max_search_parent_depth=1) + assert found == target + + +def test_search_file_upwards_not_found_raises(tmp_path: Path): + base = tmp_path / "m" / "n" + base.mkdir(parents=True) + with pytest.raises(ValueError) as exc: + search_file_upwards(base, "missing.txt", max_search_parent_depth=3) + # error message should contain file name and base path + msg = str(exc.value) + assert "missing.txt" in msg + assert str(base) in msg + + +def test_search_file_upwards_root_breaks_and_raises(): + # Using filesystem root triggers the 'break' branch (parent == current) + with pytest.raises(ValueError): + search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1) + + +def test_search_file_upwards_depth_limit_raises(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + # The file is 2 levels up from `c` (in `a`), but search depth is only 2. + # The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3). + # So, this should not find the file and should raise an error. + with pytest.raises(ValueError): + search_file_upwards(base, "target.txt", max_search_parent_depth=2) diff --git a/api/tests/unit_tests/libs/test_json_in_md_parser.py b/api/tests/unit_tests/libs/test_json_in_md_parser.py new file mode 100644 index 0000000000..53fd0bea16 --- /dev/null +++ b/api/tests/unit_tests/libs/test_json_in_md_parser.py @@ -0,0 +1,88 @@ +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from libs.json_in_md_parser import ( + parse_and_check_json_markdown, + parse_json_markdown, +) + + +def test_parse_json_markdown_triple_backticks_json(): + src = """ + ```json + {"a": 1, "b": "x"} + ``` + """ + assert parse_json_markdown(src) == {"a": 1, "b": "x"} + + +def test_parse_json_markdown_triple_backticks_generic(): + src = """ + ``` + {"k": [1, 2, 3]} + ``` + """ + assert parse_json_markdown(src) == {"k": [1, 2, 3]} + + +def test_parse_json_markdown_single_backticks(): + src = '`{"x": true}`' + assert parse_json_markdown(src) == {"x": True} + + +def test_parse_json_markdown_braces_only(): + src = ' {\n \t"ok": "yes"\n} ' + assert parse_json_markdown(src) == {"ok": "yes"} + + +def test_parse_json_markdown_not_found(): + with pytest.raises(ValueError): + parse_json_markdown("no json here") + + +def test_parse_and_check_json_markdown_missing_key(): + src = """ + ``` + {"present": 1} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, ["present", "missing"]) + assert "expected key `missing`" in str(exc.value) + + +def test_parse_and_check_json_markdown_invalid_json(): + src = """ + ```json + {invalid json} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, []) + assert "got invalid json object" in str(exc.value) + + +def test_parse_and_check_json_markdown_success(): + src = """ + ```json + {"present": 1, "other": 2} + ``` + """ + obj = parse_and_check_json_markdown(src, ["present"]) + assert obj == {"present": 1, "other": 2} + + +def test_parse_and_check_json_markdown_multiple_blocks_fails(): + src = """ + ```json + {"a": 1} + ``` + Some text + ```json + {"b": 2} + ``` + """ + # The current implementation is greedy and will match from the first + # opening fence to the last closing fence, causing JSON decode failure. + with pytest.raises(OutputParserError): + parse_and_check_json_markdown(src, []) diff --git a/api/tests/unit_tests/libs/test_orjson.py b/api/tests/unit_tests/libs/test_orjson.py new file mode 100644 index 0000000000..6df1d077df --- /dev/null +++ b/api/tests/unit_tests/libs/test_orjson.py @@ -0,0 +1,25 @@ +import orjson +import pytest + +from libs.orjson import orjson_dumps + + +def test_orjson_dumps_round_trip_basic(): + obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}} + s = orjson_dumps(obj) + assert orjson.loads(s) == obj + + +def test_orjson_dumps_with_unicode_and_indent(): + obj = {"msg": "你好,Dify"} + s = orjson_dumps(obj, option=orjson.OPT_INDENT_2) + # contains indentation newline/spaces + assert "\n" in s + assert orjson.loads(s) == obj + + +def test_orjson_dumps_non_utf8_encoding_fails(): + obj = {"msg": "你好"} + # orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails. + with pytest.raises(UnicodeDecodeError): + orjson_dumps(obj, encoding="ascii") From 7443c5a6fcb7af3e8d7b723a29d0ceeb00cef242 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 9 Sep 2025 17:12:45 +0800 Subject: [PATCH 15/18] refactor: update pyrightconfig to scan all API files (#25429) --- api/pyrightconfig.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index a3a5f2044e..352161523f 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -1,5 +1,5 @@ { - "include": ["models", "configs"], + "include": ["."], "exclude": [".venv", "tests/", "migrations/"], "ignore": [ "core/", From 240b65b980cbc3d679d348ee852c9e7246b4979e Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 9 Sep 2025 20:06:35 +0800 Subject: [PATCH 16/18] fix(mcp): properly handle arrays containing both numbers and strings (#25430) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/tools/mcp_tool/tool.py | 50 +++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 6810ac683d..21d256ae03 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -67,22 +67,42 @@ class MCPTool(Tool): for content in result.content: if isinstance(content, TextContent): - try: - content_json = json.loads(content.text) - if isinstance(content_json, dict): - yield self.create_json_message(content_json) - elif isinstance(content_json, list): - for item in content_json: - yield self.create_json_message(item) - else: - yield self.create_text_message(content.text) - except json.JSONDecodeError: - yield self.create_text_message(content.text) - + yield from self._process_text_content(content) elif isinstance(content, ImageContent): - yield self.create_blob_message( - blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType} - ) + yield self._process_image_content(content) + + def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]: + """Process text content and yield appropriate messages.""" + try: + content_json = json.loads(content.text) + yield from self._process_json_content(content_json) + except json.JSONDecodeError: + yield self.create_text_message(content.text) + + def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]: + """Process JSON content based on its type.""" + if isinstance(content_json, dict): + yield self.create_json_message(content_json) + elif isinstance(content_json, list): + yield from self._process_json_list(content_json) + else: + # For primitive types (str, int, bool, etc.), convert to string + yield self.create_text_message(str(content_json)) + + def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]: + """Process a list of JSON items.""" + if any(not isinstance(item, dict) for item in json_list): + # If the list contains any non-dict item, treat the entire list as a text message. + yield self.create_text_message(str(json_list)) + return + + # Otherwise, process each dictionary as a separate JSON message. + for item in json_list: + yield self.create_json_message(item) + + def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage: + """Process image content and return a blob message.""" + return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}) def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": return MCPTool( From 2ac7a9c8fc586c0895ec329cca005a41e6700922 Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Tue, 9 Sep 2025 20:07:17 +0800 Subject: [PATCH 17/18] Chore: thanks to bump-pydantic (#25437) --- api/core/app/entities/app_invoke_entities.py | 2 +- api/core/app/entities/queue_entities.py | 4 ++-- api/core/app/entities/task_entities.py | 4 ++-- api/core/entities/provider_entities.py | 2 +- api/core/mcp/types.py | 2 +- api/core/ops/entities/trace_entity.py | 10 +++++----- api/core/plugin/entities/plugin_daemon.py | 2 +- .../datasource/vdb/huawei/huawei_cloud_vector.py | 4 ++-- .../rag/datasource/vdb/tencent/tencent_vector.py | 6 +++--- api/core/variables/segments.py | 2 +- api/core/workflow/nodes/base/entities.py | 2 +- .../nodes/variable_assigner/common/helpers.py | 2 +- api/services/app_dsl_service.py | 14 +++++++------- 13 files changed, 28 insertions(+), 28 deletions(-) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 72b62eb67c..9151137fe8 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -95,7 +95,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: Any + app_config: Any = None file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index db0297c352..fc04e60836 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -432,8 +432,8 @@ class QueueAgentLogEvent(AppQueueEvent): id: str label: str node_execution_id: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] metadata: Optional[Mapping[str, Any]] = None diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index a1c0368354..29f3e3427e 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -828,8 +828,8 @@ class AgentLogStreamResponse(StreamResponse): node_execution_id: str id: str label: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] metadata: Optional[Mapping[str, Any]] = None diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 9b8baf1973..52acbc1eef 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -107,7 +107,7 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict | None + credentials: dict | None = None current_credential_id: Optional[str] = None current_credential_name: Optional[str] = None available_model_credentials: list[CredentialConfiguration] = [] diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 49aa8e4498..a2c3157b3b 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -809,7 +809,7 @@ class LoggingMessageNotificationParams(NotificationParams): """The severity of this log message.""" logger: str | None = None """An optional name of the logger issuing this message.""" - data: Any + data: Any = None """ The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 3bad5c92fb..1870da3781 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -35,7 +35,7 @@ class BaseTraceInfo(BaseModel): class WorkflowTraceInfo(BaseTraceInfo): - workflow_data: Any + workflow_data: Any = None conversation_id: Optional[str] = None workflow_app_log_id: Optional[str] = None workflow_id: str @@ -89,7 +89,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo): class DatasetRetrievalTraceInfo(BaseTraceInfo): - documents: Any + documents: Any = None class ToolTraceInfo(BaseTraceInfo): @@ -97,12 +97,12 @@ class ToolTraceInfo(BaseTraceInfo): tool_inputs: dict[str, Any] tool_outputs: str metadata: dict[str, Any] - message_file_data: Any + message_file_data: Any = None error: Optional[str] = None tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] + file_url: Union[str, None, list] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -113,7 +113,7 @@ class GenerateNameTraceInfo(BaseTraceInfo): class TaskData(BaseModel): app_id: str trace_info_type: str - trace_info: Any + trace_info: Any = None trace_info_info_map = { diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 16ab661092..f1d6860bb4 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -24,7 +24,7 @@ class PluginDaemonBasicResponse(BaseModel, Generic[T]): code: int message: str - data: Optional[T] + data: Optional[T] = None class InstallPluginMessage(BaseModel): diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 107ea75e6a..0eca37a129 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -28,8 +28,8 @@ def create_ssl_context() -> ssl.SSLContext: class HuaweiCloudVectorConfig(BaseModel): hosts: str - username: str | None - password: str | None + username: str | None = None + password: str | None = None @model_validator(mode="before") @classmethod diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 4af34bbb2d..2485857070 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -24,10 +24,10 @@ logger = logging.getLogger(__name__) class TencentConfig(BaseModel): url: str - api_key: Optional[str] + api_key: Optional[str] = None timeout: float = 30 - username: Optional[str] - database: Optional[str] + username: Optional[str] = None + database: Optional[str] = None index_type: str = "HNSW" metric_type: str = "IP" shard: int = 1 diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index cfef193633..7da43a6504 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -19,7 +19,7 @@ class Segment(BaseModel): model_config = ConfigDict(frozen=True) value_type: SegmentType - value: Any + value: Any = None @field_validator("value_type") @classmethod diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 708da21177..90e45e9d25 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -23,7 +23,7 @@ NumberType = Union[int, float] class DefaultValue(BaseModel): - value: Any + value: Any = None type: DefaultValueType key: str diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 8caee27363..04a7323739 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel): name: str selector: Sequence[str] value_type: SegmentType - new_value: Any + new_value: Any = None _T = TypeVar("_T", bound=MutableMapping[str, Any]) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 2ed73ffec1..49ff28d191 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -99,17 +99,17 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus: class PendingData(BaseModel): import_mode: str yaml_content: str - name: str | None - description: str | None - icon_type: str | None - icon: str | None - icon_background: str | None - app_id: str | None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + app_id: str | None = None class CheckDependenciesPendingData(BaseModel): dependencies: list[PluginDependency] - app_id: str | None + app_id: str | None = None class AppDslService: From 08dd3f7b5079fe9171351ea79054302c915e42d1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:54:26 +0800 Subject: [PATCH 18/18] Fix basedpyright type errors (#25435) Signed-off-by: -LAN- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/commands.py | 18 +++- api/constants/__init__.py | 12 +-- api/contexts/__init__.py | 1 - api/controllers/console/__init__.py | 100 ++++++++++-------- api/controllers/console/apikey.py | 13 +-- api/controllers/console/app/app.py | 30 ++++-- api/controllers/console/app/audio.py | 4 +- api/controllers/console/app/completion.py | 28 ++--- api/controllers/console/app/conversation.py | 6 +- api/controllers/console/app/message.py | 13 ++- api/controllers/console/app/site.py | 6 +- api/controllers/console/app/statistic.py | 12 +-- .../console/app/workflow_statistic.py | 6 +- api/controllers/console/auth/oauth.py | 5 +- api/controllers/console/explore/completion.py | 11 +- .../console/explore/conversation.py | 13 ++- .../console/explore/installed_app.py | 13 ++- api/controllers/console/explore/message.py | 11 +- .../console/explore/recommended_app.py | 8 +- .../console/explore/saved_message.py | 9 +- api/controllers/console/files.py | 3 + api/controllers/console/version.py | 6 +- api/controllers/console/workspace/account.py | 32 ++++++ api/controllers/console/workspace/members.py | 59 +++++++++-- .../console/workspace/model_providers.py | 37 +++++++ .../console/workspace/workspace.py | 24 ++++- api/controllers/files/__init__.py | 2 +- api/controllers/inner_api/__init__.py | 6 +- api/controllers/inner_api/plugin/plugin.py | 30 +++--- api/controllers/inner_api/plugin/wraps.py | 10 +- api/controllers/mcp/__init__.py | 2 +- api/controllers/service_api/__init__.py | 26 ++++- .../service_api/app/conversation.py | 3 +- .../service_api/dataset/document.py | 6 ++ api/controllers/service_api/wraps.py | 4 +- api/controllers/web/__init__.py | 28 ++--- api/core/__init__.py | 1 - api/core/agent/cot_agent_runner.py | 2 + api/core/agent/fc_agent_runner.py | 1 + .../sensitive_word_avoidance/manager.py | 11 +- .../prompt_template/manager.py | 10 +- .../generate_response_converter.py | 12 +-- .../advanced_chat/generate_task_pipeline.py | 24 ++--- .../app/apps/agent_chat/app_config_manager.py | 34 +++--- .../agent_chat/generate_response_converter.py | 11 +- api/core/app/apps/base_app_queue_manager.py | 1 + .../apps/chat/generate_response_converter.py | 11 +- api/core/app/apps/completion/app_generator.py | 2 + .../completion/generate_response_converter.py | 13 ++- .../workflow/generate_response_converter.py | 10 +- .../apps/workflow/generate_task_pipeline.py | 10 +- api/core/app/entities/app_invoke_entities.py | 6 +- api/core/app/entities/task_entities.py | 7 -- .../annotation_reply/annotation_reply.py | 3 + .../app/features/rate_limiting/__init__.py | 2 + .../app/features/rate_limiting/rate_limit.py | 2 +- .../based_generate_task_pipeline.py | 22 ++-- .../easy_ui_based_generate_task_pipeline.py | 22 ++-- .../base/tts/app_generator_tts_publisher.py | 6 +- api/core/entities/provider_configuration.py | 8 +- api/core/file/file_manager.py | 6 +- api/core/file/models.py | 8 ++ api/core/helper/ssrf_proxy.py | 14 +-- api/core/indexing_runner.py | 7 +- api/core/llm_generator/llm_generator.py | 12 ++- .../output_parser/structured_output.py | 14 ++- api/core/mcp/client/sse_client.py | 8 +- api/core/mcp/server/streamable_http.py | 28 ++--- api/core/mcp/session/base_session.py | 12 +-- .../__base/large_language_model.py | 2 +- api/core/plugin/entities/parameters.py | 5 +- api/core/plugin/utils/chunk_merger.py | 4 +- api/core/prompt/simple_prompt_transform.py | 32 ++++-- .../datasource/vdb/qdrant/qdrant_vector.py | 35 ++++-- ...lery_workflow_node_execution_repository.py | 4 +- api/core/variables/segment_group.py | 2 +- api/core/variables/segments.py | 24 ++--- api/core/workflow/errors.py | 4 +- api/core/workflow/nodes/list_operator/node.py | 4 +- api/core/workflow/nodes/llm/node.py | 3 +- api/factories/file_factory.py | 4 +- api/fields/_value_type_serializer.py | 5 +- api/libs/external_api.py | 14 ++- api/libs/helper.py | 7 -- api/pyrightconfig.json | 54 +++++++--- api/services/account_service.py | 4 +- api/services/annotation_service.py | 54 ++++++---- .../clear_free_plan_tenant_expired_logs.py | 1 + api/services/dataset_service.py | 66 ++---------- api/services/external_knowledge_service.py | 2 +- api/services/file_service.py | 4 +- api/services/model_load_balancing_service.py | 17 +-- api/services/plugin/plugin_migration.py | 1 + .../tools/builtin_tools_manage_service.py | 10 +- api/services/workflow/workflow_converter.py | 16 ++- api/services/workflow_service.py | 4 +- api/services/workspace_service.py | 2 +- .../services/test_account_service.py | 4 +- .../workflow/test_workflow_converter.py | 3 +- .../services/test_account_service.py | 16 +-- 100 files changed, 847 insertions(+), 497 deletions(-) diff --git a/api/commands.py b/api/commands.py index 9b13cc2e1a..2bef83b2a7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -511,7 +511,7 @@ def add_qdrant_index(field: str): from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig for binding in bindings: if dify_config.QDRANT_URL is None: @@ -525,7 +525,21 @@ def add_qdrant_index(field: str): prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: - client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) + params = qdrant_config.to_qdrant_params() + # Check the type before using + if isinstance(params, PathQdrantParams): + # PathQdrantParams case + client = qdrant_client.QdrantClient(path=params.path) + else: + # UrlQdrantParams case - params is UrlQdrantParams + client = qdrant_client.QdrantClient( + url=params.url, + api_key=params.api_key, + timeout=int(params.timeout), + verify=params.verify, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + ) # create payload index client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 diff --git a/api/constants/__init__.py b/api/constants/__init__.py index c98f4d55c8..fe8f4f8785 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) +_doc_extensions: list[str] if dify_config.ETL_TYPE == "Unstructured": - DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] - DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) + _doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] + _doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) if dify_config.UNSTRUCTURED_API_URL: - DOCUMENT_EXTENSIONS.append("ppt") - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) + _doc_extensions.append("ppt") else: - DOCUMENT_EXTENSIONS = [ + _doc_extensions = [ "txt", "markdown", "md", @@ -38,4 +38,4 @@ else: "vtt", "properties", ] - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) +DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index ae41a2c03a..a07e6a08a6 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController - from core.workflow.entities.variable_pool import VariablePool """ diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 5ad7645969..9a8e840554 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -43,56 +43,64 @@ api.add_resource(AppImportConfirmApi, "/apps/imports//confirm" api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") # Import other controllers -from . import admin, apikey, extension, feature, ping, setup, version +from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport] # Import app controllers from .app import ( - advanced_prompt_template, - agent, - annotation, - app, - audio, - completion, - conversation, - conversation_variables, - generator, - mcp_server, - message, - model_config, - ops_trace, - site, - statistic, - workflow, - workflow_app_log, - workflow_draft_variable, - workflow_run, - workflow_statistic, + advanced_prompt_template, # pyright: ignore[reportUnusedImport] + agent, # pyright: ignore[reportUnusedImport] + annotation, # pyright: ignore[reportUnusedImport] + app, # pyright: ignore[reportUnusedImport] + audio, # pyright: ignore[reportUnusedImport] + completion, # pyright: ignore[reportUnusedImport] + conversation, # pyright: ignore[reportUnusedImport] + conversation_variables, # pyright: ignore[reportUnusedImport] + generator, # pyright: ignore[reportUnusedImport] + mcp_server, # pyright: ignore[reportUnusedImport] + message, # pyright: ignore[reportUnusedImport] + model_config, # pyright: ignore[reportUnusedImport] + ops_trace, # pyright: ignore[reportUnusedImport] + site, # pyright: ignore[reportUnusedImport] + statistic, # pyright: ignore[reportUnusedImport] + workflow, # pyright: ignore[reportUnusedImport] + workflow_app_log, # pyright: ignore[reportUnusedImport] + workflow_draft_variable, # pyright: ignore[reportUnusedImport] + workflow_run, # pyright: ignore[reportUnusedImport] + workflow_statistic, # pyright: ignore[reportUnusedImport] ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server +from .auth import ( + activate, # pyright: ignore[reportUnusedImport] + data_source_bearer_auth, # pyright: ignore[reportUnusedImport] + data_source_oauth, # pyright: ignore[reportUnusedImport] + forgot_password, # pyright: ignore[reportUnusedImport] + login, # pyright: ignore[reportUnusedImport] + oauth, # pyright: ignore[reportUnusedImport] + oauth_server, # pyright: ignore[reportUnusedImport] +) # Import billing controllers -from .billing import billing, compliance +from .billing import billing, compliance # pyright: ignore[reportUnusedImport] # Import datasets controllers from .datasets import ( - data_source, - datasets, - datasets_document, - datasets_segments, - external, - hit_testing, - metadata, - website, + data_source, # pyright: ignore[reportUnusedImport] + datasets, # pyright: ignore[reportUnusedImport] + datasets_document, # pyright: ignore[reportUnusedImport] + datasets_segments, # pyright: ignore[reportUnusedImport] + external, # pyright: ignore[reportUnusedImport] + hit_testing, # pyright: ignore[reportUnusedImport] + metadata, # pyright: ignore[reportUnusedImport] + website, # pyright: ignore[reportUnusedImport] ) # Import explore controllers from .explore import ( - installed_app, - parameter, - recommended_app, - saved_message, + installed_app, # pyright: ignore[reportUnusedImport] + parameter, # pyright: ignore[reportUnusedImport] + recommended_app, # pyright: ignore[reportUnusedImport] + saved_message, # pyright: ignore[reportUnusedImport] ) # Explore Audio @@ -167,18 +175,18 @@ api.add_resource( ) # Import tag controllers -from .tag import tags +from .tag import tags # pyright: ignore[reportUnusedImport] # Import workspace controllers from .workspace import ( - account, - agent_providers, - endpoint, - load_balancing_config, - members, - model_providers, - models, - plugin, - tool_providers, - workspace, + account, # pyright: ignore[reportUnusedImport] + agent_providers, # pyright: ignore[reportUnusedImport] + endpoint, # pyright: ignore[reportUnusedImport] + load_balancing_config, # pyright: ignore[reportUnusedImport] + members, # pyright: ignore[reportUnusedImport] + model_providers, # pyright: ignore[reportUnusedImport] + models, # pyright: ignore[reportUnusedImport] + plugin, # pyright: ignore[reportUnusedImport] + tool_providers, # pyright: ignore[reportUnusedImport] + workspace, # pyright: ignore[reportUnusedImport] ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index cfd5f73ade..58a1d437d1 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,8 +1,9 @@ -from typing import Any, Optional +from typing import Optional import flask_restx from flask_login import current_user from flask_restx import Resource, fields, marshal_with +from flask_restx._http import HTTPStatus from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model): ).scalar_one_or_none() if resource is None: - flask_restx.abort(404, message=f"{resource_model.__name__} not found.") + flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") return resource @@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: Optional[type] = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource): if current_key_count >= self.max_keys: flask_restx.abort( - 400, + HTTPStatus.BAD_REQUEST, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", custom="max_keys_exceeded", ) @@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: Optional[type] = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): @@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource): ) if key is None: - flask_restx.abort(404, message="API key not found") + flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 10753d2f95..1db9d2e764 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -115,6 +115,10 @@ class AppListApi(Resource): raise BadRequest("mode is required") app_service = AppService() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + if current_user.current_tenant_id is None: + raise ValueError("current_user.current_tenant_id cannot be None") app = app_service.create_app(current_user.current_tenant_id, args, current_user) return app, 201 @@ -161,14 +165,26 @@ class AppApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app(app_model, args) + # Construct ArgsDict from parsed arguments + from services.app_service import AppService as AppServiceType + + args_dict: AppServiceType.ArgsDict = { + "name": args["name"], + "description": args.get("description", ""), + "icon_type": args.get("icon_type", ""), + "icon": args.get("icon", ""), + "icon_background": args.get("icon_background", ""), + "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), + "max_active_requests": args.get("max_active_requests", 0), + } + app_model = app_service.update_app(app_model, args_dict) return app_model + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def delete(self, app_model): """Delete app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -224,10 +240,10 @@ class AppCopyApi(Resource): class AppExportApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): """Export app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -263,7 +279,7 @@ class AppNameApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get("name")) + app_model = app_service.update_app_name(app_model, args["name"]) return app_model @@ -285,7 +301,7 @@ class AppIconApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) + app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") return app_model @@ -306,7 +322,7 @@ class AppSiteStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) + app_model = app_service.update_app_site_status(app_model, args["enable_site"]) return app_model @@ -327,7 +343,7 @@ class AppApiStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) + app_model = app_service.update_app_api_status(app_model, args["enable_api"]) return app_model diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index aaf5c3dfaa..447bcb37c2 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource): class ChatMessageTextApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model: App): try: parser = reqparse.RequestParser() @@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource): class TextModesApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): try: parser = reqparse.RequestParser() diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 701ebb0b4a..2083c15a9b 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,6 +1,5 @@ import logging -import flask_login from flask import request from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from libs.login import login_required +from libs.login import current_user, login_required +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -56,11 +56,11 @@ class CompletionMessageApi(Resource): streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 @@ -123,11 +123,11 @@ class ChatMessageApi(Resource): if external_trace_id: args["external_trace_id"] = external_trace_id - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -161,9 +161,9 @@ class ChatMessageStopApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index bc825effad..2f2cd66aaa 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -22,7 +22,7 @@ from fields.conversation_fields import ( from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString from libs.login import login_required -from models import Conversation, EndUser, Message, MessageAnnotation +from models import Account, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError @@ -124,6 +124,8 @@ class CompletionConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -282,6 +284,8 @@ class ChatConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index f0605a37f9..272f360c06 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy import exists, select @@ -27,7 +26,8 @@ from extensions.ext_database import db from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError @@ -118,11 +118,14 @@ class ChatMessageListApi(Resource): class MessageFeedbackApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model): + if current_user is None: + raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") @@ -167,6 +170,8 @@ class MessageAnnotationApi(Resource): @get_app_model @marshal_with(annotation_fields) def post(self, app_model): + if not isinstance(current_user, Account): + raise Forbidden() if not current_user.is_editor: raise Forbidden() @@ -182,10 +187,10 @@ class MessageAnnotationApi(Resource): class MessageAnnotationCountApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 778ce92da6..871efd989c 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -10,7 +10,7 @@ from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now from libs.login import login_required -from models import Site +from models import Account, Site def parse_app_site_args(): @@ -75,6 +75,8 @@ class AppSite(Resource): if value is not None: setattr(site, attr_name, value) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() @@ -99,6 +101,8 @@ class AppSiteAccessTokenReset(Resource): raise NotFound site.code = Site.generate_code(16) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 27e405af38..2116732c73 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -18,10 +18,10 @@ from models import AppMode, Message class DailyMessageStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -75,10 +75,10 @@ WHERE class DailyConversationStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource): class DailyTerminalsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -184,10 +184,10 @@ WHERE class DailyTokenCostStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -320,10 +320,10 @@ ORDER BY class UserSatisfactionRateStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -443,10 +443,10 @@ WHERE class TokensPerSecondStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 7cef175c14..da7216086e 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -18,10 +18,10 @@ from models.model import AppMode class WorkflowDailyRunsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -80,10 +80,10 @@ WHERE class WorkflowDailyTerminalsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -142,10 +142,10 @@ WHERE class WorkflowDailyTokenCostStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 332a98c474..06151ee39b 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -77,6 +77,9 @@ class OAuthCallback(Resource): if state: invite_token = state + if not code: + return {"error": "Authorization code is required"}, 400 + try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) @@ -86,7 +89,7 @@ class OAuthCallback(Resource): return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): - invitation = RegisterService._get_invitation_by_token(token=invite_token) + invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) if invitation_email != user_info.email: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index cc46f54ea3..a99708b7cd 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -28,6 +27,8 @@ from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) @@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 @@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) @@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 43ad3ecfbd..1aef9c544d 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session @@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError @@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource): pinned = args["pinned"] == "true" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") with Session(db.engine) as session: return WebConversationService.pagination_by_last_id( session=session, @@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return ConversationService.rename( app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) @@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.pin(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource): raise NotChatAppError() conversation_id = str(c_id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3ccedd654b..22aa753d92 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,7 +2,6 @@ import logging from typing import Any from flask import request -from flask_login import current_user from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required -from models import App, InstalledApp, RecommendedApp +from libs.login import current_user, login_required +from models import Account, App, InstalledApp, RecommendedApp from services.account_service import TenantService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -29,6 +28,8 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): app_id = request.args.get("app_id", default=None, type=str) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id if app_id: @@ -40,6 +41,8 @@ class InstalledAppsListApi(Resource): else: installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() + if current_user.current_tenant is None: + raise ValueError("current_user.current_tenant must not be None") current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_app_list: list[dict[str, Any]] = [ { @@ -115,6 +118,8 @@ class InstalledAppsListApi(Resource): if recommended_app is None: raise NotFound("App not found") + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id app = db.session.query(App).where(App.id == args["app_id"]).first() @@ -154,6 +159,8 @@ class InstalledAppApi(InstalledAppResource): """ def delete(self, installed_app): + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") if installed_app.app_owner_tenant_id == current_user.current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 608bc6d007..c46c1c1f4f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound @@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return MessageService.pagination_by_first_id( app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] ) @@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") MessageService.create_feedback( app_model=app_model, message_id=message_id, @@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate_more_like_this( app_model=app_model, user=current_user, @@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): message_id = str(message_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 62f9350b71..974222ddf7 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,11 +1,10 @@ -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages from controllers.console import api from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField -from libs.login import login_required +from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService app_fields = { @@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource): parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get("language") and args.get("language") in languages: - language_prefix = args.get("language") + language = args.get("language") + if language and language in languages: + language_prefix = language elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 5353dbcad5..6f05f898f9 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound @@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value +from libs.login import current_user +from models import Account from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource): parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): @@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.delete(app_model, current_user, message_id) return {"result": "success"}, 204 diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 101a49a32e..5d11dec523 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -22,6 +22,7 @@ from controllers.console.wraps import ( ) from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required +from models import Account from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -68,6 +69,8 @@ class FileApi(Resource): source = None try: + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") upload_file = FileService.upload_file( filename=file.filename, content=file.read(), diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 95515c38f9..8409e7d1ab 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -34,14 +34,14 @@ class VersionApi(Resource): return result try: - response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) + response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) except Exception as error: logger.warning("Check update version error: %s.", str(error)) - result["version"] = args.get("current_version") + result["version"] = args["current_version"] return result content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): + if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): result["version"] = content["version"] result["release_date"] = content["releaseDate"] result["release_notes"] = content["releaseNotes"] diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 5b2828dbab..bd078729c4 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -49,6 +49,8 @@ class AccountInitApi(Resource): @setup_required @login_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user if account.status == "active": @@ -102,6 +104,8 @@ class AccountProfileApi(Resource): @marshal_with(account_fields) @enterprise_license_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") return current_user @@ -111,6 +115,8 @@ class AccountNameApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() @@ -130,6 +136,8 @@ class AccountAvatarApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() @@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() @@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() @@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() @@ -194,6 +208,8 @@ class AccountPasswordApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("new_password", type=str, required=True, location="json") @@ -228,6 +244,8 @@ class AccountIntegrateApi(Resource): @account_initialization_required @marshal_with(integrate_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() @@ -268,6 +286,8 @@ class AccountDeleteVerifyApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user token, code = AccountService.generate_account_deletion_verification_code(account) @@ -281,6 +301,8 @@ class AccountDeleteApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -321,6 +343,8 @@ class EducationVerifyApi(Resource): @cloud_edition_billing_enabled @marshal_with(verify_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user return BillingService.EducationIdentity.verify(account.id, account.email) @@ -340,6 +364,8 @@ class EducationApi(Resource): @only_edition_cloud @cloud_edition_billing_enabled def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -357,6 +383,8 @@ class EducationApi(Resource): @cloud_edition_billing_enabled @marshal_with(status_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user res = BillingService.EducationIdentity.status(account.id) @@ -421,6 +449,8 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if user_email != current_user.email: raise InvalidEmailError() else: @@ -501,6 +531,8 @@ class ChangeEmailResetApi(Resource): AccountService.revoke_change_email_token(args["token"]) old_email = reset_data.get("old_email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if current_user.email != old_email: raise AccountNotFound() diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf2a10f453..77f0c9a735 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,8 +1,8 @@ from urllib import parse -from flask import request +from flask import abort, request from flask_login import current_user -from flask_restx import Resource, abort, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse import services from configs import dify_config @@ -41,6 +41,10 @@ class MemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource): if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") inviter = current_user + if not inviter.current_tenant: + raise ValueError("No current tenant") invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL @@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource): for invitee_email in invitee_emails: try: + if not inviter.current_tenant: + raise ValueError("No current tenant") token = RegisterService.invite_new_member( inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter ) @@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource): return { "result": "success", "invitation_results": invitation_results, - "tenant_id": str(current_user.current_tenant.id), + "tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "", }, 201 @@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.query(Account).where(Account.id == str(member_id)).first() if member is None: abort(404) @@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource): except Exception as e: raise ValueError(str(e)) - return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 + return { + "result": "success", + "tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "", + }, 200 class MemberUpdateRoleApi(Resource): @@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource): if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.get(Account, str(member_id)) if not member: abort(404) @@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource): raise EmailSendIpLimitError() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource): account=current_user, email=email, language=language, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) return {"result": "success", "data": token} @@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource): parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -256,6 +289,10 @@ class OwnerTransfer(Resource): args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -274,9 +311,11 @@ class OwnerTransfer(Resource): member = db.session.get(Account, str(member_id)) if not member: abort(404) - else: - member_account = member - if not TenantService.is_member(member_account, current_user.current_tenant): + return # Never reached, but helps type checker + + if not current_user.current_tenant: + raise ValueError("No current tenant") + if not TenantService.is_member(member, current_user.current_tenant): raise MemberNotInTenantError() try: @@ -286,13 +325,13 @@ class OwnerTransfer(Resource): AccountService.send_new_owner_transfer_notify_email( account=member, email=member.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) AccountService.send_old_owner_transfer_notify_email( account=current_user, email=current_user.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", new_owner_email=member.email, ) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index bfcc9a7f0a..0c9db660aa 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value from libs.login import login_required +from models.account import Account from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -21,6 +22,10 @@ class ModelProviderListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id # if credential_id is not provided, return current used credential parser = reqparse.RequestParser() @@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.create_provider_credential( tenant_id=current_user.current_tenant_id, @@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.update_provider_credential( tenant_id=current_user.current_tenant_id, @@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") model_provider_service = ModelProviderService() model_provider_service.remove_provider_credential( tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] @@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") service = ModelProviderService() service.switch_active_provider_credential( tenant_id=current_user.current_tenant_id, @@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() @@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): def get(self, provider: str): if provider != "anthropic": raise ValueError(f"provider name {provider} is invalid") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") BillingService.is_tenant_owner_or_admin(current_user) + if not current_user.current_tenant_id: + raise ValueError("No current tenant") data = BillingService.get_model_provider_payment_link( provider_name=provider, tenant_id=current_user.current_tenant_id, diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index e7a3aca66c..655afbe73f 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required -from models.account import Tenant, TenantStatus +from models.account import Account, Tenant, TenantStatus from services.account_service import TenantService from services.feature_service import FeatureService from services.file_service import FileService @@ -70,6 +70,8 @@ class TenantListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] @@ -83,7 +85,7 @@ class TenantListApi(Resource): "status": tenant.status, "created_at": tenant.created_at, "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", - "current": tenant.id == current_user.current_tenant_id, + "current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, } tenant_dicts.append(tenant_dict) @@ -125,7 +127,11 @@ class TenantApi(Resource): if request.path == "/info": logger.warning("Deprecated URL /info was used.") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenant = current_user.current_tenant + if not tenant: + raise ValueError("No current tenant") if tenant.status == TenantStatus.ARCHIVE: tenants = TenantService.get_join_tenants(current_user) @@ -137,6 +143,8 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") + if not tenant: + raise ValueError("No tenant available") return WorkspaceService.get_tenant_info(tenant), 200 @@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() @@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) custom_config_dict = { @@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") # check file if "file" not in request.files: raise NoFileUploadedError() @@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource): @account_initialization_required # Change workspace name def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant.name = args["name"] db.session.commit() diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 821ad220a2..a1b8bb7cfe 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -15,6 +15,6 @@ api = ExternalApi( files_ns = Namespace("files", description="File operations", path="/") -from . import image_preview, tool_files, upload +from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport] api.add_namespace(files_ns) diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index d29a7be139..b09c39309f 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -16,8 +16,8 @@ api = ExternalApi( # Create namespace inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") -from . import mail -from .plugin import plugin -from .workspace import workspace +from . import mail as _mail # pyright: ignore[reportUnusedImport] +from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport] +from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport] api.add_namespace(inner_api_ns) diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 170a794d89..c5bb2f2545 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -37,9 +37,9 @@ from models.model import EndUser @inner_api_ns.route("/invoke/llm") class PluginInvokeLLMApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLM) @inner_api_ns.doc("plugin_invoke_llm") @inner_api_ns.doc(description="Invoke LLM models through plugin interface") @@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource): @inner_api_ns.route("/invoke/llm/structured-output") class PluginInvokeLLMWithStructuredOutputApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) @inner_api_ns.doc("plugin_invoke_llm_structured") @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") @@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource): @inner_api_ns.route("/invoke/text-embedding") class PluginInvokeTextEmbeddingApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) @inner_api_ns.doc("plugin_invoke_text_embedding") @inner_api_ns.doc(description="Invoke text embedding models through plugin interface") @@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource): @inner_api_ns.route("/invoke/rerank") class PluginInvokeRerankApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeRerank) @inner_api_ns.doc("plugin_invoke_rerank") @inner_api_ns.doc(description="Invoke rerank models through plugin interface") @@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource): @inner_api_ns.route("/invoke/tts") class PluginInvokeTTSApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTTS) @inner_api_ns.doc("plugin_invoke_tts") @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") @@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource): @inner_api_ns.route("/invoke/speech2text") class PluginInvokeSpeech2TextApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSpeech2Text) @inner_api_ns.doc("plugin_invoke_speech2text") @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") @@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource): @inner_api_ns.route("/invoke/moderation") class PluginInvokeModerationApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeModeration) @inner_api_ns.doc("plugin_invoke_moderation") @inner_api_ns.doc(description="Invoke moderation models through plugin interface") @@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource): @inner_api_ns.route("/invoke/tool") class PluginInvokeToolApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTool) @inner_api_ns.doc("plugin_invoke_tool") @inner_api_ns.doc(description="Invoke tools through plugin interface") @@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource): @inner_api_ns.route("/invoke/parameter-extractor") class PluginInvokeParameterExtractorNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeParameterExtractorNode) @inner_api_ns.doc("plugin_invoke_parameter_extractor") @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") @@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource): @inner_api_ns.route("/invoke/question-classifier") class PluginInvokeQuestionClassifierNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) @inner_api_ns.doc("plugin_invoke_question_classifier") @inner_api_ns.doc(description="Invoke question classifier node through plugin interface") @@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): @inner_api_ns.route("/invoke/app") class PluginInvokeAppApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeApp) @inner_api_ns.doc("plugin_invoke_app") @inner_api_ns.doc(description="Invoke application through plugin interface") @@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource): @inner_api_ns.route("/invoke/encrypt") class PluginInvokeEncryptApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeEncrypt) @inner_api_ns.doc("plugin_invoke_encrypt") @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") @@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource): @inner_api_ns.route("/invoke/summary") class PluginInvokeSummaryApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSummary) @inner_api_ns.doc("plugin_invoke_summary") @inner_api_ns.doc(description="Invoke summary functionality through plugin interface") @@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource): @inner_api_ns.route("/upload/file/request") class PluginUploadFileRequestApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestRequestUploadFile) @inner_api_ns.doc("plugin_upload_file_request") @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") @@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource): @inner_api_ns.route("/fetch/app/info") class PluginFetchAppInfoApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestFetchAppInfo) @inner_api_ns.doc("plugin_fetch_app_info") @inner_api_ns.doc(description="Fetch application information through plugin interface") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 68711f7257..18b530f2c4 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, ParamSpec, TypeVar +from typing import Optional, ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in @@ -10,7 +10,7 @@ from sqlalchemy.orm import Session from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db -from libs.login import _get_user +from libs.login import current_user from models.account import Tenant from models.model import EndUser @@ -66,8 +66,8 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): p = parser.parse_args() - user_id: Optional[str] = p.get("user_id") - tenant_id: str = p.get("tenant_id") + user_id = cast(str, p.get("user_id")) + tenant_id = cast(str, p.get("tenant_id")) if not tenant_id: raise ValueError("tenant_id is required") @@ -98,7 +98,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): kwargs["user_model"] = user current_app.login_manager._update_request_context_with_user(user) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore return view_func(*args, **kwargs) diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index c344ffad08..43b36a70b4 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -15,6 +15,6 @@ api = ExternalApi( mcp_ns = Namespace("mcp", description="MCP operations", path="/") -from . import mcp +from . import mcp # pyright: ignore[reportUnusedImport] api.add_namespace(mcp_ns) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 763345d723..d69f49d957 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -15,9 +15,27 @@ api = ExternalApi( service_api_ns = Namespace("service_api", description="Service operations", path="/") -from . import index -from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow -from .dataset import dataset, document, hit_testing, metadata, segment, upload_file -from .workspace import models +from . import index # pyright: ignore[reportUnusedImport] +from .app import ( + annotation, # pyright: ignore[reportUnusedImport] + app, # pyright: ignore[reportUnusedImport] + audio, # pyright: ignore[reportUnusedImport] + completion, # pyright: ignore[reportUnusedImport] + conversation, # pyright: ignore[reportUnusedImport] + file, # pyright: ignore[reportUnusedImport] + file_preview, # pyright: ignore[reportUnusedImport] + message, # pyright: ignore[reportUnusedImport] + site, # pyright: ignore[reportUnusedImport] + workflow, # pyright: ignore[reportUnusedImport] +) +from .dataset import ( + dataset, # pyright: ignore[reportUnusedImport] + document, # pyright: ignore[reportUnusedImport] + hit_testing, # pyright: ignore[reportUnusedImport] + metadata, # pyright: ignore[reportUnusedImport] + segment, # pyright: ignore[reportUnusedImport] + upload_file, # pyright: ignore[reportUnusedImport] +) +from .workspace import models # pyright: ignore[reportUnusedImport] api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4860bf3a79..711dd5704c 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,4 +1,5 @@ from flask_restx import Resource, reqparse +from flask_restx._http import HTTPStatus from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -121,7 +122,7 @@ class ConversationDetailApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) + @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) def delete(self, app_model: App, end_user: EndUser, c_id): """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index de41384270..721cf530c3 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -30,6 +30,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +from models.model import EndUser from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") + upload_file = FileService.upload_file( filename=file.filename, content=file.read(), @@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource): raise FilenameNotExistsError try: + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") upload_file = FileService.upload_file( filename=file.filename, content=file.read(), diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 4394e64ad9..64a2f5445c 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -17,7 +17,7 @@ from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from libs.login import _get_user +from libs.login import current_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog from models.model import ApiToken, App, EndUser @@ -210,7 +210,7 @@ def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None if account: account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 3b0a9e341a..a825a2a0d8 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -17,20 +17,20 @@ api = ExternalApi( web_ns = Namespace("web", description="Web application API operations", path="/") from . import ( - app, - audio, - completion, - conversation, - feature, - files, - forgot_password, - login, - message, - passport, - remote_files, - saved_message, - site, - workflow, + app, # pyright: ignore[reportUnusedImport] + audio, # pyright: ignore[reportUnusedImport] + completion, # pyright: ignore[reportUnusedImport] + conversation, # pyright: ignore[reportUnusedImport] + feature, # pyright: ignore[reportUnusedImport] + files, # pyright: ignore[reportUnusedImport] + forgot_password, # pyright: ignore[reportUnusedImport] + login, # pyright: ignore[reportUnusedImport] + message, # pyright: ignore[reportUnusedImport] + passport, # pyright: ignore[reportUnusedImport] + remote_files, # pyright: ignore[reportUnusedImport] + saved_message, # pyright: ignore[reportUnusedImport] + site, # pyright: ignore[reportUnusedImport] + workflow, # pyright: ignore[reportUnusedImport] ) api.add_namespace(web_ns) diff --git a/api/core/__init__.py b/api/core/__init__.py index 6eaea7b1c8..e69de29bb2 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +0,0 @@ -import core.moderation.base diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index b94a60c40a..d1d5a011e0 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages + agent_thought_id = "" # Initialize agent_thought_id def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9eb853aa74..5236266908 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages # get tracing instance trace_manager = app_generate_entity.trace_manager diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 037037e6ca..97ede178c7 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id, config: dict, only_structure_validate: bool = False + cls, tenant_id: str, config: dict, only_structure_validate: bool = False ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} @@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager: if not only_structure_validate: typ = config["sensitive_word_avoidance"]["type"] - sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] + if not isinstance(typ, str): + raise ValueError("sensitive_word_avoidance.type must be a string") + + sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config") + if sensitive_word_avoidance_config is None: + sensitive_word_avoidance_config = {} + if not isinstance(sensitive_word_avoidance_config, dict): + raise ValueError("sensitive_word_avoidance.config must be a dict") ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index e6ab31e586..cda17c0010 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -25,10 +25,14 @@ class PromptTemplateConfigManager: if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): + text = message.get("text") + if not isinstance(text, str): + raise ValueError("message text must be a string") + role = message.get("role") + if not isinstance(role, str): + raise ValueError("message role must be a string") chat_prompt_messages.append( - AdvancedChatMessageEntity( - **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} - ) + AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role)) ) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 627f6b47ce..02ec96f209 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8207b70f9e..cec3b83674 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -302,13 +302,13 @@ class AdvancedChatAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" with self._database_session() as session: - err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]: """Handle workflow started events.""" @@ -627,10 +627,10 @@ class AdvancedChatAppGenerateTaskPipeline: workflow_execution=workflow_execution, ) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) - err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) + err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id) yield workflow_finish_resp - yield self._base_task_pipeline._error_to_stream_response(err) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_stop_event( self, @@ -683,7 +683,7 @@ class AdvancedChatAppGenerateTaskPipeline: """Handle advanced chat message end events.""" self._ensure_graph_runtime_initialized(graph_runtime_state) - output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished( self._task_state.answer ) if output_moderation_answer: @@ -899,7 +899,7 @@ class AdvancedChatAppGenerateTaskPipeline: message.answer = answer_text message.updated_at = naive_utc_now() - message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( @@ -955,9 +955,9 @@ class AdvancedChatAppGenerateTaskPipeline: :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._base_task_pipeline._output_moderation_handler: - if self._base_task_pipeline._output_moderation_handler.should_direct_output(): - self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + if self._base_task_pipeline.output_moderation_handler: + if self._base_task_pipeline.output_moderation_handler.should_direct_output(): + self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output() self._base_task_pipeline.queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) @@ -967,7 +967,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) return True else: - self._base_task_pipeline._output_moderation_handler.append_new_token(text) + self._base_task_pipeline.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 349b583833..54d1a9595f 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, Optional, cast from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager @@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return filtered_config @classmethod - def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_agent_mode_and_set_defaults( + cls, tenant_id: str, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate agent_mode and set defaults for agent feature @@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager): if not config.get("agent_mode"): config["agent_mode"] = {"enabled": False, "tools": []} - if not isinstance(config["agent_mode"], dict): + agent_mode = config["agent_mode"] + if not isinstance(agent_mode, dict): raise ValueError("agent_mode must be of object type") - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False + # FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing + agent_mode = cast(dict[str, Any], agent_mode) - if not isinstance(config["agent_mode"]["enabled"], bool): + if "enabled" not in agent_mode or not agent_mode["enabled"]: + agent_mode["enabled"] = False + + if not isinstance(agent_mode["enabled"], bool): raise ValueError("enabled in agent_mode must be of boolean type") - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + if not agent_mode.get("strategy"): + agent_mode["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [ - member.value for member in list(PlanningStrategy.__members__.values()) - ]: + if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: raise ValueError("strategy in agent_mode must be in the specified strategy list") - if not config["agent_mode"].get("tools"): - config["agent_mode"]["tools"] = [] + if not agent_mode.get("tools"): + agent_mode["tools"] = [] - if not isinstance(config["agent_mode"]["tools"], list): + if not isinstance(agent_mode["tools"], list): raise ValueError("tools in agent_mode must be a list of objects") - for tool in config["agent_mode"]["tools"]: + for tool in agent_mode["tools"]: key = list(tool.keys())[0] if key in OLD_TOOLS: # old style, use tool name as key diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 89a5b8e3b5..e35e9d9408 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 795a7befff..2a7fe7902b 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -32,6 +32,7 @@ class AppQueueManager: self._task_id = task_id self._user_id = user_id self._invoke_from = invoke_from + self.invoke_from = invoke_from # Public accessor for invoke_from user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" redis_client.setex( diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 816d6d79a9..3aa1161fd8 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 8485ce7519..843328f904 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise MoreLikeThisDisabledError() app_model_config = message.app_model_config + if not app_model_config: + raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] completion_params = model_dict.get("completion_params") diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 4d45c61145..d7e9ebdf24 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) + if not isinstance(metadata, dict): + metadata = {} sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 210f6110b1..01ecf0298f 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): :param blocking_response: blocking response :return: """ - return dict(blocking_response.to_dict()) + return blocking_response.model_dump() @classmethod def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] @@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 6ab89dbd61..1c950063dd 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -137,7 +137,7 @@ class WorkflowAppGenerateTaskPipeline: self._application_generate_entity = application_generate_entity self._workflow_features_dict = workflow.features_dict self._workflow_run_id = "" - self._invoke_from = queue_manager._invoke_from + self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -146,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline: :return: """ generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -276,12 +276,12 @@ class WorkflowAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" - err = self._base_task_pipeline._handle_error(event=event) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event( self, event: QueueWorkflowStartedEvent, **kwargs diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9151137fe8..1d5ebabaf7 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: EasyUIBasedAppConfig + app_config: EasyUIBasedAppConfig = None # type: ignore model_conf: ModelConfigWithCredentialsEntity query: Optional[str] = None @@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_run_id: Optional[str] = None query: str @@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_execution_id: str class SingleIterationRunEntity(BaseModel): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 29f3e3427e..31183d19a3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -5,7 +5,6 @@ from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -92,9 +91,6 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ErrorStreamResponse(StreamResponse): """ @@ -745,9 +741,6 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ChatbotAppBlockingResponse(AppBlockingResponse): """ diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index be183e2086..3853dccdc5 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -35,6 +35,9 @@ class AnnotationReplyFeature: collection_binding_detail = annotation_setting.collection_binding_detail + if not collection_binding_detail: + return None + try: score_threshold = annotation_setting.score_threshold or 1 embedding_provider_name = collection_binding_detail.provider_name diff --git a/api/core/app/features/rate_limiting/__init__.py b/api/core/app/features/rate_limiting/__init__.py index 6624f6ad9d..4ad33acd0f 100644 --- a/api/core/app/features/rate_limiting/__init__.py +++ b/api/core/app/features/rate_limiting/__init__.py @@ -1 +1,3 @@ from .rate_limit import RateLimit + +__all__ = ["RateLimit"] diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f526d2a16a..6f13f11da0 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} - def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): + def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 7d98cceb1a..4931300901 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline: ): self._application_generate_entity = application_generate_entity self.queue_manager = queue_manager - self._start_at = time.perf_counter() - self._output_moderation_handler = self._init_output_moderation() - self._stream = stream + self.start_at = time.perf_counter() + self.output_moderation_handler = self._init_output_moderation() + self.stream = stream - def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): + def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline: return message - def _error_to_stream_response(self, e: Exception): + def error_to_stream_response(self, e: Exception): """ Error to stream response. :param e: exception @@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline: """ return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) - def _ping_stream_response(self) -> PingStreamResponse: + def ping_stream_response(self) -> PingStreamResponse: """ Ping stream response. :return: @@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline: ) return None - def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ Handle output moderation when task finished. :param completion: completion :return: """ # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() - completion, flagged = self._output_moderation_handler.moderation_completion( + completion, flagged = self.output_moderation_handler.moderation_completion( completion=completion, public_event=False ) - self._output_moderation_handler = None + self.output_moderation_handler = None if flagged: return completion diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 0dad0a5a9d..71fd5ac653 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if isinstance(event, QueueErrorEvent): with Session(db.engine) as session: - err = self._handle_error(event=event, session=session, message_id=self._message_id) + err = self.handle_error(event=event, session=session, message_id=self._message_id) session.commit() - yield self._error_to_stream_response(err) + yield self.error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): @@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._handle_stop(event) # handle output moderation - output_moderation_answer = self._handle_output_moderation_when_task_finished( + output_moderation_answer = self.handle_output_moderation_when_task_finished( cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: @@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): elif isinstance(event, QueueMessageReplaceEvent): yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self.ping_stream_response() else: continue if publisher: @@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.answer_tokens = usage.completion_tokens message.answer_unit_price = usage.completion_unit_price message.answer_price_unit = usage.completion_price_unit - message.provider_response_latency = time.perf_counter() - self._start_at + message.provider_response_latency = time.perf_counter() - self.start_at message.total_price = usage.total_price message.currency = usage.currency self._task_state.llm_result.usage.latency = message.provider_response_latency @@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): # transform usage model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - self._task_state.llm_result.usage = model_type_instance._calc_response_usage( + self._task_state.llm_result.usage = model_type_instance.calc_response_usage( model, credentials, prompt_tokens, completion_tokens ) @@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): + if self.output_moderation_handler: + if self.output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output - self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() + self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output() self.queue_manager.publish( QueueLLMChunkEvent( chunk=LLMResultChunk( @@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) return True else: - self._output_moderation_handler.append_new_token(text) + self.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 4e6422e2df..89190c36cc 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher: self.voice = voice if not voice or voice not in values: self.voice = self.voices[0].get("value") - self.MAX_SENTENCE = 2 + self.max_sentence = 2 self._last_audio_event: Optional[AudioTrunk] = None # FIXME better way to handle this threading.start threading.Thread(target=self._runtime).start() @@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher: self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) - if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): - self.MAX_SENTENCE += 1 + if len(sentence_arr) >= min(self.max_sentence, 7): + self.max_sentence += 1 text_content = "".join(sentence_arr) futures_result = self.executor.submit( _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 9cf35e559d..5309e4e638 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1840,8 +1840,14 @@ class ProviderConfigurations(BaseModel): def __setitem__(self, key, value): self.configurations[key] = value + def __contains__(self, key): + if "/" not in key: + key = str(ModelProviderID(key)) + return key in self.configurations + def __iter__(self): - return iter(self.configurations) + # Return an iterator of (key, value) tuples to match BaseModel's __iter__ + yield from self.configurations.items() def values(self) -> Iterator[ProviderConfiguration]: return iter(self.configurations.values()) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index e3fd175d95..2a5f6c3dc7 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -98,7 +98,7 @@ def to_prompt_message_content( def download(f: File, /): if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): - return _download_file_content(f._storage_key) + return _download_file_content(f.storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() @@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /): response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string diff --git a/api/core/file/models.py b/api/core/file/models.py index f61334e7bc..9b74fa387f 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -146,3 +146,11 @@ class File(BaseModel): if not self.related_id: raise ValueError("Missing file related_id") return self + + @property + def storage_key(self) -> str: + return self._storage_key + + @storage_key.setter + def storage_key(self, value: str): + self._storage_key = value diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index efeba9e5ee..cbb78939d2 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -13,18 +13,18 @@ logger = logging.getLogger(__name__) SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES -HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True +http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True try: - HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY - http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() + config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + http_request_node_ssl_verify_lower = str(config_value).lower() if http_request_node_ssl_verify_lower == "true": - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True elif http_request_node_ssl_verify_lower == "false": - HTTP_REQUEST_NODE_SSL_VERIFY = False + http_request_node_ssl_verify = False else: raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") except NameError: - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] @@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): ) if "ssl_verify" not in kwargs: - kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY + kwargs["ssl_verify"] = http_request_node_ssl_verify ssl_verify = kwargs.pop("ssl_verify") diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 89a05e02c8..ed02b70b03 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -529,6 +529,7 @@ class IndexingRunner: # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 + create_keyword_thread = None if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": # create keyword index create_keyword_thread = threading.Thread( @@ -567,7 +568,11 @@ class IndexingRunner: for future in futures: tokens += future.result() - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + and create_keyword_thread is not None + ): create_keyword_thread.join() indexing_end_at = time.perf_counter() diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 94b8258e9c..d4c4f10a12 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -20,7 +20,7 @@ from core.llm_generator.prompts import ( ) from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.entities.trace_entity import TraceTaskName @@ -313,14 +313,20 @@ class LLMGenerator: model_type=ModelType.LLM, ) - prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] + prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response: LLMResult = model_instance.invoke_llm( + # Explicitly use the non-streaming overload + result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False, ) + # Runtime type check since pyright has issues with the overload + if not isinstance(result, LLMResult): + raise TypeError("Expected LLMResult when stream=False") + response = result + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 28833fe8e8..e0b70c132f 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -45,6 +45,7 @@ class SpecialModelType(StrEnum): @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -53,14 +54,13 @@ def invoke_llm_with_structured_output( model_parameters: Optional[Mapping] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, - stream: Literal[True] = True, + stream: Literal[True], user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -69,14 +69,13 @@ def invoke_llm_with_structured_output( model_parameters: Optional[Mapping] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, - stream: Literal[False] = False, + stream: Literal[False], user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -89,9 +88,8 @@ def invoke_llm_with_structured_output( user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index cc4263c0aa..6db22a09e0 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 @final class _StatusReady: def __init__(self, endpoint_url: str): - self._endpoint_url = endpoint_url + self.endpoint_url = endpoint_url @final class _StatusError: def __init__(self, exc: Exception): - self._exc = exc + self.exc = exc # Type aliases for better readability @@ -211,9 +211,9 @@ class SSETransport: raise ValueError("failed to get endpoint URL") if isinstance(status, _StatusReady): - return status._endpoint_url + return status.endpoint_url elif isinstance(status, _StatusError): - raise status._exc + raise status.exc else: raise ValueError("failed to get endpoint URL") diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 3d51ac2333..6f52c65234 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -38,6 +38,7 @@ def handle_mcp_request( """ request_type = type(request.root) + request_root = request.root def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: """Create success response with business result data""" @@ -58,21 +59,20 @@ def handle_mcp_request( error=error_data, ) - # Request handler mapping using functional approach - request_handlers = { - mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description), - mcp_types.ListToolsRequest: lambda: handle_list_tools( - app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict - ), - mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user), - mcp_types.PingRequest: lambda: handle_ping(), - } - try: - # Dispatch request to appropriate handler - handler = request_handlers.get(request_type) - if handler: - return create_success_response(handler()) + # Dispatch request to appropriate handler based on instance type + if isinstance(request_root, mcp_types.InitializeRequest): + return create_success_response(handle_initialize(mcp_server.description)) + elif isinstance(request_root, mcp_types.ListToolsRequest): + return create_success_response( + handle_list_tools( + app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict + ) + ) + elif isinstance(request_root, mcp_types.CallToolRequest): + return create_success_response(handle_call_tool(app, request, user_input_form, end_user)) + elif isinstance(request_root, mcp_types.PingRequest): + return create_success_response(handle_ping()) else: return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 96c48034c7..fbad5576aa 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self.request_meta = request_meta self.request = request self._session = session - self._completed = False + self.completed = False self._on_complete = on_complete self._entered = False # Track if we're in a context manager @@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): ): """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: + if self.completed: self._on_complete(self) finally: self._entered = False @@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """ if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" + assert not self.completed, "Request already responded to" - self._completed = True + self.completed = True self._session._send_response(request_id=self.request_id, response=response) @@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - self._completed = True # Mark as completed so it's removed from in_flight + self.completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation self._session._send_response( request_id=self.request_id, @@ -351,7 +351,7 @@ class BaseSession( self._in_flight[responder.request_id] = responder self._received_request(responder) - if not responder._completed: + if not responder.completed: self._handle_incoming(responder) elif isinstance(message.message.root, JSONRPCNotification): diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 24b206fdbe..1d7fd7d447 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -354,7 +354,7 @@ class LargeLanguageModel(AIModel): ) return 0 - def _calc_response_usage( + def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int ) -> LLMUsage: """ diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 47290ee613..92427a7426 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -1,4 +1,5 @@ import enum +import json from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -162,8 +163,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): # Try to parse JSON string for arrays if isinstance(value, str): try: - import json - parsed_value = json.loads(value) if isinstance(parsed_value, list): return parsed_value @@ -176,8 +175,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): # Try to parse JSON string for objects if isinstance(value, str): try: - import json - parsed_value = json.loads(value) if isinstance(parsed_value, dict): return parsed_value diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index ec66ba02ee..e30076f9d3 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -82,7 +82,9 @@ def merge_blob_chunks( message_class = type(resp) merged_message = message_class( type=ToolInvokeMessage.MessageType.BLOB, - message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), + message=ToolInvokeMessage.BlobMessage( + blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written]) + ), meta=resp.meta, ) yield cast(MessageType, merged_message) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d75a230d73..d15cb7cbc1 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} + custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] + special_variable_keys_obj = prompt_template_config["special_variable_keys"] - for v in prompt_template_config["special_variable_keys"]: + # Type check for custom_variable_keys + if not isinstance(custom_variable_keys_obj, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") + custom_variable_keys = cast(list[str], custom_variable_keys_obj) + + # Type check for special_variable_keys + if not isinstance(special_variable_keys_obj, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") + special_variable_keys = cast(list[str], special_variable_keys_obj) + + variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} + + for v in special_variable_keys: # support #context#, #query# and #histories# if v == "#context#": variables["#context#"] = context or "" @@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform): variables["#histories#"] = histories or "" prompt_template = prompt_template_config["prompt_template"] + if not isinstance(prompt_template, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}") + prompt = prompt_template.format(variables) - return prompt, prompt_template_config["prompt_rules"] + prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + + return prompt, prompt_rules def get_prompt_template( self, @@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ): + ) -> dict[str, object]: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) - custom_variable_keys = [] - special_variable_keys = [] + custom_variable_keys: list[str] = [] + special_variable_keys: list[str] = [] prompt = "" for order in prompt_rules["system_prompt_orders"]: diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 12d97c500f..d329220580 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -40,6 +40,19 @@ if TYPE_CHECKING: MetadataFilter = Union[DictFilter, common_types.Filter] +class PathQdrantParams(BaseModel): + path: str + + +class UrlQdrantParams(BaseModel): + url: str + api_key: Optional[str] + timeout: float + verify: bool + grpc_port: int + prefer_grpc: bool + + class QdrantConfig(BaseModel): endpoint: str api_key: Optional[str] = None @@ -50,7 +63,7 @@ class QdrantConfig(BaseModel): replication_factor: int = 1 write_consistency_factor: int = 1 - def to_qdrant_params(self): + def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams: if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): @@ -58,23 +71,23 @@ class QdrantConfig(BaseModel): raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) - return {"path": path} + return PathQdrantParams(path=path) else: - return { - "url": self.endpoint, - "api_key": self.api_key, - "timeout": self.timeout, - "verify": self.endpoint.startswith("https"), - "grpc_port": self.grpc_port, - "prefer_grpc": self.prefer_grpc, - } + return UrlQdrantParams( + url=self.endpoint, + api_key=self.api_key, + timeout=self.timeout, + verify=self.endpoint.startswith("https"), + grpc_port=self.grpc_port, + prefer_grpc=self.prefer_grpc, + ) class QdrantVector(BaseVector): def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config - self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump()) self._distance_func = distance_func.upper() self._group_id = group_id diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index b36252dba2..95ad9f25fe 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER # In-memory cache for workflow node executions - self._execution_cache: dict[str, WorkflowNodeExecution] = {} + self._execution_cache = {} # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval - self._workflow_execution_mapping: dict[str, list[str]] = {} + self._workflow_execution_mapping = {} logger.info( "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py index b363255b2c..0a41b64228 100644 --- a/api/core/variables/segment_group.py +++ b/api/core/variables/segment_group.py @@ -4,7 +4,7 @@ from .types import SegmentType class SegmentGroup(Segment): value_type: SegmentType = SegmentType.GROUP - value: list[Segment] + value: list[Segment] = None # type: ignore @property def text(self): diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 7da43a6504..28644b0169 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -74,12 +74,12 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING - value: str + value: str = None # type: ignore class FloatSegment(Segment): value_type: SegmentType = SegmentType.FLOAT - value: float + value: float = None # type: ignore # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. # @@ -98,12 +98,12 @@ class FloatSegment(Segment): class IntegerSegment(Segment): value_type: SegmentType = SegmentType.INTEGER - value: int + value: int = None # type: ignore class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] + value: Mapping[str, Any] = None # type: ignore @property def text(self) -> str: @@ -136,7 +136,7 @@ class ArraySegment(Segment): class FileSegment(Segment): value_type: SegmentType = SegmentType.FILE - value: File + value: File = None # type: ignore @property def markdown(self) -> str: @@ -153,17 +153,17 @@ class FileSegment(Segment): class BooleanSegment(Segment): value_type: SegmentType = SegmentType.BOOLEAN - value: bool + value: bool = None # type: ignore class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] + value: Sequence[Any] = None # type: ignore class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] + value: Sequence[str] = None # type: ignore @property def text(self) -> str: @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] + value: Sequence[float | int] = None # type: ignore class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] + value: Sequence[Mapping[str, Any]] = None # type: ignore class ArrayFileSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] + value: Sequence[File] = None # type: ignore @property def markdown(self) -> str: @@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] + value: Sequence[bool] = None # type: ignore def get_segment_discriminator(v: Any) -> SegmentType | None: diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 594bb2b32e..63513bdc9f 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -3,6 +3,6 @@ from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): def __init__(self, node: BaseNode, err_msg: str): - self._node = node - self._error = err_msg + self.node = node + self.error = err_msg super().__init__(f"Node {node.title} run failed: {err_msg}") diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index eb7b9fc2c6..cf46870254 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -67,8 +67,8 @@ class ListOperatorNode(BaseNode): return "1" def _run(self): - inputs: dict[str, list] = {} - process_data: dict[str, list] = {} + inputs: dict[str, Sequence[object]] = {} + process_data: dict[str, Sequence[object]] = {} outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c34a06d981..fdcdac1ec2 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1183,7 +1183,8 @@ def _combine_message_content_with_role( return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: return SystemPromptMessage(content=contents) - raise NotImplementedError(f"Role {role} is not supported") + case _: + raise NotImplementedError(f"Role {role} is not supported") def _render_jinja2_message( diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 9433b312cf..f2c37e1a4b 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -462,9 +462,9 @@ class StorageKeyLoader: upload_file_row = upload_files.get(model_id) if upload_file_row is None: raise ValueError(f"Upload file not found for id: {model_id}") - file._storage_key = upload_file_row.key + file.storage_key = upload_file_row.key elif file.transfer_method == FileTransferMethod.TOOL_FILE: tool_file_row = tool_files.get(model_id) if tool_file_row is None: raise ValueError(f"Tool file not found for id: {model_id}") - file._storage_key = tool_file_row.file_key + file.storage_key = tool_file_row.file_key diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index 8288bd54a3..b2b793d40e 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): return v.value_type.exposed_type().value else: - return v["value_type"].exposed_type().value + value_type = v.get("value_type") + if value_type is None: + raise ValueError("value_type is required but not provided") + return value_type.exposed_type().value diff --git a/api/libs/external_api.py b/api/libs/external_api.py index cee80f7f24..cf91b0117f 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api): headers["WWW-Authenticate"] = 'Bearer realm="api"' return data, status_code, headers + _ = handle_http_exception + @api.errorhandler(ValueError) def handle_value_error(e: ValueError): got_request_exception.send(current_app, exception=e) @@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api): data = {"code": "invalid_param", "message": str(e), "status": status_code} return data, status_code + _ = handle_value_error + @api.errorhandler(AppInvokeQuotaExceededError) def handle_quota_exceeded(e: AppInvokeQuotaExceededError): got_request_exception.send(current_app, exception=e) @@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api): data = {"code": "too_many_requests", "message": str(e), "status": status_code} return data, status_code + _ = handle_quota_exceeded + @api.errorhandler(Exception) def handle_general_exception(e: Exception): got_request_exception.send(current_app, exception=e) status_code = 500 - data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) + data = getattr(e, "data", {"message": http_status_message(status_code)}) # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) - if not isinstance(data, Mapping): + if not isinstance(data, dict): data = {"message": str(e)} data.setdefault("code", "unknown") @@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api): exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None - current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type] + current_app.log_exception(exc_info) return data, status_code + _ = handle_general_exception + class ExternalApi(Api): _authorizations = { diff --git a/api/libs/helper.py b/api/libs/helper.py index 139cb329de..f3c46b4843 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -167,13 +167,6 @@ class DatetimeString: return value -def _get_float(value): - try: - return float(value) - except (TypeError, ValueError): - raise ValueError(f"{value} is not a valid float") - - def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 352161523f..7c59c2ca28 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -1,24 +1,44 @@ { "include": ["."], - "exclude": [".venv", "tests/", "migrations/"], - "ignore": [ - "core/", - "controllers/", - "tasks/", - "services/", - "schedule/", - "extensions/", - "utils/", - "repositories/", - "libs/", - "fields/", - "factories/", - "events/", - "contexts/", - "constants/", - "commands.py" + "exclude": [ + ".venv", + "tests/", + "migrations/", + "core/rag", + "extensions", + "libs", + "controllers/console/datasets", + "controllers/service_api/dataset", + "core/ops", + "core/tools", + "core/model_runtime", + "core/workflow", + "core/app/app_config/easy_ui_based_app/dataset" ], "typeCheckingMode": "strict", + "allowedUntypedLibraries": [ + "flask_restx", + "flask_login", + "opentelemetry.instrumentation.celery", + "opentelemetry.instrumentation.flask", + "opentelemetry.instrumentation.requests", + "opentelemetry.instrumentation.sqlalchemy", + "opentelemetry.instrumentation.redis" + ], + "reportUnknownMemberType": "hint", + "reportUnknownParameterType": "hint", + "reportUnknownArgumentType": "hint", + "reportUnknownVariableType": "hint", + "reportUnknownLambdaType": "hint", + "reportMissingParameterType": "hint", + "reportMissingTypeArgument": "hint", + "reportUnnecessaryContains": "hint", + "reportUnnecessaryComparison": "hint", + "reportUnnecessaryCast": "hint", + "reportUnnecessaryIsInstance": "hint", + "reportUntypedFunctionDecorator": "hint", + + "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" } diff --git a/api/services/account_service.py b/api/services/account_service.py index a76792f88e..f66c1aa677 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1318,7 +1318,7 @@ class RegisterService: def get_invitation_if_token_valid( cls, workspace_id: Optional[str], email: str, token: str ) -> Optional[dict[str, Any]]: - invitation_data = cls._get_invitation_by_token(token, workspace_id, email) + invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -1355,7 +1355,7 @@ class RegisterService: } @classmethod - def _get_invitation_by_token( + def get_invitation_by_token( cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None ) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ba86a31240..82b1d21179 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -349,7 +349,7 @@ class AppAnnotationService: try: # Skip the first row - df = pd.read_csv(file, dtype=str) + df = pd.read_csv(file.stream, dtype=str) result = [] for _, row in df.iterrows(): content = {"question": row.iloc[0], "answer": row.iloc[1]} @@ -463,15 +463,23 @@ class AppAnnotationService: annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } return {"enabled": False} @classmethod @@ -506,15 +514,23 @@ class AppAnnotationService: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } @classmethod def clear_all_annotations(cls, app_id: str): diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 2f1b63664f..3b4cb1900a 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 65dc673100..20a9c73f08 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -134,11 +134,14 @@ class DatasetService: # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: - target_ids = TagService.get_target_ids_by_tag_ids( - "knowledge", - tenant_id, # ty: ignore [invalid-argument-type] - tag_ids, - ) + if tenant_id is not None: + target_ids = TagService.get_target_ids_by_tag_ids( + "knowledge", + tenant_id, + tag_ids, + ) + else: + target_ids = [] if target_ids and len(target_ids) > 0: query = query.where(Dataset.id.in_(target_ids)) else: @@ -987,7 +990,8 @@ class DocumentService: for document in documents if document.data_source_type == "upload_file" and document.data_source_info_dict ] - batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + if dataset.doc_form is not None: + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) for document in documents: db.session.delete(document) @@ -2688,56 +2692,6 @@ class SegmentService: return paginated_segments.items, paginated_segments.total - @classmethod - def update_segment_by_id( - cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str - ) -> tuple[DocumentSegment, Document]: - """Update a segment by its ID with validation and checks.""" - # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: - raise NotFound("Dataset not found.") - - # check user's model setting - DatasetService.check_dataset_model_setting(dataset) - - # check document - document = DocumentService.get_document(dataset_id, document_id) - if not document: - raise NotFound("Document not found.") - - # check embedding model setting if high quality - if dataset.indexing_technique == "high_quality": - try: - model_manager = ModelManager() - model_manager.get_model_instance( - tenant_id=user_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - - # check segment - segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) - .first() - ) - if not segment: - raise NotFound("Segment not found.") - - # validate and update segment - cls.segment_create_args_validate(segment_data, document) - updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset) - - return updated_segment, document - @classmethod def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: """Get a segment by its ID.""" diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 3262a00663..3911b763b6 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -181,7 +181,7 @@ class ExternalDatasetService: do http request depending on api bundle """ - kwargs = { + kwargs: dict[str, Any] = { "url": settings.url, "headers": settings.headers, "follow_redirects": True, diff --git a/api/services/file_service.py b/api/services/file_service.py index 8a4655d25e..364a872a91 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,7 +1,7 @@ import hashlib import os import uuid -from typing import Any, Literal, Union +from typing import Literal, Union from werkzeug.exceptions import NotFound @@ -35,7 +35,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser, Any], + user: Union[Account, EndUser], source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index c638087f63..d0e2230540 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -165,7 +165,7 @@ class ModelLoadBalancingService: try: if load_balancing_config.encrypted_config: - credentials = json.loads(load_balancing_config.encrypted_config) + credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config) else: credentials = {} except JSONDecodeError: @@ -180,11 +180,13 @@ class ModelLoadBalancingService: for variable in credential_secret_variables: if variable in credentials: try: - credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), # ty: ignore [invalid-argument-type] - decoding_rsa_key, - decoding_cipher_rsa, - ) + token_value = credentials.get(variable) + if isinstance(token_value, str): + credentials[variable] = encrypter.decrypt_token_with_decoding( + token_value, + decoding_rsa_key, + decoding_cipher_rsa, + ) except ValueError: pass @@ -345,8 +347,9 @@ class ModelLoadBalancingService: credential_id = config.get("credential_id") enabled = config.get("enabled") + credential_record: ProviderCredential | ProviderModelCredential | None = None + if credential_id: - credential_record: ProviderCredential | ProviderModelCredential | None = None if config_from == "predefined-model": credential_record = ( db.session.query(ProviderCredential) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 8dbf117fd3..bae2921a27 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -99,6 +99,7 @@ class PluginMigration: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index bce389b949..cb31111485 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -223,8 +223,8 @@ class BuiltinToolManageService: """ add builtin tool provider """ - try: - with Session(db.engine) as session: + with Session(db.engine) as session: + try: lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -285,9 +285,9 @@ class BuiltinToolManageService: session.add(db_provider) session.commit() - except Exception as e: - session.rollback() - raise ValueError(str(e)) + except Exception as e: + session.rollback() + raise ValueError(str(e)) return {"result": "success"} @staticmethod diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 2994856b54..8a58289b22 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -18,6 +18,7 @@ from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db @@ -420,7 +421,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template if not template: prompts = [] else: @@ -457,7 +462,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template template = self._replace_template_variables( template=template, variables=start_node["data"]["variables"], @@ -467,6 +476,9 @@ class WorkflowConverter: prompts = {"text": template} prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + role_prefix = { "user": prompt_rules.get("human_prefix", "Human"), "assistant": prompt_rules.get("assistant_prefix", "Assistant"), diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0a14007349..4e0ae15841 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -769,10 +769,10 @@ class WorkflowService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node = e._node + node = e.node run_succeeded = False node_run_result = None - error = e._error + error = e.error # Create a NodeExecution domain model node_execution = WorkflowNodeExecution( diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index d4fc68a084..292ac6e008 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -12,7 +12,7 @@ class WorkspaceService: def get_tenant_info(cls, tenant: Tenant): if not tenant: return None - tenant_info = { + tenant_info: dict[str, object] = { "id": tenant.id, "name": tenant.name, "plan": tenant.plan, diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 415e65ce51..6b5ac713e6 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -3278,7 +3278,7 @@ class TestRegisterService: redis_client.setex(cache_key, 24 * 60 * 60, account_id) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token( + result = RegisterService.get_invitation_by_token( token=token, workspace_id=workspace_id, email=email, @@ -3316,7 +3316,7 @@ class TestRegisterService: redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token(token=token) + result = RegisterService.get_invitation_by_token(token=token) # Verify result contains expected data assert result is not None diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 8b3db27525..18ab4bb73c 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( VariableEntityType, ) from core.model_runtime.entities.llm_entities import LLMMode +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.account import Account, Tenant from models.api_based_extension import APIBasedExtension from models.model import App, AppMode, AppModelConfig @@ -37,7 +38,7 @@ class TestWorkflowConverter: # Setup default mock returns mock_encrypter.decrypt_token.return_value = "decrypted_api_key" mock_prompt_transform.return_value.get_prompt_template.return_value = { - "prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(), + "prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"), "prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"}, } mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config() diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 442839e44e..d7404ee90a 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1370,8 +1370,8 @@ class TestRegisterService: account_id="user-123", email="test@example.com" ) - with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token: - # Mock the invitation data returned by _get_invitation_by_token + with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token: + # Mock the invitation data returned by get_invitation_by_token invitation_data = { "account_id": "user-123", "email": "test@example.com", @@ -1503,12 +1503,12 @@ class TestRegisterService: assert result == "member_invite:token:test-token" def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token with workspace ID and email.""" + """Test get_invitation_by_token with workspace ID and email.""" # Setup mock mock_redis_dependencies.get.return_value = b"user-123" # Execute test - result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com") + result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com") # Verify results assert result is not None @@ -1517,7 +1517,7 @@ class TestRegisterService: assert result["workspace_id"] == "workspace-456" def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token without workspace ID and email.""" + """Test get_invitation_by_token without workspace ID and email.""" # Setup mock invitation_data = { "account_id": "user-123", @@ -1527,19 +1527,19 @@ class TestRegisterService: mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is not None assert result == invitation_data def test_get_invitation_by_token_no_data(self, mock_redis_dependencies): - """Test _get_invitation_by_token with no data.""" + """Test get_invitation_by_token with no data.""" # Setup mock mock_redis_dependencies.get.return_value = None # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is None