mirror of https://github.com/langgenius/dify.git
This commit is contained in:
parent
cb5cfb2dae
commit
69529fb16d
|
|
@ -108,6 +108,18 @@ class DatasourceAuth(Resource):
|
|||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, plugin_id):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class PluginDatasourceProviderEntity(BaseModel):
|
|||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
is_authorized: bool = False
|
||||
declaration: DatasourceProviderEntityWithPlugin
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from .model import (
|
|||
TraceAppConfig,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from .provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
|
|
@ -123,6 +124,8 @@ __all__ = [
|
|||
"DatasetProcessRule",
|
||||
"DatasetQuery",
|
||||
"DatasetRetrieverResource",
|
||||
"DatasourceOauthParamConfig",
|
||||
"DatasourceProvider",
|
||||
"DifySetup",
|
||||
"Document",
|
||||
"DocumentSegment",
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class DatasourceProvider(Base):
|
|||
db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"),
|
||||
)
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class DatasourceProviderService:
|
|||
return secret_input_form_variables
|
||||
|
||||
|
||||
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> Optional[dict]:
|
||||
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
|
||||
"""
|
||||
get datasource credentials.
|
||||
|
||||
|
|
@ -102,22 +102,30 @@ class DatasourceProviderService:
|
|||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider: DatasourceProvider | None = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
datasource_providers: list[DatasourceProvider] = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
if not datasource_provider:
|
||||
return None
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
plugin_id=plugin_id).all()
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
for datasource_provider in datasource_providers:
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
copy_credentials_list.append(
|
||||
{
|
||||
"credentials": copy_credentials,
|
||||
"type": datasource_provider.auth_type,
|
||||
}
|
||||
)
|
||||
|
||||
return copy_credentials
|
||||
return copy_credentials_list
|
||||
|
||||
|
||||
def remove_datasource_credentials(self,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
|
||||
class RagPipelineManageService:
|
||||
|
|
@ -11,4 +12,12 @@ class RagPipelineManageService:
|
|||
|
||||
# get all builtin providers
|
||||
manager = PluginDatasourceManager()
|
||||
return manager.fetch_datasource_providers(tenant_id)
|
||||
datasources = manager.fetch_datasource_providers(tenant_id)
|
||||
for datasource in datasources:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(tenant_id=tenant_id,
|
||||
provider=datasource.provider,
|
||||
plugin_id=datasource.plugin_id)
|
||||
if credentials:
|
||||
datasource.is_authorized = True
|
||||
return datasources
|
||||
|
|
|
|||
Loading…
Reference in New Issue