feat: refactor datasource authentication APIs for improved credential management

This commit is contained in:
Harry 2025-07-21 16:43:50 +08:00
parent 57b48f51b5
commit 17da96bdd8
2 changed files with 86 additions and 34 deletions

View File

@ -149,33 +149,40 @@ class DatasourceAuth(Resource):
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_id: str, auth_id: str):
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
class DatasourceAuthUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, provider_id: str, auth_id: str):
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
@ -183,10 +190,11 @@ class DatasourceAuthUpdateDeleteApi(Resource):
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
credentials=args["credentials"],
credentials=args.get("credentials", {}),
name=args.get("name", None),
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@ -228,6 +236,17 @@ class DatasourceAuthOauthCustomClient(Resource):
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
class DatasourceAuthDefaultApi(Resource):
@setup_required
@ -237,14 +256,14 @@ class DatasourceAuthDefaultApi(Resource):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["credential_id"],
credential_id=args["id"],
)
return {"result": "success"}, 200
@ -284,8 +303,13 @@ api.add_resource(
)
api.add_resource(
DatasourceAuthUpdateDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/<string:auth_id>",
DatasourceAuthUpdateApi,
"/auth/plugin/datasource/<path:provider_id>/update",
)
api.add_resource(
DatasourceAuthDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/delete",
)
api.add_resource(

View File

@ -4,6 +4,7 @@ from typing import Any
from flask_login import current_user
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
@ -29,6 +30,18 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
"""
get default credentials
@ -512,6 +525,10 @@ class DatasourceProviderService:
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
datasource_credentials.append(
{
"provider": datasource.provider,
@ -542,6 +559,7 @@ class DatasourceProviderService:
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
"redirect_uri": redirect_uri
}
if datasource.declaration.oauth_schema
else None,
@ -594,38 +612,50 @@ class DatasourceProviderService:
return copy_credentials_list
def update_datasource_credentials(
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
) -> None:
"""
update datasource credentials.
"""
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=credentials,
)
if credential_valid:
# Get all provider configurations of the current workspace
with Session(db.engine) as session:
datasource_provider = (
db.session.query(DatasourceProvider)
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
else:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=datasource_provider.auth_type,
)
# update name
if name and name != datasource_provider.name:
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
.count()
> 0
):
raise ValueError("name is already exists")
datasource_provider.name = name
# update credentials
if credentials:
try:
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
original_credentials = datasource_provider.encrypted_credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if key in self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=datasource_provider.auth_type,
):
if value == HIDDEN_VALUE and key in original_credentials:
original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
@ -633,9 +663,7 @@ class DatasourceProviderService:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider.encrypted_credentials = credentials
db.session.commit()
else:
raise CredentialsValidateFailedError()
session.commit()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""