mirror of https://github.com/langgenius/dify.git
This commit is contained in:
parent
b9ab1555fb
commit
9437145218
|
|
@ -78,3 +78,68 @@ class DatasourceToolProviderController(BuiltinToolProviderController):
|
|||
)
|
||||
for datasource_entity in self.entity.datasources
|
||||
]
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.datasource.__base.datasource import Datasource
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.plugin.manager.datasource import PluginDatasourceManager
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
||||
|
||||
class DatasourcePlugin(Datasource):
|
||||
|
|
@ -16,11 +18,14 @@ class DatasourcePlugin(Datasource):
|
|||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[DatasourceParameter]]
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
self, entity: DatasourceEntity, runtime: DatasourceRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
|
@ -34,7 +39,7 @@ class DatasourcePlugin(Datasource):
|
|||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
|
@ -54,7 +59,7 @@ class DatasourcePlugin(Datasource):
|
|||
user_id: str,
|
||||
datasource_parameters: dict[str, Any],
|
||||
rag_pipeline_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class ApiProviderAuthType(Enum):
|
|||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
class DatasourceInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
|
|
@ -200,7 +200,7 @@ class ToolInvokeMessage(BaseModel):
|
|||
return v
|
||||
|
||||
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
class DatasourceInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue