fix(api): gracefully handle credential fetch failures in rag pipeline (#36165)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
林玮 (Jade Lin) 2026-05-14 16:27:19 +08:00 committed by GitHub
parent d9ccfcbc6e
commit e660d7af38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 217 additions and 73 deletions

View File

@ -49,6 +49,94 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
@staticmethod
def _should_refresh_credentials(datasource_provider: DatasourceProvider, now: int | None = None) -> bool:
current_time = int(time.time()) if now is None else now
if datasource_provider.expires_at == -1:
return False
return (datasource_provider.expires_at - 60) < current_time
def _refresh_datasource_credentials(
self,
tenant_id: str,
provider: str,
plugin_id: str,
datasource_provider: DatasourceProvider,
current_user: Any,
) -> tuple[dict[str, Any], int]:
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
credential_id = getattr(datasource_provider, "id", None)
credential_name = getattr(datasource_provider, "name", None)
logger.info(
"Refreshing datasource credentials for provider %s",
provider_name,
extra={
"tenant_id": tenant_id,
"plugin_id": datasource_provider_id.plugin_id,
"provider": provider_name,
"credential_id": credential_id,
"credential_name": credential_name,
"expires_at": datasource_provider.expires_at,
},
)
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
try:
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
except Exception as exc:
message = (
f"Failed to refresh datasource credentials for provider {provider_name}"
f" (credential: {credential_name or credential_id or 'unknown'})"
)
logger.exception(
message,
extra={
"tenant_id": tenant_id,
"plugin_id": datasource_provider_id.plugin_id,
"provider": provider_name,
"credential_id": credential_id,
"credential_name": credential_name,
},
)
raise ValueError(f"{message}: {exc}") from exc
encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
)
logger.info(
"Refreshed datasource credentials for provider %s",
provider_name,
extra={
"tenant_id": tenant_id,
"plugin_id": datasource_provider_id.plugin_id,
"provider": provider_name,
"credential_id": credential_id,
"credential_name": credential_name,
"expires_at": refreshed_credentials.expires_at,
},
)
return encrypted_credentials, refreshed_credentials.expires_at
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
"""
remove oauth custom client params
@ -108,7 +196,10 @@ class DatasourceProviderService:
credential_id: str | None = None,
) -> dict[str, Any]:
"""
get credential by id
Return decrypted datasource credentials.
If the stored credential is expired or about to expire, this method refreshes
it through plugin-daemon and persists the refreshed credential before returning.
"""
with sessionmaker(bind=db.engine).begin() as session:
if credential_id:
@ -130,39 +221,17 @@ class DatasourceProviderService:
)
if not datasource_provider:
return {}
# refresh the credentials
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
if self._should_refresh_credentials(datasource_provider):
current_user = get_current_user()
decrypted_credentials = self.decrypt_datasource_provider_credentials(
encrypted_credentials, expires_at = self._refresh_datasource_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
current_user=current_user,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
datasource_provider.encrypted_credentials = encrypted_credentials
datasource_provider.expires_at = expires_at
return self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
@ -178,7 +247,10 @@ class DatasourceProviderService:
plugin_id: str,
) -> list[dict[str, Any]]:
"""
get all datasource credentials by provider
Return all decrypted datasource credentials for a provider.
Expired credentials are refreshed independently. A failed credential refresh is
logged and skipped so one broken authorization does not block other credentials.
"""
with sessionmaker(bind=db.engine).begin() as session:
datasource_providers = session.scalars(
@ -193,46 +265,39 @@ class DatasourceProviderService:
if not datasource_providers:
return []
current_user = get_current_user()
# refresh the credentials
real_credentials_list = []
for datasource_provider in datasource_providers:
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
real_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
real_credentials_list.append(real_credentials)
try:
if self._should_refresh_credentials(datasource_provider):
encrypted_credentials, expires_at = self._refresh_datasource_credentials(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
current_user=current_user,
)
datasource_provider.encrypted_credentials = encrypted_credentials
datasource_provider.expires_at = expires_at
real_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
real_credentials_list.append(real_credentials)
except Exception:
logger.exception(
"Skipping datasource credentials for provider %s after refresh or decrypt failure",
provider,
extra={
"tenant_id": tenant_id,
"plugin_id": plugin_id,
"provider": provider,
"credential_id": getattr(datasource_provider, "id", None),
"credential_name": getattr(datasource_provider, "name", None),
"expires_at": getattr(datasource_provider, "expires_at", None),
},
)
return real_credentials_list

View File

@ -1,7 +1,11 @@
import logging
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
from core.plugin.impl.datasource import PluginDatasourceManager
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
class RagPipelineManageService:
@staticmethod
@ -15,9 +19,21 @@ class RagPipelineManageService:
datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
if credentials:
datasource.is_authorized = True
try:
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
if credentials:
datasource.is_authorized = True
except Exception:
logger.exception(
"Skipping datasource credentials for provider %s after refresh or decrypt failure",
datasource.provider,
extra={
"tenant_id": tenant_id,
"plugin_id": datasource.plugin_id,
"provider": datasource.provider,
},
)
return datasources

View File

@ -243,7 +243,7 @@ class TestDatasourceProviderService:
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
"""Expired OAuth credential (expires_at near zero) triggers a silent refresh."""
"""Expired OAuth credential (expires_at near zero) triggers a refresh."""
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "oauth2"
p.expires_at = 0 # expired
@ -256,6 +256,24 @@ class TestDatasourceProviderService:
):
service.get_datasource_credentials("t1", "prov", "org/plug")
def test_should_include_provider_name_when_refresh_fails(self, service, mock_db_session, mock_user):
p = MagicMock(spec=DatasourceProvider)
p.id = "cred-id"
p.name = "Credential"
p.auth_type = "oauth2"
p.expires_at = 0
p.encrypted_credentials = {"tok": "x"}
mock_db_session.scalar.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch("services.datasource_provider_service.OAuthHandler") as oauth_handler,
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
):
oauth_handler.return_value.refresh_credentials.side_effect = RuntimeError("token endpoint failed")
with pytest.raises(ValueError, match="provider prov"):
service.get_datasource_credentials("t1", "prov", "org/plug")
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
"""API key credentials with expires_at=-1 skip refresh and return directly."""
p = MagicMock(spec=DatasourceProvider)
@ -307,6 +325,51 @@ class TestDatasourceProviderService:
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
assert len(result) == 1
def test_should_skip_failed_provider_when_refreshing_all_credentials(
self, service, mock_db_session, mock_user, caplog
):
failed_provider = MagicMock(spec=DatasourceProvider)
failed_provider.id = "failed-cred"
failed_provider.name = "Failed"
failed_provider.auth_type = "oauth2"
failed_provider.expires_at = 0
working_provider = MagicMock(spec=DatasourceProvider)
working_provider.id = "working-cred"
working_provider.name = "Working"
working_provider.auth_type = "oauth2"
working_provider.expires_at = 0
mock_db_session.scalars.return_value.all.return_value = [failed_provider, working_provider]
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(
service,
"_refresh_datasource_credentials",
side_effect=[ValueError("refresh failed"), ({"t": "enc"}, 9999)],
) as refresh_credentials,
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}),
):
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
assert result == [{"t": "plain"}]
assert refresh_credentials.call_count == 2
assert "Skipping datasource credentials for provider prov" in caplog.text
def test_should_return_valid_credentials_without_refresh_when_getting_all_credentials(
self, service, mock_db_session, mock_user
):
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "oauth2"
p.expires_at = -1
p.encrypted_credentials = {"t": "x"}
mock_db_session.scalars.return_value.all.return_value = [p]
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "_refresh_datasource_credentials") as refresh_credentials,
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}),
):
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
assert result == [{"t": "plain"}]
refresh_credentials.assert_not_called()
# -----------------------------------------------------------------------
# update_datasource_provider_name (lines 236-303)
# -----------------------------------------------------------------------