diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 9f8d06e322..0279725ff2 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -473,9 +473,21 @@ class ProviderConfiguration(BaseModel): self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) else: - # some historical data may have a provider record but not be set as valid provider_record.is_valid = True + if provider_record.credential_id is None: + provider_record.credential_id = new_record.id + provider_record.updated_at = naive_utc_now() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) + session.commit() except Exception: session.rollback() diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c538a557fb..ed34922346 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -196,6 +196,8 @@ class ProviderManager: if preferred_provider_type_record: preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) + elif dify_config.EDITION == "CLOUD" and system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM elif custom_configuration.provider or custom_configuration.models: preferred_provider_type = ProviderType.CUSTOM elif system_configuration.enabled: diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 55a3ffde78..ca83742d65 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.provider import Provider, ProviderCredential +from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from models.provider_ids import GenericProviderID from services.enterprise.plugin_manager_service import ( PluginManagerService, @@ -534,6 +534,13 @@ class PluginService: plugin_id = plugin.plugin_id logger.info("Deleting credentials for plugin: %s", plugin_id) + session.execute( + delete(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"), + ) + ) + # Delete provider credentials that match this plugin credential_ids = session.scalars( select(ProviderCredential.id).where( diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 82f98d07a3..5ebefcd8d2 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -734,7 +734,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No def test_create_provider_credential_marks_existing_provider_as_valid() -> None: configuration = _build_provider_configuration() session = Mock() - provider_record = SimpleNamespace(is_valid=False) + provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred") with _patched_session(session): with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): @@ -743,6 +743,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None: configuration.create_provider_credential({"api_key": "raw"}, "Main") assert provider_record.is_valid is True + assert provider_record.credential_id == "existing-cred" + session.commit.assert_called_once() + + +def test_create_provider_credential_auto_activates_when_no_active_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache"): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type"): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + assert provider_record.is_valid is True + assert provider_record.credential_id is not None session.commit.assert_called_once()