diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 82ba95444f..be6ce80dfc 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -20,7 +20,7 @@ jobs: cd api uv sync --dev # Fix lint errors - uv run ruff check --fix-only . + uv run ruff check --fix . # Format code uv run ruff format . - name: ast-grep diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5a456f14fd..1a472e771d 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -11,11 +11,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from configs import dify_config from controllers.console import api -from controllers.console.app.error import ( - ConversationCompletedError, - DraftWorkflowNotExist, - DraftWorkflowNotSync, -) +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c7e300279a..5a871f896a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, Union +from typing import Optional, ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db @@ -8,6 +8,9 @@ from libs.login import current_user from models import App, AppMode from models.account import Account +P = ParamSpec("P") +R = TypeVar("R") + def _load_app_model(app_id: str) -> Optional[App]: assert isinstance(current_user, Account) @@ -19,10 +22,10 @@ def _load_app_model(app_id: str) -> Optional[App]: return app_model -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): - def decorator(view_func): +def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index e4d5f1be6e..45c647659b 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -249,7 +249,7 @@ class DataSourceNotionDatasetSyncApi(Resource): documents = DocumentService.get_document_by_dataset_id(dataset_id_str) for document in documents: document_indexing_sync_task.delay(dataset_id_str, document.id) - return 200 + return {"result": "success"}, 200 class DataSourceNotionDocumentSyncApi(Resource): @@ -267,7 +267,7 @@ class DataSourceNotionDocumentSyncApi(Resource): if document is None: raise NotFound("Document not found.") document_indexing_sync_task.delay(dataset_id_str, document_id_str) - return 200 + return {"result": "success"}, 200 api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 6aa309f930..21ab5e4fe1 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -113,7 +113,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): MetadataService.enable_built_in_field(dataset) elif action == "disable": MetadataService.disable_built_in_field(dataset) - return 200 + return {"result": "success"}, 200 class DocumentMetadataEditApi(Resource): @@ -135,7 +135,7 @@ class DocumentMetadataEditApi(Resource): MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 + return {"result": "success"}, 200 api.add_resource(DatasetMetadataCreateApi, "/datasets//metadata") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index c45e7dbb26..da236ee5af 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -111,7 +111,7 @@ class TagBindingCreateApi(Resource): args = parser.parse_args() TagService.save_tag_binding(args) - return 200 + return {"result": "success"}, 200 class TagBindingDeleteApi(Resource): @@ -132,7 +132,7 @@ class TagBindingDeleteApi(Resource): args = parser.parse_args() TagService.delete_tag_binding(args) - return 200 + return {"result": "success"}, 200 api.add_resource(TagListApi, "/tags") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index f751e06ddf..68711f7257 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional +from typing import Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -14,6 +14,9 @@ from libs.login import _get_user from models.account import Tenant from models.model import EndUser +P = ParamSpec("P") +R = TypeVar("R") + def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ @@ -52,19 +55,19 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Optional[Callable] = None): - def decorator(view_func): +def get_user_tenant(view: Optional[Callable[P, R]] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): # fetch json body parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json") - kwargs = parser.parse_args() + p = parser.parse_args() - user_id = kwargs.get("user_id") - tenant_id = kwargs.get("tenant_id") + user_id: Optional[str] = p.get("user_id") + tenant_id: str = p.get("tenant_id") if not tenant_id: raise ValueError("tenant_id is required") @@ -107,9 +110,9 @@ def get_user_tenant(view: Optional[Callable] = None): return decorator(view) -def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): - def decorator(view_func): - def decorated_view(*args, **kwargs): +def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]): + def decorator(view_func: Callable[P, R]): + def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() except Exception: diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index de4f1da801..4bdcc6832a 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]): return decorated -def enterprise_inner_api_user_auth(view): +def enterprise_inner_api_user_auth(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.INNER_API: return view(*args, **kwargs) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 444a791c01..c2df97eaec 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -174,7 +174,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): MetadataService.enable_built_in_field(dataset) elif action == "disable": MetadataService.disable_built_in_field(dataset) - return 200 + return {"result": "success"}, 200 @service_api_ns.route("/datasets//documents/metadata") @@ -204,4 +204,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): MetadataService.update_documents_metadata(dataset, metadata_args) - return 200 + return {"result": "success"}, 200 diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 536cf81a2f..fffcb47bd4 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -19,7 +19,7 @@ class ModelProviderAvailableModelApi(Resource): } ) @validate_dataset_token - def get(self, _, model_type): + def get(self, _, model_type: str): """Get available models by model type. Returns a list of available models for the specified model type. diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 14291578d5..4394e64ad9 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Optional, ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -25,6 +25,7 @@ from services.feature_service import FeatureService P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") class WhereisUserArg(StrEnum): @@ -42,10 +43,10 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): - def decorator(view_func): +def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("app") app_model = db.session.query(App).where(App.id == api_token.app_id).first() @@ -189,10 +190,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view=None): - def decorator(view): +def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None): + def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1fbb2c165f..e79456535a 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -20,12 +21,11 @@ P = ParamSpec("P") R = TypeVar("R") -def validate_jwt_token(view=None): - def decorator(view): +def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None): + def decorator(view: Callable[Concatenate[App, EndUser, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): app_model, end_user = decode_jwt_token() - return view(app_model, end_user, *args, **kwargs) return decorated diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 048885e245..b64c804082 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -42,6 +42,7 @@ from models.provider import ( TenantPreferredModelProvider, ) from models.provider_ids import ModelProviderID +from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) @@ -129,14 +130,38 @@ class ProviderConfiguration(BaseModel): return copy_credentials else: credentials = None + current_credential_id = None + if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: credentials = model_configuration.credentials + current_credential_id = model_configuration.current_credential_id break if not credentials and self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials + current_credential_id = self.custom_configuration.provider.current_credential_id + + if current_credential_id: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=current_credential_id, + provider=self.provider.provider, + credential_type=PluginCredentialType.MODEL, + ) + else: + # no current credential id, check all available credentials + if self.custom_configuration.provider: + for credential_configuration in self.custom_configuration.provider.available_credentials: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=credential_configuration.credential_id, + provider=self.provider.provider, + credential_type=PluginCredentialType.MODEL, + ) return credentials @@ -266,7 +291,6 @@ class ProviderConfiguration(BaseModel): :param credential_id: if provided, return the specified credential :return: """ - if credential_id: return self._get_specific_provider_credential(credential_id) @@ -739,6 +763,7 @@ class ProviderConfiguration(BaseModel): current_credential_id = credential_record.id current_credential_name = credential_record.credential_name + credentials = self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas @@ -793,6 +818,7 @@ class ProviderConfiguration(BaseModel): ): current_credential_id = model_configuration.current_credential_id current_credential_name = model_configuration.current_credential_name + credentials = self.obfuscated_credentials( credentials=model_configuration.credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 79a7514bbc..9b8baf1973 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -145,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel): name: str credentials: dict credential_source_type: str | None = None + credential_id: str | None = None class ModelSettings(BaseModel): diff --git a/api/core/helper/credential_utils.py b/api/core/helper/credential_utils.py new file mode 100644 index 0000000000..240f498181 --- /dev/null +++ b/api/core/helper/credential_utils.py @@ -0,0 +1,75 @@ +""" +Credential utility functions for checking credential existence and policy compliance. +""" + +from services.enterprise.plugin_manager_service import PluginCredentialType + + +def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: + """ + Check if the credential still exists in the database. + + :param credential_id: The credential ID to check + :param credential_type: The type of credential (MODEL or TOOL) + :return: True if credential exists, False otherwise + """ + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + from models.provider import ProviderCredential, ProviderModelCredential + from models.tools import BuiltinToolProvider + + with Session(db.engine) as session: + if credential_type == PluginCredentialType.MODEL: + # Check both pre-defined and custom model credentials using a single UNION query + stmt = ( + select(ProviderCredential.id) + .where(ProviderCredential.id == credential_id) + .union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id)) + ) + return session.scalar(stmt) is not None + + if credential_type == PluginCredentialType.TOOL: + return ( + session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id)) + is not None + ) + + return False + + +def check_credential_policy_compliance( + credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True +) -> None: + """ + Check credential policy compliance for the given credential ID. + + :param credential_id: The credential ID to check + :param provider: The provider name + :param credential_type: The type of credential (MODEL or TOOL) + :param check_existence: Whether to check if credential exists in database first + :raises ValueError: If credential policy compliance check fails + """ + from services.enterprise.plugin_manager_service import ( + CheckCredentialPolicyComplianceRequest, + PluginManagerService, + ) + from services.feature_service import FeatureService + + if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id: + return + + # Check if credential exists in database first (if requested) + if check_existence: + if not is_credential_exists(credential_id, credential_type): + raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.") + + # Check policy compliance + PluginManagerService.check_credential_policy_compliance( + CheckCredentialPolicyComplianceRequest( + dify_credential_id=credential_id, + provider=provider, + credential_type=credential_type, + ) + ) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a59b0ae826..10df2ad79e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -23,6 +23,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.provider_manager import ProviderManager from extensions.ext_redis import redis_client from models.provider import ProviderType +from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) @@ -362,6 +363,23 @@ class ModelInstance: else: raise last_exception + # Additional policy compliance check as fallback (in case fetch_next didn't catch it) + try: + from core.helper.credential_utils import check_credential_policy_compliance + + if lb_config.credential_id: + check_credential_policy_compliance( + credential_id=lb_config.credential_id, + provider=self.provider, + credential_type=PluginCredentialType.MODEL, + ) + except Exception as e: + logger.warning( + "Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e) + ) + self.load_balancing_manager.cooldown(lb_config, expire=60) + continue + try: if "credentials" in kwargs: del kwargs["credentials"] @@ -515,6 +533,24 @@ class LBModelManager: continue + # Check policy compliance for the selected configuration + try: + from core.helper.credential_utils import check_credential_policy_compliance + + if config.credential_id: + check_credential_policy_compliance( + credential_id=config.credential_id, + provider=self._provider, + credential_type=PluginCredentialType.MODEL, + ) + except Exception as e: + logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e)) + cooldown_load_balancing_configs.append(config) + if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): + # all configs are in cooldown or failed policy compliance + return None + continue + if dify_config.DEBUG: logger.info( """Model LB diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index ecbeadac0b..51db6eb1a7 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1129,6 +1129,7 @@ class ProviderManager: name=load_balancing_model_config.name, credentials=provider_model_credentials, credential_source_type=load_balancing_model_config.credential_source_type, + credential_id=load_balancing_model_config.credential_id, ) ) diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 7da830f643..3dd073ce50 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -1,8 +1,9 @@ import json import logging import uuid +from collections.abc import Callable from functools import wraps -from typing import Any, Optional +from typing import Any, Concatenate, Optional, ParamSpec, TypeVar from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset logger = logging.getLogger(__name__) -from typing import ParamSpec, TypeVar P = ParamSpec("P") R = TypeVar("R") @@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel): return values -def ensure_client(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.client is None: - self.client = self._get_client(None, False) - return func(self, *args, **kwargs) - - return wrapper - - class MatrixoneVector(BaseVector): """ Matrixone vector storage implementation. @@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector): self.client.delete() +T = TypeVar("T", bound=MatrixoneVector) + + +def ensure_client(func: Callable[Concatenate[T, P], R]): + @wraps(func) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs): + if self.client is None: + self.client = self._get_client(None, False) + return func(self, *args, **kwargs) + + return wrapper + + class MatrixoneVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: if dataset.index_struct_dict: diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index c5f9ca4774..b0c2232857 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolCredentialPolicyViolationError(ValueError): + pass + + class ToolEngineInvokeError(Exception): meta: ToolInvokeMeta diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index faba457b75..5c836cfcd2 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -21,6 +21,7 @@ from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered from core.helper.provider_cache import ToolProviderCredentialsCache from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool import Tool from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime @@ -44,16 +45,16 @@ from core.tools.mcp_tool.tool import MCPTool from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool +from core.workflow.entities.variable_pool import VariablePool from extensions.ext_database import db from models.provider_ids import ToolProviderID from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider +from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.tools_transform_service import ToolTransformService @@ -115,7 +116,6 @@ class ToolManager: get the plugin provider """ # check if context is set - from core.plugin.impl.tool import PluginToolManager try: contexts.plugin_tool_providers.get() @@ -237,6 +237,16 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + # check if the credential is allowed to be used + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=builtin_provider.id, + provider=provider_id, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, config=[ @@ -509,7 +519,6 @@ class ToolManager: """ list all the plugin providers """ - from core.plugin.impl.tool import PluginToolManager manager = PluginToolManager() provider_entities = manager.fetch_tool_providers(tenant_id) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index cd01a31068..5571c0d9ba 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -86,9 +86,7 @@ def load_user_from_request(request_from_flask_login): if not app_mcp_server: raise NotFound("App MCP server not found.") end_user = ( - db.session.query(EndUser) - .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") - .first() + db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first() ) if not end_user: raise NotFound("End user not found.") diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8f95e327b2..b55233ecb3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -217,7 +217,7 @@ class DatasetService: and retrieval_model.reranking_model.reranking_model_name ): # check if reranking model setting is valid - DatasetService.check_embedding_model_setting( + DatasetService.check_reranking_model_setting( tenant_id, retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 3c3f970444..edb76408e8 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -3,18 +3,30 @@ import os import requests -class EnterpriseRequest: - base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") - secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") - +class BaseRequest: proxies = { "http": "", "https": "", } + base_url = "" + secret_key = "" + secret_key_header = "" @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} + headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) return response.json() + + +class EnterpriseRequest(BaseRequest): + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") + secret_key_header = "Enterprise-Api-Secret-Key" + + +class EnterprisePluginManagerRequest(BaseRequest): + base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL") + secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY") + secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key" diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py new file mode 100644 index 0000000000..ee8a932ded --- /dev/null +++ b/api/services/enterprise/plugin_manager_service.py @@ -0,0 +1,57 @@ +import enum +import logging + +from pydantic import BaseModel + +from services.enterprise.base import EnterprisePluginManagerRequest +from services.errors.base import BaseServiceError + +logger = logging.getLogger(__name__) + + +class PluginCredentialType(enum.IntEnum): + MODEL = enum.auto() + TOOL = enum.auto() + + def to_number(self): + return self.value + + +class CheckCredentialPolicyComplianceRequest(BaseModel): + dify_credential_id: str + provider: str + credential_type: PluginCredentialType + + def model_dump(self, **kwargs): + data = super().model_dump(**kwargs) + data["credential_type"] = self.credential_type.to_number() + return data + + +class CredentialPolicyViolationError(BaseServiceError): + pass + + +class PluginManagerService: + @classmethod + def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest): + try: + ret = EnterprisePluginManagerRequest.send_request( + "POST", "/check-credential-policy-compliance", json=body.model_dump() + ) + if not isinstance(ret, dict) or "result" not in ret: + raise ValueError("Invalid response format from plugin manager API") + except Exception as e: + raise CredentialPolicyViolationError( + f"error occurred while checking credential policy compliance: {e}" + ) from e + + if not ret.get("result", False): + raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") + + logger.debug( + "Credential policy compliance checked for %s with credential %s, result: %s", + body.provider, + body.dify_credential_id, + ret.get("result", False), + ) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 1441e6ce16..c27c0b0d58 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -134,6 +134,10 @@ class KnowledgeRateLimitModel(BaseModel): subscription_plan: str = "" +class PluginManagerModel(BaseModel): + enabled: bool = False + + class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" @@ -150,6 +154,7 @@ class SystemFeatureModel(BaseModel): webapp_auth: WebAppAuthModel = WebAppAuthModel() plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() enable_change_email: bool = True + plugin_manager: PluginManagerModel = PluginManagerModel() class FeatureService: @@ -188,6 +193,7 @@ class FeatureService: system_features.branding.enabled = True system_features.webapp_auth.enabled = True system_features.enable_change_email = False + system_features.plugin_manager.enabled = True cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 14d242fd1f..e83a9c3095 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -32,22 +32,14 @@ from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider -from models.workflow import ( - Workflow, - WorkflowNodeExecutionModel, - WorkflowNodeExecutionTriggeredFrom, - WorkflowType, -) +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType from repositories.factory import DifyAPIRepositoryFactory +from services.enterprise.plugin_manager_service import PluginCredentialType from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from .workflow_draft_variable_service import ( - DraftVariableSaver, - DraftVarLoader, - WorkflowDraftVariableService, -) +from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService class WorkflowService: @@ -267,6 +259,12 @@ class WorkflowService: if not draft_workflow: raise ValueError("No valid workflow found.") + # Validate credentials before publishing, for credential policy check + from services.feature_service import FeatureService + + if FeatureService.get_system_features().plugin_manager.enabled: + self._validate_workflow_credentials(draft_workflow) + # create new workflow workflow = Workflow.new( tenant_id=app_model.tenant_id, @@ -291,6 +289,260 @@ class WorkflowService: # return new workflow return workflow + def _validate_workflow_credentials(self, workflow: Workflow) -> None: + """ + Validate all credentials in workflow nodes before publishing. + + :param workflow: The workflow to validate + :raises ValueError: If any credentials violate policy compliance + """ + graph_dict = workflow.graph_dict + nodes = graph_dict.get("nodes", []) + + for node in nodes: + node_data = node.get("data", {}) + node_type = node_data.get("type") + node_id = node.get("id", "unknown") + + try: + # Extract and validate credentials based on node type + if node_type == "tool": + credential_id = node_data.get("credential_id") + provider = node_data.get("provider_id") + if provider: + if credential_id: + # Check specific credential + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=credential_id, + provider=provider, + credential_type=PluginCredentialType.TOOL, + ) + else: + # Check default workspace credential for this provider + self._check_default_tool_credential(workflow.tenant_id, provider) + + elif node_type == "agent": + agent_params = node_data.get("agent_parameters", {}) + + model_config = agent_params.get("model", {}).get("value", {}) + if model_config.get("provider") and model_config.get("model"): + self._validate_llm_model_config( + workflow.tenant_id, model_config["provider"], model_config["model"] + ) + + # Validate load balancing credentials for agent model if load balancing is enabled + agent_model_node_data = {"model": model_config} + self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id) + + # Validate agent tools + tools = agent_params.get("tools", {}).get("value", []) + for tool in tools: + # Agent tools store provider in provider_name field + provider = tool.get("provider_name") + credential_id = tool.get("credential_id") + if provider: + if credential_id: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL) + else: + self._check_default_tool_credential(workflow.tenant_id, provider) + + elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]: + model_config = node_data.get("model", {}) + provider = model_config.get("provider") + model_name = model_config.get("name") + + if provider and model_name: + # Validate that the provider+model combination can fetch valid credentials + self._validate_llm_model_config(workflow.tenant_id, provider, model_name) + # Validate load balancing credentials if load balancing is enabled + self._validate_load_balancing_credentials(workflow, node_data, node_id) + else: + raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration") + + except Exception as e: + if isinstance(e, ValueError): + raise e + else: + raise ValueError(f"Node {node_id} ({node_type}): {str(e)}") + + def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None: + """ + Validate that an LLM model configuration can fetch valid credentials. + + This method attempts to get the model instance and validates that: + 1. The provider exists and is configured + 2. The model exists in the provider + 3. Credentials can be fetched for the model + 4. The credentials pass policy compliance checks + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :raises ValueError: If the model configuration is invalid or credentials fail policy checks + """ + try: + from core.model_manager import ModelManager + from core.model_runtime.entities.model_entities import ModelType + + # Get model instance to validate provider+model combination + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name + ) + + # The ModelInstance constructor will automatically check credential policy compliance + # via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance() + # If it fails, an exception will be raised + + except Exception as e: + raise ValueError( + f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}" + ) + + def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None: + """ + Check credential policy compliance for the default workspace credential of a tool provider. + + This method finds the default credential for the given provider and validates it. + Uses the same fallback logic as runtime to handle deauthorized credentials. + + :param tenant_id: The tenant ID + :param provider: The tool provider name + :raises ValueError: If no default credential exists or if it fails policy compliance + """ + try: + from models.tools import BuiltinToolProvider + + # Use the same fallback logic as runtime: get the first available credential + # ordered by is_default DESC, created_at ASC (same as tool_manager.py) + default_provider = ( + db.session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() + ) + + if not default_provider: + raise ValueError("No default credential found") + + # Check credential policy compliance using the default credential ID + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=default_provider.id, + provider=provider, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + + except Exception as e: + raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") + + def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: + """ + Validate load balancing credentials for a workflow node. + + :param workflow: The workflow being validated + :param node_data: The node data containing model configuration + :param node_id: The node ID for error reporting + :raises ValueError: If load balancing credentials violate policy compliance + """ + # Extract model configuration + model_config = node_data.get("model", {}) + provider = model_config.get("provider") + model_name = model_config.get("name") + + if not provider or not model_name: + return # No model config to validate + + # Check if this model has load balancing enabled + if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name): + # Get all load balancing configurations for this model + load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name) + # Validate each load balancing configuration + try: + for config in load_balancing_configs: + if config.get("credential_id"): + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + config["credential_id"], provider, PluginCredentialType.MODEL + ) + except Exception as e: + raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") + + def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool: + """ + Check if load balancing is enabled for a specific model. + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :return: True if load balancing is enabled, False otherwise + """ + try: + from core.model_runtime.entities.model_entities import ModelType + from core.provider_manager import ProviderManager + + # Get provider configurations + provider_manager = ProviderManager() + provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configuration = provider_configurations.get(provider) + + if not provider_configuration: + return False + + # Get provider model setting + provider_model_setting = provider_configuration.get_provider_model_setting( + model_type=ModelType.LLM, + model=model_name, + ) + return provider_model_setting is not None and provider_model_setting.load_balancing_enabled + + except Exception: + # If we can't determine the status, assume load balancing is not enabled + return False + + def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]: + """ + Get all load balancing configurations for a model. + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :return: List of load balancing configuration dictionaries + """ + try: + from services.model_load_balancing_service import ModelLoadBalancingService + + model_load_balancing_service = ModelLoadBalancingService() + _, configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, + provider=provider, + model=model_name, + model_type="llm", # Load balancing is primarily used for LLM models + config_from="predefined-model", # Check both predefined and custom models + ) + + _, custom_configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" + ) + all_configs = configs + custom_configs + + return [config for config in all_configs if config.get("credential_id")] + + except Exception: + # If we can't get the configurations, return empty list + # This will prevent validation errors from breaking the workflow + return [] + def get_default_block_configs(self) -> list[dict]: """ Get default block configs diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 5890c2ea92..f3dbc9421c 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -43,9 +43,9 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const handleSaveAvatar = useCallback(async (uploadedFileId: string) => { try { await updateUserProfile({ url: 'account/avatar', body: { avatar: uploadedFileId } }) - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) setIsShowAvatarPicker(false) onSave?.() + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) } catch (e) { notify({ type: 'error', message: (e as Error).message }) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index d0d42dc32c..e9a64d8867 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -279,12 +279,21 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { )} { - (isGettingUserCanAccessApp || !userCanAccessApp?.result) ? null : <> - - - + (!systemFeatures.webapp_auth.enabled) + ? <> + + + + : !(isGettingUserCanAccessApp || !userCanAccessApp?.result) && ( + <> + + + + ) } { diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/avatar/index.tsx index a6e04a0755..89019a19b0 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/avatar/index.tsx @@ -1,5 +1,5 @@ 'use client' -import { useState } from 'react' +import { useEffect, useState } from 'react' import cn from '@/utils/classnames' export type AvatarProps = { @@ -27,6 +27,12 @@ const Avatar = ({ onError?.(true) } + // after uploaded, api would first return error imgs url: '.../files//file-preview/...'. Then return the right url, Which caused not show the avatar + useEffect(() => { + if(avatar && imgError) + setImgError(false) + }, [avatar]) + if (avatar && !imgError) { return ( - + )} @@ -232,7 +232,7 @@ const ModelLoadBalancingConfigs = ({ <> toggleConfigEntryEnabled(index, value)} diff --git a/web/app/components/workflow/candidate-node.tsx b/web/app/components/workflow/candidate-node.tsx index eb59a4618c..35bcd5c201 100644 --- a/web/app/components/workflow/candidate-node.tsx +++ b/web/app/components/workflow/candidate-node.tsx @@ -62,9 +62,9 @@ const CandidateNode = () => { }) setNodes(newNodes) if (candidateNode.type === CUSTOM_NOTE_NODE) - saveStateToHistory(WorkflowHistoryEvent.NoteAdd) + saveStateToHistory(WorkflowHistoryEvent.NoteAdd, { nodeId: candidateNode.id }) else - saveStateToHistory(WorkflowHistoryEvent.NodeAdd) + saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: candidateNode.id }) workflowStore.setState({ candidateNode: undefined }) diff --git a/web/app/components/workflow/header/view-workflow-history.tsx b/web/app/components/workflow/header/view-workflow-history.tsx index 5c31677f5e..42afd18d25 100644 --- a/web/app/components/workflow/header/view-workflow-history.tsx +++ b/web/app/components/workflow/header/view-workflow-history.tsx @@ -89,10 +89,19 @@ const ViewWorkflowHistory = () => { const calculateChangeList: ChangeHistoryList = useMemo(() => { const filterList = (list: any, startIndex = 0, reverse = false) => list.map((state: Partial, index: number) => { + const nodes = (state.nodes || store.getState().nodes) || [] + const nodeId = state?.workflowHistoryEventMeta?.nodeId + const targetTitle = nodes.find(n => n.id === nodeId)?.data?.title ?? '' return { label: state.workflowHistoryEvent && getHistoryLabel(state.workflowHistoryEvent), index: reverse ? list.length - 1 - index - startIndex : index - startIndex, - state, + state: { + ...state, + workflowHistoryEventMeta: state.workflowHistoryEventMeta ? { + ...state.workflowHistoryEventMeta, + nodeTitle: state.workflowHistoryEventMeta.nodeTitle || targetTitle, + } : undefined, + }, } }).filter(Boolean) @@ -110,6 +119,12 @@ const ViewWorkflowHistory = () => { } }, [futureStates, getHistoryLabel, pastStates, store]) + const composeHistoryItemLabel = useCallback((nodeTitle: string | undefined, baseLabel: string) => { + if (!nodeTitle) + return baseLabel + return `${nodeTitle} ${baseLabel}` + }, []) + return ( ( { 'flex items-center text-[13px] font-medium leading-[18px] text-text-secondary', )} > - {item?.label || t('workflow.changeHistory.sessionStart')} ({calculateStepLabel(item?.index)}{item?.index === currentHistoryStateIndex && t('workflow.changeHistory.currentState')}) + {composeHistoryItemLabel( + item?.state?.workflowHistoryEventMeta?.nodeTitle, + item?.label || t('workflow.changeHistory.sessionStart'), + )} ({calculateStepLabel(item?.index)}{item?.index === currentHistoryStateIndex && t('workflow.changeHistory.currentState')}) @@ -222,7 +240,10 @@ const ViewWorkflowHistory = () => { 'flex items-center text-[13px] font-medium leading-[18px] text-text-secondary', )} > - {item?.label || t('workflow.changeHistory.sessionStart')} ({calculateStepLabel(item?.index)}) + {composeHistoryItemLabel( + item?.state?.workflowHistoryEventMeta?.nodeTitle, + item?.label || t('workflow.changeHistory.sessionStart'), + )} ({calculateStepLabel(item?.index)}) diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index aae4594214..35740e0c28 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -173,7 +173,7 @@ export const useNodesInteractions = () => { if (x !== 0 && y !== 0) { // selecting a note will trigger a drag stop event with x and y as 0 - saveStateToHistory(WorkflowHistoryEvent.NodeDragStop) + saveStateToHistory(WorkflowHistoryEvent.NodeDragStop, { nodeId: node.id }) } } }, [workflowStore, getNodesReadOnly, saveStateToHistory, handleSyncWorkflowDraft]) @@ -404,7 +404,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeConnect) + saveStateToHistory(WorkflowHistoryEvent.NodeConnect, { nodeId: targetNode?.id }) } else { const { @@ -640,10 +640,10 @@ export const useNodesInteractions = () => { handleSyncWorkflowDraft() if (currentNode.type === CUSTOM_NOTE_NODE) - saveStateToHistory(WorkflowHistoryEvent.NoteDelete) + saveStateToHistory(WorkflowHistoryEvent.NoteDelete, { nodeId: currentNode.id }) else - saveStateToHistory(WorkflowHistoryEvent.NodeDelete) + saveStateToHistory(WorkflowHistoryEvent.NodeDelete, { nodeId: currentNode.id }) }, [getNodesReadOnly, store, deleteNodeInspectorVars, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, t]) const handleNodeAdd = useCallback(( @@ -1081,7 +1081,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) } handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeAdd) + saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: newNode.id }) }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory, workflowStore, getAfterNodesInSameBranch, checkNestedParallelLimit]) const handleNodeChange = useCallback(( @@ -1163,7 +1163,7 @@ export const useNodesInteractions = () => { setEdges(newEdges) handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeChange) + saveStateToHistory(WorkflowHistoryEvent.NodeChange, { nodeId: currentNodeId }) }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory]) const handleNodesCancelSelected = useCallback(() => { @@ -1385,7 +1385,7 @@ export const useNodesInteractions = () => { setNodes([...nodes, ...nodesToPaste]) setEdges([...edges, ...edgesToPaste]) - saveStateToHistory(WorkflowHistoryEvent.NodePaste) + saveStateToHistory(WorkflowHistoryEvent.NodePaste, { nodeId: nodesToPaste?.[0]?.id }) handleSyncWorkflowDraft() } }, [getNodesReadOnly, workflowStore, store, reactflow, saveStateToHistory, handleSyncWorkflowDraft, handleNodeIterationChildrenCopy, handleNodeLoopChildrenCopy]) @@ -1482,7 +1482,7 @@ export const useNodesInteractions = () => { }) setNodes(newNodes) handleSyncWorkflowDraft() - saveStateToHistory(WorkflowHistoryEvent.NodeResize) + saveStateToHistory(WorkflowHistoryEvent.NodeResize, { nodeId }) }, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory]) const handleNodeDisconnect = useCallback((nodeId: string) => { diff --git a/web/app/components/workflow/hooks/use-workflow-history.ts b/web/app/components/workflow/hooks/use-workflow-history.ts index 592c0b01cd..b7338dc4f8 100644 --- a/web/app/components/workflow/hooks/use-workflow-history.ts +++ b/web/app/components/workflow/hooks/use-workflow-history.ts @@ -8,6 +8,7 @@ import { } from 'reactflow' import { useTranslation } from 'react-i18next' import { useWorkflowHistoryStore } from '../workflow-history-store' +import type { WorkflowHistoryEventMeta } from '../workflow-history-store' /** * All supported Events that create a new history state. @@ -64,20 +65,21 @@ export const useWorkflowHistory = () => { // Some events may be triggered multiple times in a short period of time. // We debounce the history state update to avoid creating multiple history states // with minimal changes. - const saveStateToHistoryRef = useRef(debounce((event: WorkflowHistoryEvent) => { + const saveStateToHistoryRef = useRef(debounce((event: WorkflowHistoryEvent, meta?: WorkflowHistoryEventMeta) => { workflowHistoryStore.setState({ workflowHistoryEvent: event, + workflowHistoryEventMeta: meta, nodes: store.getState().getNodes(), edges: store.getState().edges, }) }, 500)) - const saveStateToHistory = useCallback((event: WorkflowHistoryEvent) => { + const saveStateToHistory = useCallback((event: WorkflowHistoryEvent, meta?: WorkflowHistoryEventMeta) => { switch (event) { case WorkflowHistoryEvent.NoteChange: // Hint: Note change does not trigger when note text changes, // because the note editors have their own history states. - saveStateToHistoryRef.current(event) + saveStateToHistoryRef.current(event, meta) break case WorkflowHistoryEvent.NodeTitleChange: case WorkflowHistoryEvent.NodeDescriptionChange: @@ -93,7 +95,7 @@ export const useWorkflowHistory = () => { case WorkflowHistoryEvent.NoteAdd: case WorkflowHistoryEvent.LayoutOrganize: case WorkflowHistoryEvent.NoteDelete: - saveStateToHistoryRef.current(event) + saveStateToHistoryRef.current(event, meta) break default: // We do not create a history state for every event. diff --git a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx index 3594b8fdbc..a5bf1befbd 100644 --- a/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx @@ -154,11 +154,11 @@ const BasePanel: FC = ({ const handleTitleBlur = useCallback((title: string) => { handleNodeDataUpdateWithSyncDraft({ id, data: { title } }) - saveStateToHistory(WorkflowHistoryEvent.NodeTitleChange) + saveStateToHistory(WorkflowHistoryEvent.NodeTitleChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) const handleDescriptionChange = useCallback((desc: string) => { handleNodeDataUpdateWithSyncDraft({ id, data: { desc } }) - saveStateToHistory(WorkflowHistoryEvent.NodeDescriptionChange) + saveStateToHistory(WorkflowHistoryEvent.NodeDescriptionChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) const isChildNode = !!(data.isInIteration || data.isInLoop) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx index c21aa1405e..6e8a2b2fad 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx @@ -17,7 +17,7 @@ const ErrorMessage: FC = ({ className, )}> -
+
{message}
diff --git a/web/app/components/workflow/nodes/llm/utils.ts b/web/app/components/workflow/nodes/llm/utils.ts index 045acf3993..7f13998cd7 100644 --- a/web/app/components/workflow/nodes/llm/utils.ts +++ b/web/app/components/workflow/nodes/llm/utils.ts @@ -1,9 +1,8 @@ +import { z } from 'zod' import { ArrayType, Type } from './types' import type { ArrayItems, Field, LLMNodeType } from './types' -import type { Schema, ValidationError } from 'jsonschema' -import { Validator } from 'jsonschema' -import produce from 'immer' -import { z } from 'zod' +import { draft07Validator, forbidBooleanProperties } from '@/utils/validators' +import type { ValidationError } from 'jsonschema' export const checkNodeValid = (_payload: LLMNodeType) => { return true @@ -14,7 +13,7 @@ export const getFieldType = (field: Field) => { if (type !== Type.array || !items) return type - return ArrayType[items.type] + return ArrayType[items.type as keyof typeof ArrayType] } export const getHasChildren = (schema: Field) => { @@ -115,191 +114,22 @@ export const findPropertyWithPath = (target: any, path: string[]) => { return current } -const draft07MetaSchema = { - $schema: 'http://json-schema.org/draft-07/schema#', - $id: 'http://json-schema.org/draft-07/schema#', - title: 'Core schema meta-schema', - definitions: { - schemaArray: { - type: 'array', - minItems: 1, - items: { $ref: '#' }, - }, - nonNegativeInteger: { - type: 'integer', - minimum: 0, - }, - nonNegativeIntegerDefault0: { - allOf: [ - { $ref: '#/definitions/nonNegativeInteger' }, - { default: 0 }, - ], - }, - simpleTypes: { - enum: [ - 'array', - 'boolean', - 'integer', - 'null', - 'number', - 'object', - 'string', - ], - }, - stringArray: { - type: 'array', - items: { type: 'string' }, - uniqueItems: true, - default: [], - }, - }, - type: ['object', 'boolean'], - properties: { - $id: { - type: 'string', - format: 'uri-reference', - }, - $schema: { - type: 'string', - format: 'uri', - }, - $ref: { - type: 'string', - format: 'uri-reference', - }, - title: { - type: 'string', - }, - description: { - type: 'string', - }, - default: true, - readOnly: { - type: 'boolean', - default: false, - }, - examples: { - type: 'array', - items: true, - }, - multipleOf: { - type: 'number', - exclusiveMinimum: 0, - }, - maximum: { - type: 'number', - }, - exclusiveMaximum: { - type: 'number', - }, - minimum: { - type: 'number', - }, - exclusiveMinimum: { - type: 'number', - }, - maxLength: { $ref: '#/definitions/nonNegativeInteger' }, - minLength: { $ref: '#/definitions/nonNegativeIntegerDefault0' }, - pattern: { - type: 'string', - format: 'regex', - }, - additionalItems: { $ref: '#' }, - items: { - anyOf: [ - { $ref: '#' }, - { $ref: '#/definitions/schemaArray' }, - ], - default: true, - }, - maxItems: { $ref: '#/definitions/nonNegativeInteger' }, - minItems: { $ref: '#/definitions/nonNegativeIntegerDefault0' }, - uniqueItems: { - type: 'boolean', - default: false, - }, - contains: { $ref: '#' }, - maxProperties: { $ref: '#/definitions/nonNegativeInteger' }, - minProperties: { $ref: '#/definitions/nonNegativeIntegerDefault0' }, - required: { $ref: '#/definitions/stringArray' }, - additionalProperties: { $ref: '#' }, - definitions: { - type: 'object', - additionalProperties: { $ref: '#' }, - default: {}, - }, - properties: { - type: 'object', - additionalProperties: { $ref: '#' }, - default: {}, - }, - patternProperties: { - type: 'object', - additionalProperties: { $ref: '#' }, - propertyNames: { format: 'regex' }, - default: {}, - }, - dependencies: { - type: 'object', - additionalProperties: { - anyOf: [ - { $ref: '#' }, - { $ref: '#/definitions/stringArray' }, - ], - }, - }, - propertyNames: { $ref: '#' }, - const: true, - enum: { - type: 'array', - items: true, - minItems: 1, - uniqueItems: true, - }, - type: { - anyOf: [ - { $ref: '#/definitions/simpleTypes' }, - { - type: 'array', - items: { $ref: '#/definitions/simpleTypes' }, - minItems: 1, - uniqueItems: true, - }, - ], - }, - format: { type: 'string' }, - allOf: { $ref: '#/definitions/schemaArray' }, - anyOf: { $ref: '#/definitions/schemaArray' }, - oneOf: { $ref: '#/definitions/schemaArray' }, - not: { $ref: '#' }, - }, - default: true, -} as unknown as Schema - -const validator = new Validator() - export const validateSchemaAgainstDraft7 = (schemaToValidate: any) => { - const schema = produce(schemaToValidate, (draft: any) => { - // Make sure the schema has the $schema property for draft-07 - if (!draft.$schema) - draft.$schema = 'http://json-schema.org/draft-07/schema#' - }) + // First check against Draft-07 + const result = draft07Validator(schemaToValidate) + // Then apply custom rule + const customErrors = forbidBooleanProperties(schemaToValidate) - const result = validator.validate(schema, draft07MetaSchema, { - nestedErrors: true, - throwError: false, - }) - - // Access errors from the validation result - const errors = result.valid ? [] : result.errors || [] - - return errors + return [...result.errors, ...customErrors] } -export const getValidationErrorMessage = (errors: ValidationError[]) => { +export const getValidationErrorMessage = (errors: Array) => { const message = errors.map((error) => { - return `Error: ${error.path.join('.')} ${error.message} Details: ${JSON.stringify(error.stack)}` - }).join('; ') + if (typeof error === 'string') + return error + else + return `Error: ${error.stack}\n` + }).join('') return message } diff --git a/web/app/components/workflow/note-node/hooks.ts b/web/app/components/workflow/note-node/hooks.ts index 04e8081692..29642f90df 100644 --- a/web/app/components/workflow/note-node/hooks.ts +++ b/web/app/components/workflow/note-node/hooks.ts @@ -9,7 +9,7 @@ export const useNote = (id: string) => { const handleThemeChange = useCallback((theme: NoteTheme) => { handleNodeDataUpdateWithSyncDraft({ id, data: { theme } }) - saveStateToHistory(WorkflowHistoryEvent.NoteChange) + saveStateToHistory(WorkflowHistoryEvent.NoteChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) const handleEditorChange = useCallback((editorState: EditorState) => { @@ -21,7 +21,7 @@ export const useNote = (id: string) => { const handleShowAuthorChange = useCallback((showAuthor: boolean) => { handleNodeDataUpdateWithSyncDraft({ id, data: { showAuthor } }) - saveStateToHistory(WorkflowHistoryEvent.NoteChange) + saveStateToHistory(WorkflowHistoryEvent.NoteChange, { nodeId: id }) }, [handleNodeDataUpdateWithSyncDraft, id, saveStateToHistory]) return { diff --git a/web/app/components/workflow/variable-inspect/value-content.tsx b/web/app/components/workflow/variable-inspect/value-content.tsx index 2b28cd8ef4..a3ede311c4 100644 --- a/web/app/components/workflow/variable-inspect/value-content.tsx +++ b/web/app/components/workflow/variable-inspect/value-content.tsx @@ -60,18 +60,22 @@ const ValueContent = ({ const [fileValue, setFileValue] = useState(formatFileValue(currentVar)) const { run: debounceValueChange } = useDebounceFn(handleValueChange, { wait: 500 }) - if (showTextEditor) { - if (currentVar.value_type === 'number') - setValue(JSON.stringify(currentVar.value)) - if (!currentVar.value) - setValue('') - setValue(currentVar.value) - } - if (showJSONEditor) - setJson(currentVar.value ? JSON.stringify(currentVar.value, null, 2) : '') - if (showFileEditor) - setFileValue(formatFileValue(currentVar)) + // update default value when id changed + useEffect(() => { + if (showTextEditor) { + if (currentVar.value_type === 'number') + return setValue(JSON.stringify(currentVar.value)) + if (!currentVar.value) + return setValue('') + setValue(currentVar.value) + } + if (showJSONEditor) + setJson(currentVar.value ? JSON.stringify(currentVar.value, null, 2) : '') + + if (showFileEditor) + setFileValue(formatFileValue(currentVar)) + }, [currentVar.id, currentVar.value]) const handleTextChange = (value: string) => { if (currentVar.value_type === 'string') diff --git a/web/app/components/workflow/workflow-history-store.tsx b/web/app/components/workflow/workflow-history-store.tsx index 52132f3657..c250708177 100644 --- a/web/app/components/workflow/workflow-history-store.tsx +++ b/web/app/components/workflow/workflow-history-store.tsx @@ -51,6 +51,7 @@ export function useWorkflowHistoryStore() { setState: (state: WorkflowHistoryState) => { store.setState({ workflowHistoryEvent: state.workflowHistoryEvent, + workflowHistoryEventMeta: state.workflowHistoryEventMeta, nodes: state.nodes.map((node: Node) => ({ ...node, data: { ...node.data, selected: false } })), edges: state.edges.map((edge: Edge) => ({ ...edge, selected: false }) as Edge), }) @@ -76,6 +77,7 @@ function createStore({ (set, get) => { return { workflowHistoryEvent: undefined, + workflowHistoryEventMeta: undefined, nodes: storeNodes, edges: storeEdges, getNodes: () => get().nodes, @@ -97,6 +99,7 @@ export type WorkflowHistoryStore = { nodes: Node[] edges: Edge[] workflowHistoryEvent: WorkflowHistoryEvent | undefined + workflowHistoryEventMeta?: WorkflowHistoryEventMeta } export type WorkflowHistoryActions = { @@ -119,3 +122,8 @@ export type WorkflowWithHistoryProviderProps = { edges: Edge[] children: ReactNode } + +export type WorkflowHistoryEventMeta = { + nodeId?: string + nodeTitle?: string +} diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 694b7fb2da..c815ecb5e7 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -12603,7 +12603,7 @@ snapshots: '@vue/compiler-sfc@3.5.17': dependencies: - '@babel/parser': 7.28.0 + '@babel/parser': 7.28.3 '@vue/compiler-core': 3.5.17 '@vue/compiler-dom': 3.5.17 '@vue/compiler-ssr': 3.5.17 diff --git a/web/utils/draft-07.json b/web/utils/draft-07.json new file mode 100644 index 0000000000..99389d7ab4 --- /dev/null +++ b/web/utils/draft-07.json @@ -0,0 +1,245 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "$id": "http://json-schema.org/draft-07/schema#", + "title": "Core schema meta-schema", + "definitions": { + "schemaArray": { + "type": "array", + "minItems": 1, + "items": { + "$ref": "#" + } + }, + "nonNegativeInteger": { + "type": "integer", + "minimum": 0 + }, + "nonNegativeIntegerDefault0": { + "allOf": [ + { + "$ref": "#/definitions/nonNegativeInteger" + }, + { + "default": 0 + } + ] + }, + "simpleTypes": { + "enum": [ + "array", + "boolean", + "integer", + "null", + "number", + "object", + "string" + ] + }, + "stringArray": { + "type": "array", + "items": { + "type": "string" + }, + "uniqueItems": true, + "default": [] + } + }, + "type": [ + "object", + "boolean" + ], + "properties": { + "$id": { + "type": "string", + "format": "uri-reference" + }, + "$schema": { + "type": "string", + "format": "uri" + }, + "$ref": { + "type": "string", + "format": "uri-reference" + }, + "$comment": { + "type": "string" + }, + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "default": true, + "readOnly": { + "type": "boolean", + "default": false + }, + "writeOnly": { + "type": "boolean", + "default": false + }, + "examples": { + "type": "array", + "items": true + }, + "multipleOf": { + "type": "number", + "exclusiveMinimum": 0 + }, + "maximum": { + "type": "number" + }, + "exclusiveMaximum": { + "type": "number" + }, + "minimum": { + "type": "number" + }, + "exclusiveMinimum": { + "type": "number" + }, + "maxLength": { + "$ref": "#/definitions/nonNegativeInteger" + }, + "minLength": { + "$ref": "#/definitions/nonNegativeIntegerDefault0" + }, + "pattern": { + "type": "string", + "format": "regex" + }, + "additionalItems": { + "$ref": "#" + }, + "items": { + "anyOf": [ + { + "$ref": "#" + }, + { + "$ref": "#/definitions/schemaArray" + } + ], + "default": true + }, + "maxItems": { + "$ref": "#/definitions/nonNegativeInteger" + }, + "minItems": { + "$ref": "#/definitions/nonNegativeIntegerDefault0" + }, + "uniqueItems": { + "type": "boolean", + "default": false + }, + "contains": { + "$ref": "#" + }, + "maxProperties": { + "$ref": "#/definitions/nonNegativeInteger" + }, + "minProperties": { + "$ref": "#/definitions/nonNegativeIntegerDefault0" + }, + "required": { + "$ref": "#/definitions/stringArray" + }, + "additionalProperties": { + "$ref": "#" + }, + "definitions": { + "type": "object", + "additionalProperties": { + "$ref": "#" + }, + "default": {} + }, + "properties": { + "type": "object", + "additionalProperties": { + "$ref": "#" + }, + "default": {} + }, + "patternProperties": { + "type": "object", + "additionalProperties": { + "$ref": "#" + }, + "propertyNames": { + "format": "regex" + }, + "default": {} + }, + "dependencies": { + "type": "object", + "additionalProperties": { + "anyOf": [ + { + "$ref": "#" + }, + { + "$ref": "#/definitions/stringArray" + } + ] + } + }, + "propertyNames": { + "$ref": "#" + }, + "const": true, + "enum": { + "type": "array", + "items": true, + "minItems": 1, + "uniqueItems": true + }, + "type": { + "anyOf": [ + { + "$ref": "#/definitions/simpleTypes" + }, + { + "type": "array", + "items": { + "$ref": "#/definitions/simpleTypes" + }, + "minItems": 1, + "uniqueItems": true + } + ] + }, + "format": { + "type": "string" + }, + "contentMediaType": { + "type": "string" + }, + "contentEncoding": { + "type": "string" + }, + "if": { + "$ref": "#" + }, + "then": { + "$ref": "#" + }, + "else": { + "$ref": "#" + }, + "allOf": { + "$ref": "#/definitions/schemaArray" + }, + "anyOf": { + "$ref": "#/definitions/schemaArray" + }, + "oneOf": { + "$ref": "#/definitions/schemaArray" + }, + "not": { + "$ref": "#" + } + }, + "default": true +} diff --git a/web/utils/validators.ts b/web/utils/validators.ts new file mode 100644 index 0000000000..51b47feddf --- /dev/null +++ b/web/utils/validators.ts @@ -0,0 +1,27 @@ +import type { Schema } from 'jsonschema' +import { Validator } from 'jsonschema' +import draft07Schema from './draft-07.json' + +const validator = new Validator() + +export const draft07Validator = (schema: any) => { + return validator.validate(schema, draft07Schema as unknown as Schema) +} + +export const forbidBooleanProperties = (schema: any, path: string[] = []): string[] => { + let errors: string[] = [] + + if (schema && typeof schema === 'object' && schema.properties) { + for (const [key, val] of Object.entries(schema.properties)) { + if (typeof val === 'boolean') { + errors.push( + `Error: Property '${[...path, key].join('.')}' must not be a boolean schema`, + ) + } + else if (typeof val === 'object') { + errors = errors.concat(forbidBooleanProperties(val, [...path, key])) + } + } + } + return errors +}