From 64c9a2f678414ee9614a1467ad967d198236e617 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Mon, 8 Sep 2025 23:45:05 -0700 Subject: [PATCH] Feat/credential policy (#25151) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/workflow.py | 6 +- api/core/entities/provider_configuration.py | 28 +- api/core/entities/provider_entities.py | 1 + api/core/helper/credential_utils.py | 75 +++++ api/core/model_manager.py | 36 +++ api/core/provider_manager.py | 1 + api/core/tools/errors.py | 4 + api/core/tools/tool_manager.py | 15 +- api/services/enterprise/base.py | 22 +- .../enterprise/plugin_manager_service.py | 52 ++++ api/services/feature_service.py | 6 + api/services/workflow_service.py | 274 +++++++++++++++++- 12 files changed, 495 insertions(+), 25 deletions(-) create mode 100644 api/core/helper/credential_utils.py create mode 100644 api/services/enterprise/plugin_manager_service.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index bf20a5ae62..05178328fe 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -11,11 +11,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from configs import dify_config from controllers.console import api -from controllers.console.app.error import ( - ConversationCompletedError, - DraftWorkflowNotExist, - DraftWorkflowNotSync, -) +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 61a960c3d4..9cf35e559d 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -42,6 +42,7 @@ from models.provider import ( ProviderType, TenantPreferredModelProvider, ) +from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) @@ -129,14 +130,38 @@ class ProviderConfiguration(BaseModel): return copy_credentials else: credentials = None + current_credential_id = None + if self.custom_configuration.models: for model_configuration in self.custom_configuration.models: if model_configuration.model_type == model_type and model_configuration.model == model: credentials = model_configuration.credentials + current_credential_id = model_configuration.current_credential_id break if not credentials and self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials + current_credential_id = self.custom_configuration.provider.current_credential_id + + if current_credential_id: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=current_credential_id, + provider=self.provider.provider, + credential_type=PluginCredentialType.MODEL, + ) + else: + # no current credential id, check all available credentials + if self.custom_configuration.provider: + for credential_configuration in self.custom_configuration.provider.available_credentials: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=credential_configuration.credential_id, + provider=self.provider.provider, + credential_type=PluginCredentialType.MODEL, + ) return credentials @@ -266,7 +291,6 @@ class ProviderConfiguration(BaseModel): :param credential_id: if provided, return the specified credential :return: """ - if credential_id: return self._get_specific_provider_credential(credential_id) @@ -738,6 +762,7 @@ class ProviderConfiguration(BaseModel): current_credential_id = credential_record.id current_credential_name = credential_record.credential_name + credentials = self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas @@ -792,6 +817,7 @@ class ProviderConfiguration(BaseModel): ): current_credential_id = model_configuration.current_credential_id current_credential_name = model_configuration.current_credential_name + credentials = self.obfuscated_credentials( credentials=model_configuration.credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 79a7514bbc..9b8baf1973 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -145,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel): name: str credentials: dict credential_source_type: str | None = None + credential_id: str | None = None class ModelSettings(BaseModel): diff --git a/api/core/helper/credential_utils.py b/api/core/helper/credential_utils.py new file mode 100644 index 0000000000..240f498181 --- /dev/null +++ b/api/core/helper/credential_utils.py @@ -0,0 +1,75 @@ +""" +Credential utility functions for checking credential existence and policy compliance. +""" + +from services.enterprise.plugin_manager_service import PluginCredentialType + + +def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: + """ + Check if the credential still exists in the database. + + :param credential_id: The credential ID to check + :param credential_type: The type of credential (MODEL or TOOL) + :return: True if credential exists, False otherwise + """ + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + from models.provider import ProviderCredential, ProviderModelCredential + from models.tools import BuiltinToolProvider + + with Session(db.engine) as session: + if credential_type == PluginCredentialType.MODEL: + # Check both pre-defined and custom model credentials using a single UNION query + stmt = ( + select(ProviderCredential.id) + .where(ProviderCredential.id == credential_id) + .union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id)) + ) + return session.scalar(stmt) is not None + + if credential_type == PluginCredentialType.TOOL: + return ( + session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id)) + is not None + ) + + return False + + +def check_credential_policy_compliance( + credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True +) -> None: + """ + Check credential policy compliance for the given credential ID. + + :param credential_id: The credential ID to check + :param provider: The provider name + :param credential_type: The type of credential (MODEL or TOOL) + :param check_existence: Whether to check if credential exists in database first + :raises ValueError: If credential policy compliance check fails + """ + from services.enterprise.plugin_manager_service import ( + CheckCredentialPolicyComplianceRequest, + PluginManagerService, + ) + from services.feature_service import FeatureService + + if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id: + return + + # Check if credential exists in database first (if requested) + if check_existence: + if not is_credential_exists(credential_id, credential_type): + raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.") + + # Check policy compliance + PluginManagerService.check_credential_policy_compliance( + CheckCredentialPolicyComplianceRequest( + dify_credential_id=credential_id, + provider=provider, + credential_type=credential_type, + ) + ) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a59b0ae826..10df2ad79e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -23,6 +23,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.provider_manager import ProviderManager from extensions.ext_redis import redis_client from models.provider import ProviderType +from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) @@ -362,6 +363,23 @@ class ModelInstance: else: raise last_exception + # Additional policy compliance check as fallback (in case fetch_next didn't catch it) + try: + from core.helper.credential_utils import check_credential_policy_compliance + + if lb_config.credential_id: + check_credential_policy_compliance( + credential_id=lb_config.credential_id, + provider=self.provider, + credential_type=PluginCredentialType.MODEL, + ) + except Exception as e: + logger.warning( + "Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e) + ) + self.load_balancing_manager.cooldown(lb_config, expire=60) + continue + try: if "credentials" in kwargs: del kwargs["credentials"] @@ -515,6 +533,24 @@ class LBModelManager: continue + # Check policy compliance for the selected configuration + try: + from core.helper.credential_utils import check_credential_policy_compliance + + if config.credential_id: + check_credential_policy_compliance( + credential_id=config.credential_id, + provider=self._provider, + credential_type=PluginCredentialType.MODEL, + ) + except Exception as e: + logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e)) + cooldown_load_balancing_configs.append(config) + if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): + # all configs are in cooldown or failed policy compliance + return None + continue + if dify_config.DEBUG: logger.info( """Model LB diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 13dcef1a1f..e4e8b09a04 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1129,6 +1129,7 @@ class ProviderManager: name=load_balancing_model_config.name, credentials=provider_model_credentials, credential_source_type=load_balancing_model_config.credential_source_type, + credential_id=load_balancing_model_config.credential_id, ) ) diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index c5f9ca4774..b0c2232857 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolCredentialPolicyViolationError(ValueError): + pass + + class ToolEngineInvokeError(Exception): meta: ToolInvokeMeta diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 00fc57a3f1..bc1f09a2fc 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -27,6 +27,7 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.workflow.entities.variable_pool import VariablePool +from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: @@ -55,9 +56,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -237,6 +236,16 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + # check if the credential is allowed to be used + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=builtin_provider.id, + provider=provider_id, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, config=[ diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 3c3f970444..edb76408e8 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -3,18 +3,30 @@ import os import requests -class EnterpriseRequest: - base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") - secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") - +class BaseRequest: proxies = { "http": "", "https": "", } + base_url = "" + secret_key = "" + secret_key_header = "" @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} + headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) return response.json() + + +class EnterpriseRequest(BaseRequest): + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") + secret_key_header = "Enterprise-Api-Secret-Key" + + +class EnterprisePluginManagerRequest(BaseRequest): + base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL") + secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY") + secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key" diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py new file mode 100644 index 0000000000..cfcc39416a --- /dev/null +++ b/api/services/enterprise/plugin_manager_service.py @@ -0,0 +1,52 @@ +import enum +import logging + +from pydantic import BaseModel + +from services.enterprise.base import EnterprisePluginManagerRequest +from services.errors.base import BaseServiceError + + +class PluginCredentialType(enum.Enum): + MODEL = 0 + TOOL = 1 + + def to_number(self): + return self.value + + +class CheckCredentialPolicyComplianceRequest(BaseModel): + dify_credential_id: str + provider: str + credential_type: PluginCredentialType + + def model_dump(self, **kwargs): + data = super().model_dump(**kwargs) + data["credential_type"] = self.credential_type.to_number() + return data + + +class CredentialPolicyViolationError(BaseServiceError): + pass + + +class PluginManagerService: + @classmethod + def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest): + try: + ret = EnterprisePluginManagerRequest.send_request( + "POST", "/check-credential-policy-compliance", json=body.model_dump() + ) + if not isinstance(ret, dict) or "result" not in ret: + raise ValueError("Invalid response format from plugin manager API") + except Exception as e: + raise CredentialPolicyViolationError( + f"error occurred while checking credential policy compliance: {e}" + ) from e + + if not ret.get("result", False): + raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") + + logging.debug( + f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}" + ) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 1441e6ce16..c27c0b0d58 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -134,6 +134,10 @@ class KnowledgeRateLimitModel(BaseModel): subscription_plan: str = "" +class PluginManagerModel(BaseModel): + enabled: bool = False + + class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" @@ -150,6 +154,7 @@ class SystemFeatureModel(BaseModel): webapp_auth: WebAppAuthModel = WebAppAuthModel() plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() enable_change_email: bool = True + plugin_manager: PluginManagerModel = PluginManagerModel() class FeatureService: @@ -188,6 +193,7 @@ class FeatureService: system_features.branding.enabled = True system_features.webapp_auth.enabled = True system_features.enable_change_email = False + system_features.plugin_manager.enabled = True cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 350e52e438..0a14007349 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -36,22 +36,14 @@ from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider -from models.workflow import ( - Workflow, - WorkflowNodeExecutionModel, - WorkflowNodeExecutionTriggeredFrom, - WorkflowType, -) +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType from repositories.factory import DifyAPIRepositoryFactory +from services.enterprise.plugin_manager_service import PluginCredentialType from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from .workflow_draft_variable_service import ( - DraftVariableSaver, - DraftVarLoader, - WorkflowDraftVariableService, -) +from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService class WorkflowService: @@ -271,6 +263,12 @@ class WorkflowService: if not draft_workflow: raise ValueError("No valid workflow found.") + # Validate credentials before publishing, for credential policy check + from services.feature_service import FeatureService + + if FeatureService.get_system_features().plugin_manager.enabled: + self._validate_workflow_credentials(draft_workflow) + # create new workflow workflow = Workflow.new( tenant_id=app_model.tenant_id, @@ -295,6 +293,260 @@ class WorkflowService: # return new workflow return workflow + def _validate_workflow_credentials(self, workflow: Workflow) -> None: + """ + Validate all credentials in workflow nodes before publishing. + + :param workflow: The workflow to validate + :raises ValueError: If any credentials violate policy compliance + """ + graph_dict = workflow.graph_dict + nodes = graph_dict.get("nodes", []) + + for node in nodes: + node_data = node.get("data", {}) + node_type = node_data.get("type") + node_id = node.get("id", "unknown") + + try: + # Extract and validate credentials based on node type + if node_type == "tool": + credential_id = node_data.get("credential_id") + provider = node_data.get("provider_id") + if provider: + if credential_id: + # Check specific credential + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=credential_id, + provider=provider, + credential_type=PluginCredentialType.TOOL, + ) + else: + # Check default workspace credential for this provider + self._check_default_tool_credential(workflow.tenant_id, provider) + + elif node_type == "agent": + agent_params = node_data.get("agent_parameters", {}) + + model_config = agent_params.get("model", {}).get("value", {}) + if model_config.get("provider") and model_config.get("model"): + self._validate_llm_model_config( + workflow.tenant_id, model_config["provider"], model_config["model"] + ) + + # Validate load balancing credentials for agent model if load balancing is enabled + agent_model_node_data = {"model": model_config} + self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id) + + # Validate agent tools + tools = agent_params.get("tools", {}).get("value", []) + for tool in tools: + # Agent tools store provider in provider_name field + provider = tool.get("provider_name") + credential_id = tool.get("credential_id") + if provider: + if credential_id: + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL) + else: + self._check_default_tool_credential(workflow.tenant_id, provider) + + elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]: + model_config = node_data.get("model", {}) + provider = model_config.get("provider") + model_name = model_config.get("name") + + if provider and model_name: + # Validate that the provider+model combination can fetch valid credentials + self._validate_llm_model_config(workflow.tenant_id, provider, model_name) + # Validate load balancing credentials if load balancing is enabled + self._validate_load_balancing_credentials(workflow, node_data, node_id) + else: + raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration") + + except Exception as e: + if isinstance(e, ValueError): + raise e + else: + raise ValueError(f"Node {node_id} ({node_type}): {str(e)}") + + def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None: + """ + Validate that an LLM model configuration can fetch valid credentials. + + This method attempts to get the model instance and validates that: + 1. The provider exists and is configured + 2. The model exists in the provider + 3. Credentials can be fetched for the model + 4. The credentials pass policy compliance checks + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :raises ValueError: If the model configuration is invalid or credentials fail policy checks + """ + try: + from core.model_manager import ModelManager + from core.model_runtime.entities.model_entities import ModelType + + # Get model instance to validate provider+model combination + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name + ) + + # The ModelInstance constructor will automatically check credential policy compliance + # via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance() + # If it fails, an exception will be raised + + except Exception as e: + raise ValueError( + f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}" + ) + + def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None: + """ + Check credential policy compliance for the default workspace credential of a tool provider. + + This method finds the default credential for the given provider and validates it. + Uses the same fallback logic as runtime to handle deauthorized credentials. + + :param tenant_id: The tenant ID + :param provider: The tool provider name + :raises ValueError: If no default credential exists or if it fails policy compliance + """ + try: + from models.tools import BuiltinToolProvider + + # Use the same fallback logic as runtime: get the first available credential + # ordered by is_default DESC, created_at ASC (same as tool_manager.py) + default_provider = ( + db.session.query(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() + ) + + if not default_provider: + raise ValueError("No default credential found") + + # Check credential policy compliance using the default credential ID + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=default_provider.id, + provider=provider, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + + except Exception as e: + raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") + + def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: + """ + Validate load balancing credentials for a workflow node. + + :param workflow: The workflow being validated + :param node_data: The node data containing model configuration + :param node_id: The node ID for error reporting + :raises ValueError: If load balancing credentials violate policy compliance + """ + # Extract model configuration + model_config = node_data.get("model", {}) + provider = model_config.get("provider") + model_name = model_config.get("name") + + if not provider or not model_name: + return # No model config to validate + + # Check if this model has load balancing enabled + if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name): + # Get all load balancing configurations for this model + load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name) + # Validate each load balancing configuration + try: + for config in load_balancing_configs: + if config.get("credential_id"): + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + config["credential_id"], provider, PluginCredentialType.MODEL + ) + except Exception as e: + raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") + + def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool: + """ + Check if load balancing is enabled for a specific model. + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :return: True if load balancing is enabled, False otherwise + """ + try: + from core.model_runtime.entities.model_entities import ModelType + from core.provider_manager import ProviderManager + + # Get provider configurations + provider_manager = ProviderManager() + provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configuration = provider_configurations.get(provider) + + if not provider_configuration: + return False + + # Get provider model setting + provider_model_setting = provider_configuration.get_provider_model_setting( + model_type=ModelType.LLM, + model=model_name, + ) + return provider_model_setting is not None and provider_model_setting.load_balancing_enabled + + except Exception: + # If we can't determine the status, assume load balancing is not enabled + return False + + def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]: + """ + Get all load balancing configurations for a model. + + :param tenant_id: The tenant ID + :param provider: The provider name + :param model_name: The model name + :return: List of load balancing configuration dictionaries + """ + try: + from services.model_load_balancing_service import ModelLoadBalancingService + + model_load_balancing_service = ModelLoadBalancingService() + _, configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, + provider=provider, + model=model_name, + model_type="llm", # Load balancing is primarily used for LLM models + config_from="predefined-model", # Check both predefined and custom models + ) + + _, custom_configs = model_load_balancing_service.get_load_balancing_configs( + tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" + ) + all_configs = configs + custom_configs + + return [config for config in all_configs if config.get("credential_id")] + + except Exception: + # If we can't get the configurations, return empty list + # This will prevent validation errors from breaking the workflow + return [] + def get_default_block_configs(self) -> list[dict]: """ Get default block configs