mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
fix(api): gracefully handle credential fetch failures in rag pipeline (#36165)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d9ccfcbc6e
commit
e660d7af38
@ -49,6 +49,94 @@ class DatasourceProviderService:
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = PluginDatasourceManager()
|
||||
|
||||
@staticmethod
|
||||
def _should_refresh_credentials(datasource_provider: DatasourceProvider, now: int | None = None) -> bool:
|
||||
current_time = int(time.time()) if now is None else now
|
||||
if datasource_provider.expires_at == -1:
|
||||
return False
|
||||
return (datasource_provider.expires_at - 60) < current_time
|
||||
|
||||
def _refresh_datasource_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
datasource_provider: DatasourceProvider,
|
||||
current_user: Any,
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
credential_id = getattr(datasource_provider, "id", None)
|
||||
credential_name = getattr(datasource_provider, "name", None)
|
||||
logger.info(
|
||||
"Refreshing datasource credentials for provider %s",
|
||||
provider_name,
|
||||
extra={
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_id": datasource_provider_id.plugin_id,
|
||||
"provider": provider_name,
|
||||
"credential_id": credential_id,
|
||||
"credential_name": credential_name,
|
||||
"expires_at": datasource_provider.expires_at,
|
||||
},
|
||||
)
|
||||
decrypted_credentials = self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider=datasource_provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
|
||||
try:
|
||||
refreshed_credentials = OAuthHandler().refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
except Exception as exc:
|
||||
message = (
|
||||
f"Failed to refresh datasource credentials for provider {provider_name}"
|
||||
f" (credential: {credential_name or credential_id or 'unknown'})"
|
||||
)
|
||||
logger.exception(
|
||||
message,
|
||||
extra={
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_id": datasource_provider_id.plugin_id,
|
||||
"provider": provider_name,
|
||||
"credential_id": credential_id,
|
||||
"credential_name": credential_name,
|
||||
},
|
||||
)
|
||||
raise ValueError(f"{message}: {exc}") from exc
|
||||
encrypted_credentials = self.encrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
raw_credentials=refreshed_credentials.credentials,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
datasource_provider=datasource_provider,
|
||||
)
|
||||
logger.info(
|
||||
"Refreshed datasource credentials for provider %s",
|
||||
provider_name,
|
||||
extra={
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_id": datasource_provider_id.plugin_id,
|
||||
"provider": provider_name,
|
||||
"credential_id": credential_id,
|
||||
"credential_name": credential_name,
|
||||
"expires_at": refreshed_credentials.expires_at,
|
||||
},
|
||||
)
|
||||
return encrypted_credentials, refreshed_credentials.expires_at
|
||||
|
||||
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
|
||||
"""
|
||||
remove oauth custom client params
|
||||
@ -108,7 +196,10 @@ class DatasourceProviderService:
|
||||
credential_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
get credential by id
|
||||
Return decrypted datasource credentials.
|
||||
|
||||
If the stored credential is expired or about to expire, this method refreshes
|
||||
it through plugin-daemon and persists the refreshed credential before returning.
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
@ -130,39 +221,17 @@ class DatasourceProviderService:
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
# refresh the credentials
|
||||
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
|
||||
if self._should_refresh_credentials(datasource_provider):
|
||||
current_user = get_current_user()
|
||||
decrypted_credentials = self.decrypt_datasource_provider_credentials(
|
||||
encrypted_credentials, expires_at = self._refresh_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider=datasource_provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
|
||||
f"{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
|
||||
refreshed_credentials = OAuthHandler().refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
raw_credentials=refreshed_credentials.credentials,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
datasource_provider=datasource_provider,
|
||||
current_user=current_user,
|
||||
)
|
||||
datasource_provider.expires_at = refreshed_credentials.expires_at
|
||||
datasource_provider.encrypted_credentials = encrypted_credentials
|
||||
datasource_provider.expires_at = expires_at
|
||||
|
||||
return self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
@ -178,7 +247,10 @@ class DatasourceProviderService:
|
||||
plugin_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
get all datasource credentials by provider
|
||||
Return all decrypted datasource credentials for a provider.
|
||||
|
||||
Expired credentials are refreshed independently. A failed credential refresh is
|
||||
logged and skipped so one broken authorization does not block other credentials.
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = session.scalars(
|
||||
@ -193,46 +265,39 @@ class DatasourceProviderService:
|
||||
if not datasource_providers:
|
||||
return []
|
||||
current_user = get_current_user()
|
||||
# refresh the credentials
|
||||
real_credentials_list = []
|
||||
for datasource_provider in datasource_providers:
|
||||
decrypted_credentials = self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider=datasource_provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
|
||||
f"{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
|
||||
refreshed_credentials = OAuthHandler().refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
raw_credentials=refreshed_credentials.credentials,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
datasource_provider=datasource_provider,
|
||||
)
|
||||
datasource_provider.expires_at = refreshed_credentials.expires_at
|
||||
real_credentials = self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider=datasource_provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
real_credentials_list.append(real_credentials)
|
||||
try:
|
||||
if self._should_refresh_credentials(datasource_provider):
|
||||
encrypted_credentials, expires_at = self._refresh_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
datasource_provider=datasource_provider,
|
||||
current_user=current_user,
|
||||
)
|
||||
datasource_provider.encrypted_credentials = encrypted_credentials
|
||||
datasource_provider.expires_at = expires_at
|
||||
real_credentials = self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider=datasource_provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider,
|
||||
)
|
||||
real_credentials_list.append(real_credentials)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Skipping datasource credentials for provider %s after refresh or decrypt failure",
|
||||
provider,
|
||||
extra={
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_id": plugin_id,
|
||||
"provider": provider,
|
||||
"credential_id": getattr(datasource_provider, "id", None),
|
||||
"credential_name": getattr(datasource_provider, "name", None),
|
||||
"expires_at": getattr(datasource_provider, "expires_at", None),
|
||||
},
|
||||
)
|
||||
|
||||
return real_credentials_list
|
||||
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
import logging
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPipelineManageService:
|
||||
@staticmethod
|
||||
@ -15,9 +19,21 @@ class RagPipelineManageService:
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
for datasource in datasources:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
if credentials:
|
||||
datasource.is_authorized = True
|
||||
try:
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
if credentials:
|
||||
datasource.is_authorized = True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Skipping datasource credentials for provider %s after refresh or decrypt failure",
|
||||
datasource.provider,
|
||||
extra={
|
||||
"tenant_id": tenant_id,
|
||||
"plugin_id": datasource.plugin_id,
|
||||
"provider": datasource.provider,
|
||||
},
|
||||
)
|
||||
|
||||
return datasources
|
||||
|
||||
@ -243,7 +243,7 @@ class TestDatasourceProviderService:
|
||||
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
|
||||
|
||||
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
|
||||
"""Expired OAuth credential (expires_at near zero) triggers a silent refresh."""
|
||||
"""Expired OAuth credential (expires_at near zero) triggers a refresh."""
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0 # expired
|
||||
@ -256,6 +256,24 @@ class TestDatasourceProviderService:
|
||||
):
|
||||
service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
|
||||
def test_should_include_provider_name_when_refresh_fails(self, service, mock_db_session, mock_user):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.id = "cred-id"
|
||||
p.name = "Credential"
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0
|
||||
p.encrypted_credentials = {"tok": "x"}
|
||||
mock_db_session.scalar.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch("services.datasource_provider_service.OAuthHandler") as oauth_handler,
|
||||
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
|
||||
):
|
||||
oauth_handler.return_value.refresh_credentials.side_effect = RuntimeError("token endpoint failed")
|
||||
with pytest.raises(ValueError, match="provider prov"):
|
||||
service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
|
||||
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
|
||||
"""API key credentials with expires_at=-1 skip refresh and return directly."""
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@ -307,6 +325,51 @@ class TestDatasourceProviderService:
|
||||
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
|
||||
assert len(result) == 1
|
||||
|
||||
def test_should_skip_failed_provider_when_refreshing_all_credentials(
|
||||
self, service, mock_db_session, mock_user, caplog
|
||||
):
|
||||
failed_provider = MagicMock(spec=DatasourceProvider)
|
||||
failed_provider.id = "failed-cred"
|
||||
failed_provider.name = "Failed"
|
||||
failed_provider.auth_type = "oauth2"
|
||||
failed_provider.expires_at = 0
|
||||
working_provider = MagicMock(spec=DatasourceProvider)
|
||||
working_provider.id = "working-cred"
|
||||
working_provider.name = "Working"
|
||||
working_provider.auth_type = "oauth2"
|
||||
working_provider.expires_at = 0
|
||||
mock_db_session.scalars.return_value.all.return_value = [failed_provider, working_provider]
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(
|
||||
service,
|
||||
"_refresh_datasource_credentials",
|
||||
side_effect=[ValueError("refresh failed"), ({"t": "enc"}, 9999)],
|
||||
) as refresh_credentials,
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}),
|
||||
):
|
||||
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
|
||||
assert result == [{"t": "plain"}]
|
||||
assert refresh_credentials.call_count == 2
|
||||
assert "Skipping datasource credentials for provider prov" in caplog.text
|
||||
|
||||
def test_should_return_valid_credentials_without_refresh_when_getting_all_credentials(
|
||||
self, service, mock_db_session, mock_user
|
||||
):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = -1
|
||||
p.encrypted_credentials = {"t": "x"}
|
||||
mock_db_session.scalars.return_value.all.return_value = [p]
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "_refresh_datasource_credentials") as refresh_credentials,
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}),
|
||||
):
|
||||
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
|
||||
assert result == [{"t": "plain"}]
|
||||
refresh_credentials.assert_not_called()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# update_datasource_provider_name (lines 236-303)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user