mirror of https://github.com/langgenius/dify.git
fix: remove debugging flags
This commit is contained in:
parent
b3a8dbe2f5
commit
7f292dc261
|
|
@ -38,4 +38,4 @@ else:
|
|||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001,debug=True)
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
|
|
|||
|
|
@ -127,9 +127,10 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
name=args["name"]
|
||||
name=args["name"],
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
import pytz
|
||||
|
|
@ -25,17 +24,12 @@ def init_app(app: DifyApp) -> Celery:
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
flask_debugging = os.environ.get("FLASK_DEBUG", "0").lower() in {"true", "1", "yes"}
|
||||
|
||||
celery_app = Celery(
|
||||
app.name,
|
||||
task_cls=FlaskTask,
|
||||
broker=dify_config.CELERY_BROKER_URL,
|
||||
backend=dify_config.CELERY_BACKEND,
|
||||
task_ignore_result=True,
|
||||
task_always_eager=flask_debugging,
|
||||
task_eager_propagates=flask_debugging,
|
||||
)
|
||||
|
||||
# Add SSL options to the Celery configuration
|
||||
|
|
|
|||
|
|
@ -19,8 +19,11 @@ from .types import StringUUID
|
|||
|
||||
|
||||
class ToolProviderCredentialType(enum.StrEnum):
|
||||
API_KEY = "api_key",
|
||||
OAUTH2 = "oauth2",
|
||||
API_KEY = "api_key"
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
def get_name(self):
|
||||
return self.value.replace("_", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == ToolProviderCredentialType.API_KEY
|
||||
|
|
|
|||
|
|
@ -118,12 +118,8 @@ class BuiltinToolManageService:
|
|||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable():
|
||||
raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials")
|
||||
|
||||
try:
|
||||
# exclude oauth2 provider
|
||||
if provider.credential_type != ToolProviderCredentialType.OAUTH2.value:
|
||||
if ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable():
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
|
|
@ -139,11 +135,15 @@ class BuiltinToolManageService:
|
|||
credentials = BuiltinToolManageService._decrypt_and_restore_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials
|
||||
)
|
||||
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials, user_id
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"provider {provider_name} is not editable, you can only delete it and add a new one"
|
||||
)
|
||||
|
||||
# update name if provided
|
||||
if name is not None and provider.name != name:
|
||||
|
|
@ -162,15 +162,60 @@ class BuiltinToolManageService:
|
|||
|
||||
@staticmethod
|
||||
def add_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None
|
||||
user_id: str, type: ToolProviderCredentialType, tenant_id: str, provider_name:str, credentials: dict, name: str | None = None
|
||||
):
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
|
||||
if name is None:
|
||||
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, type)
|
||||
|
||||
provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
credential_type=type.value,
|
||||
credentials=json.dumps(credentials),
|
||||
name=name,
|
||||
)
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials, user_id
|
||||
)
|
||||
db.session.add(provider)
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_next_builtin_tool_provider_name(tenant_id: str, type: ToolProviderCredentialType) -> str:
|
||||
"""
|
||||
next name = max(provider_names) + 1
|
||||
"""
|
||||
provider_names = db.session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
credential_type=type.value,
|
||||
).all()
|
||||
if not provider_names:
|
||||
return f"{type.value} 1"
|
||||
# OAuth 1 then OAuth 2, if don't have OAuth 1, then return OAuth 1
|
||||
# if dont have number, then get name and add 1
|
||||
for provider_name in provider_names:
|
||||
if provider_name.provider.startswith(type.value):
|
||||
return f"{type.value} {int(provider_name.provider.split(' ')[1]) + 1}"
|
||||
return f"{type.value} 1"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
|
||||
"""
|
||||
|
|
@ -416,7 +461,7 @@ class BuiltinToolManageService:
|
|||
def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials):
|
||||
"""
|
||||
Decrypt original credentials and restore masked values from the input credentials
|
||||
|
||||
|
||||
:param provider_controller: the provider controller
|
||||
:param tool_configuration: the tool configuration encrypter
|
||||
:param provider: the provider object from database
|
||||
|
|
@ -425,19 +470,19 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
|
||||
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
|
||||
credentials[name] = original_credentials[name] # type: ignore
|
||||
|
||||
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
|
||||
"""
|
||||
Validate and encrypt credentials, then save to database
|
||||
|
||||
|
||||
:param provider_controller: the provider controller
|
||||
:param tool_configuration: the tool configuration encrypter
|
||||
:param provider: the provider object from database
|
||||
|
|
|
|||
Loading…
Reference in New Issue