mirror of https://github.com/langgenius/dify.git
feat: refactor DatasourceNode and KnowledgeIndexNode to use _node_data attribute
This commit is contained in:
parent
32fe8313b4
commit
f325662141
|
|
@ -16,7 +16,8 @@ from sqlalchemy.orm import sessionmaker
|
|||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
|
|
|
|||
|
|
@ -38,12 +38,12 @@ from .entities import DatasourceNodeData
|
|||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||
class DatasourceNode(BaseNode):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data_cls = DatasourceNodeData
|
||||
_node_data: DatasourceNodeData
|
||||
_node_type = NodeType.DATASOURCE
|
||||
|
||||
def _run(self) -> Generator:
|
||||
|
|
@ -51,7 +51,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = cast(DatasourceNodeData, self.node_data)
|
||||
node_data = cast(DatasourceNodeData, self._node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||
if not datasource_type:
|
||||
|
|
@ -90,12 +90,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||
parameters = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
node_data=self._node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
datasource_parameters=datasource_parameters,
|
||||
variable_pool=variable_pool,
|
||||
node_data=self.node_data,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
|
|
@ -421,7 +421,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
if self._node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
|
|
|
|||
|
|
@ -34,12 +34,12 @@ default_retrieval_model = {
|
|||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||
_node_data_cls = KnowledgeIndexNodeData # type: ignore
|
||||
class KnowledgeIndexNode(BaseNode):
|
||||
_node_data: KnowledgeIndexNodeData
|
||||
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||
node_data = cast(KnowledgeIndexNodeData, self._node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
|
|
|
|||
Loading…
Reference in New Issue