r2 transform

This commit is contained in:
jyong 2025-07-18 19:22:31 +08:00
parent dc359c6442
commit 34a6ed74b6
2 changed files with 49 additions and 4 deletions

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -24,7 +24,8 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -46,6 +47,28 @@ class DatasourceNode(BaseNode):
_node_data: DatasourceNodeData
_node_type = NodeType.DATASOURCE
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DatasourceNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> Generator:
"""
Run the datasource node

View File

@ -2,7 +2,7 @@ import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import func
@ -13,7 +13,8 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -38,6 +39,27 @@ class KnowledgeIndexNode(BaseNode):
_node_data: KnowledgeIndexNodeData
_node_type = NodeType.KNOWLEDGE_INDEX
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeIndexNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool