mirror of https://github.com/langgenius/dify.git
feat(oauth): add credential validation for providers
This commit is contained in:
parent
0dc5bfb2c7
commit
ef330fec2c
|
|
@ -95,9 +95,7 @@ class BuiltinToolManageService:
|
|||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name: str, credential_type: CredentialType, tenant_id: str
|
||||
):
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
|
|
@ -141,7 +139,8 @@ class BuiltinToolManageService:
|
|||
if key in masked_credentials and value == masked_credentials[key]:
|
||||
credentials[key] = original_credentials[key]
|
||||
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
|
||||
# encrypt credentials
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
||||
|
|
@ -159,6 +158,7 @@ class BuiltinToolManageService:
|
|||
ToolNotFoundError,
|
||||
ToolProviderCredentialValidationError,
|
||||
) as e:
|
||||
db.session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
@ -176,46 +176,59 @@ class BuiltinToolManageService:
|
|||
add builtin tool provider
|
||||
"""
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
# check if the provider count is over the limit
|
||||
provider_count = (
|
||||
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||
)
|
||||
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
||||
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
||||
try:
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
# TODO should we get name from oauth authentication?
|
||||
name = (
|
||||
name
|
||||
if name
|
||||
else BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||
tenant_id=tenant_id, provider=provider, credential_type=api_type
|
||||
provider_count = (
|
||||
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||
)
|
||||
)
|
||||
|
||||
db_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
credential_type=api_type.value,
|
||||
name=name,
|
||||
)
|
||||
# check if the provider count is reached the limit
|
||||
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
||||
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
# validate credentials if allowed
|
||||
if CredentialType.of(api_type).is_validate_allowed():
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
|
||||
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, db_provider, provider, provider_controller
|
||||
)
|
||||
# generate name if not provided
|
||||
if name is None:
|
||||
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||
tenant_id=tenant_id, provider=provider, credential_type=api_type
|
||||
)
|
||||
|
||||
# encrypt credentials
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
||||
# create encrypter
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(api_type)
|
||||
],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
cache.delete()
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
db_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
credential_type=api_type.value,
|
||||
name=name,
|
||||
)
|
||||
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
except (
|
||||
PluginDaemonClientSideError,
|
||||
ToolProviderNotFoundError,
|
||||
ToolNotFoundError,
|
||||
ToolProviderCredentialValidationError,
|
||||
) as e:
|
||||
db.session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -236,9 +249,7 @@ class BuiltinToolManageService:
|
|||
return encrypter, cache
|
||||
|
||||
@staticmethod
|
||||
def generate_builtin_tool_provider_name(
|
||||
tenant_id: str, provider: str, credential_type: CredentialType
|
||||
) -> str:
|
||||
def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str:
|
||||
try:
|
||||
db_providers = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
|
|
@ -324,7 +335,7 @@ class BuiltinToolManageService:
|
|||
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
|
||||
return credential_info
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -362,8 +373,8 @@ class BuiltinToolManageService:
|
|||
|
||||
# clear default provider
|
||||
session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id, user_id=user_id, provider=provider, default=True
|
||||
).update({"default": False})
|
||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||
).update({"is_default": False})
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
|
|
|
|||
Loading…
Reference in New Issue