mirror of https://github.com/langgenius/dify.git
feat: refactor datasource authentication APIs for improved credential management
This commit is contained in:
parent
57b48f51b5
commit
17da96bdd8
|
|
@ -149,33 +149,40 @@ class DatasourceAuth(Resource):
|
|||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
|
||||
class DatasourceAuthUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str, auth_id: str):
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=auth_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider_id: str, auth_id: str):
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
|
@ -183,10 +190,11 @@ class DatasourceAuthUpdateDeleteApi(Resource):
|
|||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=auth_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args["credentials"],
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
|
@ -228,6 +236,17 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -237,14 +256,14 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["credential_id"],
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -284,8 +303,13 @@ api.add_resource(
|
|||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateDeleteApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/<string:auth_id>",
|
||||
DatasourceAuthUpdateApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDeleteApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/delete",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any
|
|||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper import encrypter
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
|
|
@ -29,6 +30,18 @@ class DatasourceProviderService:
|
|||
def __init__(self) -> None:
|
||||
self.provider_manager = PluginDatasourceManager()
|
||||
|
||||
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
|
||||
"""
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
get default credentials
|
||||
|
|
@ -512,6 +525,10 @@ class DatasourceProviderService:
|
|||
credentials = self.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
|
||||
f"{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
datasource_credentials.append(
|
||||
{
|
||||
"provider": datasource.provider,
|
||||
|
|
@ -542,6 +559,7 @@ class DatasourceProviderService:
|
|||
tenant_id, datasource_provider_id
|
||||
),
|
||||
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
|
||||
"redirect_uri": redirect_uri
|
||||
}
|
||||
if datasource.declaration.oauth_schema
|
||||
else None,
|
||||
|
|
@ -594,38 +612,50 @@ class DatasourceProviderService:
|
|||
return copy_credentials_list
|
||||
|
||||
def update_datasource_credentials(
|
||||
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict
|
||||
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
update datasource credentials.
|
||||
"""
|
||||
credential_valid = self.provider_manager.validate_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credentials=credentials,
|
||||
)
|
||||
if credential_valid:
|
||||
# Get all provider configurations of the current workspace
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not datasource_provider:
|
||||
raise ValueError("Datasource provider not found")
|
||||
else:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
)
|
||||
# update name
|
||||
if name and name != datasource_provider.name:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
raise ValueError("name is already exists")
|
||||
datasource_provider.name = name
|
||||
|
||||
# update credentials
|
||||
if credentials:
|
||||
try:
|
||||
self.provider_manager.validate_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to validate credentials: {str(e)}")
|
||||
|
||||
original_credentials = datasource_provider.encrypted_credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if key in self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
):
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
|
||||
credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
|
||||
|
|
@ -633,9 +663,7 @@ class DatasourceProviderService:
|
|||
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
|
||||
datasource_provider.encrypted_credentials = credentials
|
||||
db.session.commit()
|
||||
else:
|
||||
raise CredentialsValidateFailedError()
|
||||
session.commit()
|
||||
|
||||
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue