Merge branch 'main' into feat/step-one-refactor

This commit is contained in:
Coding On Star 2025-12-24 14:08:07 +08:00 committed by GitHub
commit cf61393e58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 44 additions and 11 deletions

View File

@ -572,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],

View File

@ -396,7 +396,7 @@ class IndexingRunner:
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],

View File

@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
if not self._notion_access_token:
try:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
except Exception as e:
logger.warning(
(
"Failed to get Notion access token from datasource credentials: %s, "
"falling back to environment variable NOTION_INTEGRATION_TOKEN"
),
e,
)
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
)
) from e
self._notion_access_token = integration_token

View File

@ -247,6 +247,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -1,3 +1,4 @@
from enum import StrEnum
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
Get current output.
"""
return self.current_output
class LoopCompletedReason(StrEnum):
LOOP_BREAK = "loop_break"
LOOP_COMPLETED = "loop_completed"

View File

@ -29,7 +29,7 @@ from core.workflow.node_events import (
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
# Start Loop event
yield LoopStartedEvent(
@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_count = 0
for i in range(loop_count):
# Clear stale variables from previous loop iterations to avoid streaming old values
self._clear_loop_subgraph_variables(loop_node_ids)
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
LoopCompletedReason.LOOP_BREAK
if reach_break_condition
else LoopCompletedReason.LOOP_COMPLETED.value
),
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
"""
Remove variables produced by loop sub-graph nodes from previous iterations.
Keeping stale variables causes a freshly created response coordinator in the
next iteration to fall back to outdated values when no stream chunks exist.
"""
variable_pool = self.graph_runtime_state.variable_pool
for node_id in loop_node_ids:
variable_pool.remove([node_id])
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -96,7 +96,7 @@ class TestNotionExtractorAuthentication:
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
"""Test NotionExtractor falls back to integration token when credential not found."""
# Arrange
mock_get_token.return_value = None
mock_get_token.side_effect = Exception("No credential id found")
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
# Act
@ -105,7 +105,7 @@ class TestNotionExtractorAuthentication:
notion_obj_id="page-456",
notion_page_type="page",
tenant_id="tenant-789",
credential_id="cred-123",
credential_id=None,
document_model=mock_document_model,
)
@ -117,7 +117,7 @@ class TestNotionExtractorAuthentication:
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
"""Test NotionExtractor raises error when no credentials available."""
# Arrange
mock_get_token.return_value = None
mock_get_token.side_effect = Exception("No credential id found")
mock_config.NOTION_INTEGRATION_TOKEN = None
# Act & Assert
@ -127,7 +127,7 @@ class TestNotionExtractorAuthentication:
notion_obj_id="page-456",
notion_page_type="page",
tenant_id="tenant-789",
credential_id="cred-123",
credential_id=None,
document_model=mock_document_model,
)
assert "Must specify `integration_token`" in str(exc_info.value)