mirror of https://github.com/langgenius/dify.git
This commit is contained in:
parent
0f10852b6b
commit
ec1c4efca9
|
|
@ -20,7 +20,7 @@ class DatasourcePlugin(ABC):
|
|||
self.runtime = runtime
|
||||
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
def datasource_provider_type(self) -> str:
|
||||
"""
|
||||
returns the type of the datasource provider
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
|||
|
||||
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
entity: DatasourceProviderEntityWithPlugin | None
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
|
|
|
|||
|
|
@ -24,5 +24,5 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
|||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
|
|
|||
|
|
@ -69,5 +69,5 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
|||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
|
|
|||
|
|
@ -49,5 +49,5 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
|||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> DatasourceProviderType:
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
|||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
self, entity: DatasourceProviderEntityWithPlugin | None, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from core.datasource.entities.datasource_entities import (
|
|||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetOnlineDocumentPagesResponse,
|
||||
GetWebsiteCrawlResponse,
|
||||
GetWebsiteCrawlResponse, DatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
|
|
@ -17,7 +17,7 @@ from core.plugin.impl.base import BasePluginClient
|
|||
|
||||
|
||||
class PluginDatasourceManager(BasePluginClient):
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]:
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
|
@ -46,12 +46,15 @@ class PluginDatasourceManager(BasePluginClient):
|
|||
# for datasource in provider.declaration.datasources:
|
||||
# datasource.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())]
|
||||
return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())]
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
|
|
@ -218,6 +221,7 @@ class PluginDatasourceManager(BasePluginClient):
|
|||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
|
|
@ -228,27 +232,48 @@ class PluginDatasourceManager(BasePluginClient):
|
|||
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "langgenius/file/file",
|
||||
"author": "langgenius",
|
||||
"name": "langgenius/file/file",
|
||||
"plugin_id": "langgenius/file",
|
||||
"provider": "langgenius",
|
||||
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"type": "datasource",
|
||||
"team_credentials": {},
|
||||
"is_team_authorization": False,
|
||||
"allow_delete": True,
|
||||
"datasources": [
|
||||
{
|
||||
"declaration": {
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "upload_file",
|
||||
"label": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"description": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File."},
|
||||
"name": "langgenius/file/file",
|
||||
"label": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
},
|
||||
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
|
||||
"description": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
}
|
||||
},
|
||||
"credentials_schema": [],
|
||||
"provider_type": "local_file",
|
||||
"datasources": [{
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "local_file",
|
||||
"provider": "langgenius",
|
||||
"label": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
}
|
||||
},
|
||||
"parameters": [],
|
||||
"labels": ["search"],
|
||||
"output_schema": None,
|
||||
}
|
||||
],
|
||||
"labels": ["search"],
|
||||
"description": {
|
||||
"zh_Hans": "File",
|
||||
"en_US": "File",
|
||||
"pt_BR": "File",
|
||||
"ja_JP": "File"
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from .tool_node import ToolNode
|
||||
from .datasource_node import DatasourceNode
|
||||
|
||||
__all__ = ["DatasourceNode"]
|
||||
|
|
|
|||
|
|
@ -40,14 +40,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||
|
||||
node_data = cast(DatasourceNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
if not datasource_type:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = datasource_type.value
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if not datasource_info:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info = datasource_info.value
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
|
||||
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
|
|
@ -84,47 +89,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||
)
|
||||
|
||||
try:
|
||||
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPageContentResponse = (
|
||||
datasource_runtime._get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: GetOnlineDocumentPageContentResponse = (
|
||||
datasource_runtime._get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"online_document": online_document_result.result.model_dump(),
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"online_document": online_document_result.result.model_dump(),
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
elif (
|
||||
datasource_runtime.datasource_provider_type in (
|
||||
DatasourceProviderType.WEBSITE_CRAWL,
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
)
|
||||
):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"website": datasource_info,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": datasource_info,
|
||||
"datasource_type": datasource_runtime.datasource_provider_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise DatasourceNodeError(
|
||||
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
|
@ -170,23 +183,24 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Literal, Union
|
||||
from typing import Any, Literal, Union, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
|
@ -9,30 +9,17 @@ from core.workflow.nodes.base.entities import BaseNodeData
|
|||
class DatasourceEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_name: str # redundancy
|
||||
datasource_name: str
|
||||
tool_label: str # redundancy
|
||||
datasource_configurations: dict[str, Any]
|
||||
provider_type: str
|
||||
datasource_name: Optional[str] = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: Optional[Union[Any, list[str]]] = None
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -51,4 +38,4 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
|||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
datasource_parameters: dict[str, DatasourceInput]
|
||||
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .entities import KnowledgeIndexNodeData
|
|||
from .exc import (
|
||||
KnowledgeIndexNodeError,
|
||||
)
|
||||
from ..base import BaseNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ default_retrieval_model = {
|
|||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(LLMNode):
|
||||
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
_node_data_cls = KnowledgeIndexNodeData # type: ignore
|
||||
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||
|
||||
|
|
@ -44,7 +45,7 @@ class KnowledgeIndexNode(LLMNode):
|
|||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error="Query variable is not object type.",
|
||||
error="Index chunk variable is not object type.",
|
||||
)
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode
|
|||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end import EndNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
|
|
@ -119,4 +121,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
|||
LATEST_VERSION: AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue