mirror of https://github.com/langgenius/dify.git
feat: enhance datasource authentication by improving credential handling and updating API parameters
This commit is contained in:
parent
386d320650
commit
ba7f0b3004
|
|
@ -134,7 +134,7 @@ class DatasourceAuth(Resource):
|
|||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -149,8 +149,8 @@ class DatasourceAuth(Resource):
|
|||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -172,8 +172,8 @@ class DatasourceAuthDeleteApi(Resource):
|
|||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -186,19 +186,15 @@ class DatasourceAuthUpdateApi(Resource):
|
|||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
|
|
@ -223,7 +219,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
|
@ -231,7 +227,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enabled", False),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
|
@ -247,6 +243,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -266,6 +263,7 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -329,4 +327,4 @@ api.add_resource(
|
|||
api.add_resource(
|
||||
DatasourceUpdateProviderNameApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -496,9 +496,9 @@ class CredentialType(enum.StrEnum):
|
|||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
|
|
|||
|
|
@ -293,7 +293,6 @@ class DatasourceProviderService:
|
|||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
|
@ -351,7 +350,7 @@ class DatasourceProviderService:
|
|||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type.value
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
|
||||
)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
|
|
@ -387,7 +386,7 @@ class DatasourceProviderService:
|
|||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
with Session(db.engine) as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_api_key"
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
db_provider_name = name or self.generate_next_datasource_provider_name(
|
||||
session=session,
|
||||
|
|
@ -400,35 +399,36 @@ class DatasourceProviderService:
|
|||
if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
credential_valid = self.provider_manager.validate_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
credentials=credentials,
|
||||
)
|
||||
if credential_valid:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type="api_key"
|
||||
)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
datasource_provider = DatasourceProvider(
|
||||
try:
|
||||
self.provider_manager.validate_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
user_id=current_user.id,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
auth_type="api_key",
|
||||
encrypted_credentials=credentials,
|
||||
credentials=credentials,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise CredentialsValidateFailedError()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to validate credentials: {str(e)}")
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: str) -> list[str]:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
|
||||
)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
datasource_provider = DatasourceProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
auth_type=CredentialType.API_KEY.value,
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
db.session.commit()
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
|
|
@ -439,9 +439,9 @@ class DatasourceProviderService:
|
|||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
credential_form_schemas = []
|
||||
if credential_type == "api_key":
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
|
||||
elif credential_type == "oauth2":
|
||||
elif credential_type == CredentialType.OAUTH2:
|
||||
if not datasource_provider.declaration.oauth_schema:
|
||||
raise ValueError("Datasource provider oauth schema not found")
|
||||
credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
|
||||
|
|
@ -489,7 +489,7 @@ class DatasourceProviderService:
|
|||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
|
|
@ -526,8 +526,7 @@ class DatasourceProviderService:
|
|||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
|
||||
f"{datasource_provider_id}/datasource/callback"
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
datasource_credentials.append(
|
||||
{
|
||||
|
|
@ -559,7 +558,7 @@ class DatasourceProviderService:
|
|||
tenant_id, datasource_provider_id
|
||||
),
|
||||
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
|
||||
"redirect_uri": redirect_uri
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
if datasource.declaration.oauth_schema
|
||||
else None,
|
||||
|
|
@ -594,7 +593,7 @@ class DatasourceProviderService:
|
|||
credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
|
|
@ -654,7 +653,7 @@ class DatasourceProviderService:
|
|||
if key in self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=datasource_provider.auth_type,
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
):
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
|
||||
|
|
|
|||
Loading…
Reference in New Issue