This commit is contained in:
jyong 2025-05-30 00:37:27 +08:00
parent cb5cfb2dae
commit 69529fb16d
6 changed files with 49 additions and 15 deletions

View File

@ -108,6 +108,18 @@ class DatasourceAuth(Resource):
raise ValueError(str(ex))
return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
@setup_required

View File

@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
is_authorized: bool = False
declaration: DatasourceProviderEntityWithPlugin

View File

@ -56,6 +56,7 @@ from .model import (
TraceAppConfig,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
from .provider import (
LoadBalancingModelConfig,
Provider,
@ -123,6 +124,8 @@ __all__ = [
"DatasetProcessRule",
"DatasetQuery",
"DatasetRetrieverResource",
"DatasourceOauthParamConfig",
"DatasourceProvider",
"DifySetup",
"Document",
"DocumentSegment",

View File

@ -28,6 +28,7 @@ class DatasourceProvider(Base):
db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)

View File

@ -92,7 +92,7 @@ class DatasourceProviderService:
return secret_input_form_variables
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> Optional[dict]:
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
get datasource credentials.
@ -102,22 +102,30 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id).first()
if not datasource_provider:
return None
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
plugin_id=plugin_id).all()
if not datasource_providers:
return []
copy_credentials_list = []
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
copy_credentials_list.append(
{
"credentials": copy_credentials,
"type": datasource_provider.auth_type,
}
)
return copy_credentials
return copy_credentials_list
def remove_datasource_credentials(self,

View File

@ -1,5 +1,6 @@
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
from core.plugin.impl.datasource import PluginDatasourceManager
from services.datasource_provider_service import DatasourceProviderService
class RagPipelineManageService:
@ -11,4 +12,12 @@ class RagPipelineManageService:
# get all builtin providers
manager = PluginDatasourceManager()
return manager.fetch_datasource_providers(tenant_id)
datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
provider=datasource.provider,
plugin_id=datasource.plugin_id)
if credentials:
datasource.is_authorized = True
return datasources