diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 416bc8cef9..2245adb681 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -49,6 +49,94 @@ class DatasourceProviderService: def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() + @staticmethod + def _should_refresh_credentials(datasource_provider: DatasourceProvider, now: int | None = None) -> bool: + current_time = int(time.time()) if now is None else now + if datasource_provider.expires_at == -1: + return False + return (datasource_provider.expires_at - 60) < current_time + + def _refresh_datasource_credentials( + self, + tenant_id: str, + provider: str, + plugin_id: str, + datasource_provider: DatasourceProvider, + current_user: Any, + ) -> tuple[dict[str, Any], int]: + datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}") + provider_name = datasource_provider_id.provider_name + credential_id = getattr(datasource_provider, "id", None) + credential_name = getattr(datasource_provider, "name", None) + logger.info( + "Refreshing datasource credentials for provider %s", + provider_name, + extra={ + "tenant_id": tenant_id, + "plugin_id": datasource_provider_id.plugin_id, + "provider": provider_name, + "credential_id": credential_id, + "credential_name": credential_name, + "expires_at": datasource_provider.expires_at, + }, + ) + decrypted_credentials = self.decrypt_datasource_provider_credentials( + tenant_id=tenant_id, + datasource_provider=datasource_provider, + plugin_id=plugin_id, + provider=provider, + ) + redirect_uri = ( + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback" + ) + system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id) + try: + refreshed_credentials = OAuthHandler().refresh_credentials( + tenant_id=tenant_id, + user_id=current_user.id, + plugin_id=datasource_provider_id.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + except Exception as exc: + message = ( + f"Failed to refresh datasource credentials for provider {provider_name}" + f" (credential: {credential_name or credential_id or 'unknown'})" + ) + logger.exception( + message, + extra={ + "tenant_id": tenant_id, + "plugin_id": datasource_provider_id.plugin_id, + "provider": provider_name, + "credential_id": credential_id, + "credential_name": credential_name, + }, + ) + raise ValueError(f"{message}: {exc}") from exc + encrypted_credentials = self.encrypt_datasource_provider_credentials( + tenant_id=tenant_id, + raw_credentials=refreshed_credentials.credentials, + provider=provider, + plugin_id=plugin_id, + datasource_provider=datasource_provider, + ) + logger.info( + "Refreshed datasource credentials for provider %s", + provider_name, + extra={ + "tenant_id": tenant_id, + "plugin_id": datasource_provider_id.plugin_id, + "provider": provider_name, + "credential_id": credential_id, + "credential_name": credential_name, + "expires_at": refreshed_credentials.expires_at, + }, + ) + return encrypted_credentials, refreshed_credentials.expires_at + def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID): """ remove oauth custom client params @@ -108,7 +196,10 @@ class DatasourceProviderService: credential_id: str | None = None, ) -> dict[str, Any]: """ - get credential by id + Return decrypted datasource credentials. + + If the stored credential is expired or about to expire, this method refreshes + it through plugin-daemon and persists the refreshed credential before returning. """ with sessionmaker(bind=db.engine).begin() as session: if credential_id: @@ -130,39 +221,17 @@ class DatasourceProviderService: ) if not datasource_provider: return {} - # refresh the credentials - if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()): + if self._should_refresh_credentials(datasource_provider): current_user = get_current_user() - decrypted_credentials = self.decrypt_datasource_provider_credentials( + encrypted_credentials, expires_at = self._refresh_datasource_credentials( tenant_id=tenant_id, - datasource_provider=datasource_provider, - plugin_id=plugin_id, - provider=provider, - ) - datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}") - provider_name = datasource_provider_id.provider_name - redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/" - f"{datasource_provider_id}/datasource/callback" - ) - system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id) - refreshed_credentials = OAuthHandler().refresh_credentials( - tenant_id=tenant_id, - user_id=current_user.id, - plugin_id=datasource_provider_id.plugin_id, - provider=provider_name, - redirect_uri=redirect_uri, - system_credentials=system_credentials or {}, - credentials=decrypted_credentials, - ) - datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials( - tenant_id=tenant_id, - raw_credentials=refreshed_credentials.credentials, provider=provider, plugin_id=plugin_id, datasource_provider=datasource_provider, + current_user=current_user, ) - datasource_provider.expires_at = refreshed_credentials.expires_at + datasource_provider.encrypted_credentials = encrypted_credentials + datasource_provider.expires_at = expires_at return self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, @@ -178,7 +247,10 @@ class DatasourceProviderService: plugin_id: str, ) -> list[dict[str, Any]]: """ - get all datasource credentials by provider + Return all decrypted datasource credentials for a provider. + + Expired credentials are refreshed independently. A failed credential refresh is + logged and skipped so one broken authorization does not block other credentials. """ with sessionmaker(bind=db.engine).begin() as session: datasource_providers = session.scalars( @@ -193,46 +265,39 @@ class DatasourceProviderService: if not datasource_providers: return [] current_user = get_current_user() - # refresh the credentials real_credentials_list = [] for datasource_provider in datasource_providers: - decrypted_credentials = self.decrypt_datasource_provider_credentials( - tenant_id=tenant_id, - datasource_provider=datasource_provider, - plugin_id=plugin_id, - provider=provider, - ) - datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}") - provider_name = datasource_provider_id.provider_name - redirect_uri = ( - f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/" - f"{datasource_provider_id}/datasource/callback" - ) - system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id) - refreshed_credentials = OAuthHandler().refresh_credentials( - tenant_id=tenant_id, - user_id=current_user.id, - plugin_id=datasource_provider_id.plugin_id, - provider=provider_name, - redirect_uri=redirect_uri, - system_credentials=system_credentials or {}, - credentials=decrypted_credentials, - ) - datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials( - tenant_id=tenant_id, - raw_credentials=refreshed_credentials.credentials, - provider=provider, - plugin_id=plugin_id, - datasource_provider=datasource_provider, - ) - datasource_provider.expires_at = refreshed_credentials.expires_at - real_credentials = self.decrypt_datasource_provider_credentials( - tenant_id=tenant_id, - datasource_provider=datasource_provider, - plugin_id=plugin_id, - provider=provider, - ) - real_credentials_list.append(real_credentials) + try: + if self._should_refresh_credentials(datasource_provider): + encrypted_credentials, expires_at = self._refresh_datasource_credentials( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + datasource_provider=datasource_provider, + current_user=current_user, + ) + datasource_provider.encrypted_credentials = encrypted_credentials + datasource_provider.expires_at = expires_at + real_credentials = self.decrypt_datasource_provider_credentials( + tenant_id=tenant_id, + datasource_provider=datasource_provider, + plugin_id=plugin_id, + provider=provider, + ) + real_credentials_list.append(real_credentials) + except Exception: + logger.exception( + "Skipping datasource credentials for provider %s after refresh or decrypt failure", + provider, + extra={ + "tenant_id": tenant_id, + "plugin_id": plugin_id, + "provider": provider, + "credential_id": getattr(datasource_provider, "id", None), + "credential_name": getattr(datasource_provider, "name", None), + "expires_at": getattr(datasource_provider, "expires_at", None), + }, + ) return real_credentials_list diff --git a/api/services/rag_pipeline/rag_pipeline_manage_service.py b/api/services/rag_pipeline/rag_pipeline_manage_service.py index 0908d30c12..eaf797760f 100644 --- a/api/services/rag_pipeline/rag_pipeline_manage_service.py +++ b/api/services/rag_pipeline/rag_pipeline_manage_service.py @@ -1,7 +1,11 @@ +import logging + from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.plugin.impl.datasource import PluginDatasourceManager from services.datasource_provider_service import DatasourceProviderService +logger = logging.getLogger(__name__) + class RagPipelineManageService: @staticmethod @@ -15,9 +19,21 @@ class RagPipelineManageService: datasources = manager.fetch_datasource_providers(tenant_id) for datasource in datasources: datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_datasource_credentials( - tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id - ) - if credentials: - datasource.is_authorized = True + try: + credentials = datasource_provider_service.get_datasource_credentials( + tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id + ) + if credentials: + datasource.is_authorized = True + except Exception: + logger.exception( + "Skipping datasource credentials for provider %s after refresh or decrypt failure", + datasource.provider, + extra={ + "tenant_id": tenant_id, + "plugin_id": datasource.plugin_id, + "provider": datasource.provider, + }, + ) + return datasources diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index c389c4a635..f374a29482 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -243,7 +243,7 @@ class TestDatasourceProviderService: assert service.get_datasource_credentials("t1", "prov", "org/plug") == {} def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user): - """Expired OAuth credential (expires_at near zero) triggers a silent refresh.""" + """Expired OAuth credential (expires_at near zero) triggers a refresh.""" p = MagicMock(spec=DatasourceProvider) p.auth_type = "oauth2" p.expires_at = 0 # expired @@ -256,6 +256,24 @@ class TestDatasourceProviderService: ): service.get_datasource_credentials("t1", "prov", "org/plug") + def test_should_include_provider_name_when_refresh_fails(self, service, mock_db_session, mock_user): + p = MagicMock(spec=DatasourceProvider) + p.id = "cred-id" + p.name = "Credential" + p.auth_type = "oauth2" + p.expires_at = 0 + p.encrypted_credentials = {"tok": "x"} + mock_db_session.scalar.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch("services.datasource_provider_service.OAuthHandler") as oauth_handler, + patch.object(service, "get_oauth_client", return_value={"oc": "v"}), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}), + ): + oauth_handler.return_value.refresh_credentials.side_effect = RuntimeError("token endpoint failed") + with pytest.raises(ValueError, match="provider prov"): + service.get_datasource_credentials("t1", "prov", "org/plug") + def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user): """API key credentials with expires_at=-1 skip refresh and return directly.""" p = MagicMock(spec=DatasourceProvider) @@ -307,6 +325,51 @@ class TestDatasourceProviderService: result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") assert len(result) == 1 + def test_should_skip_failed_provider_when_refreshing_all_credentials( + self, service, mock_db_session, mock_user, caplog + ): + failed_provider = MagicMock(spec=DatasourceProvider) + failed_provider.id = "failed-cred" + failed_provider.name = "Failed" + failed_provider.auth_type = "oauth2" + failed_provider.expires_at = 0 + working_provider = MagicMock(spec=DatasourceProvider) + working_provider.id = "working-cred" + working_provider.name = "Working" + working_provider.auth_type = "oauth2" + working_provider.expires_at = 0 + mock_db_session.scalars.return_value.all.return_value = [failed_provider, working_provider] + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object( + service, + "_refresh_datasource_credentials", + side_effect=[ValueError("refresh failed"), ({"t": "enc"}, 9999)], + ) as refresh_credentials, + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}), + ): + result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") + assert result == [{"t": "plain"}] + assert refresh_credentials.call_count == 2 + assert "Skipping datasource credentials for provider prov" in caplog.text + + def test_should_return_valid_credentials_without_refresh_when_getting_all_credentials( + self, service, mock_db_session, mock_user + ): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "oauth2" + p.expires_at = -1 + p.encrypted_credentials = {"t": "x"} + mock_db_session.scalars.return_value.all.return_value = [p] + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "_refresh_datasource_credentials") as refresh_credentials, + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}), + ): + result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") + assert result == [{"t": "plain"}] + refresh_credentials.assert_not_called() + # ----------------------------------------------------------------------- # update_datasource_provider_name (lines 236-303) # -----------------------------------------------------------------------