diff --git a/api/app.py b/api/app.py index 11decffe96..4f393f6c20 100644 --- a/api/app.py +++ b/api/app.py @@ -38,4 +38,4 @@ else: celery = app.extensions["celery"] if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001,debug=True) + app.run(host="0.0.0.0", port=5001) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e3285a16c7..a46071059f 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -127,9 +127,10 @@ class ToolBuiltinProviderUpdateApi(Resource): result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, + provider_name=provider, credentials=args["credentials"], credential_id=args["credential_id"], - name=args["name"] + name=args["name"], ) session.commit() return result diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 14ec4ebae0..a837552007 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,4 +1,3 @@ -import os from datetime import timedelta import pytz @@ -25,17 +24,12 @@ def init_app(app: DifyApp) -> Celery: }, } - - flask_debugging = os.environ.get("FLASK_DEBUG", "0").lower() in {"true", "1", "yes"} - celery_app = Celery( app.name, task_cls=FlaskTask, broker=dify_config.CELERY_BROKER_URL, backend=dify_config.CELERY_BACKEND, task_ignore_result=True, - task_always_eager=flask_debugging, - task_eager_propagates=flask_debugging, ) # Add SSL options to the Celery configuration diff --git a/api/models/tools.py b/api/models/tools.py index 4b493e7596..9e50cec52f 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -19,8 +19,11 @@ from .types import StringUUID class ToolProviderCredentialType(enum.StrEnum): - API_KEY = "api_key", - OAUTH2 = "oauth2", + API_KEY = "api_key" + OAUTH2 = "oauth2" + + def get_name(self): + return self.value.replace("_", " ").upper() def is_editable(self): return self == ToolProviderCredentialType.API_KEY diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 31bc2e650d..7dc3e4c0f8 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -118,12 +118,8 @@ class BuiltinToolManageService: if provider is None: raise ValueError(f"you have not added provider {provider_name}") - if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): - raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials") - try: - # exclude oauth2 provider - if provider.credential_type != ToolProviderCredentialType.OAUTH2.value: + if ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") @@ -139,11 +135,15 @@ class BuiltinToolManageService: credentials = BuiltinToolManageService._decrypt_and_restore_credentials( provider_controller, tool_configuration, provider, credentials ) - + # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( provider_controller, tool_configuration, provider, credentials, user_id ) + else: + raise ValueError( + f"provider {provider_name} is not editable, you can only delete it and add a new one" + ) # update name if provided if name is not None and provider.name != name: @@ -162,15 +162,60 @@ class BuiltinToolManageService: @staticmethod def add_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None + user_id: str, type: ToolProviderCredentialType, tenant_id: str, provider_name:str, credentials: dict, name: str | None = None ): """ add builtin tool provider """ - + if name is None: + name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, type) + + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + credential_type=type.value, + credentials=json.dumps(credentials), + name=name, + ) + + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider_name} does not need credentials") + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + # Encrypt and save the credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller, tool_configuration, provider, credentials, user_id + ) + db.session.add(provider) return {"result": "success"} + @staticmethod + def get_next_builtin_tool_provider_name(tenant_id: str, type: ToolProviderCredentialType) -> str: + """ + next name = max(provider_names) + 1 + """ + provider_names = db.session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, + credential_type=type.value, + ).all() + if not provider_names: + return f"{type.value} 1" + # OAuth 1 then OAuth 2, if don't have OAuth 1, then return OAuth 1 + # if dont have number, then get name and add 1 + for provider_name in provider_names: + if provider_name.provider.startswith(type.value): + return f"{type.value} {int(provider_name.provider.split(' ')[1]) + 1}" + return f"{type.value} 1" + + @staticmethod def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ @@ -416,7 +461,7 @@ class BuiltinToolManageService: def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials): """ Decrypt original credentials and restore masked values from the input credentials - + :param provider_controller: the provider controller :param tool_configuration: the tool configuration encrypter :param provider: the provider object from database @@ -425,19 +470,19 @@ class BuiltinToolManageService: """ original_credentials = tool_configuration.decrypt(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - + # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: # type: ignore credentials[name] = original_credentials[name] # type: ignore - + return credentials @staticmethod def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): """ Validate and encrypt credentials, then save to database - + :param provider_controller: the provider controller :param tool_configuration: the tool configuration encrypter :param provider: the provider object from database