Merge branch 'main' into feat/trigger

This commit is contained in:
Yeuoly 2025-11-12 17:04:31 +08:00
commit 6744306818
4 changed files with 271 additions and 15 deletions

View File

@ -113,6 +113,11 @@ class AppGenerateEntity(BaseModel):
inputs: Mapping[str, Any] inputs: Mapping[str, Any]
files: Sequence[File] files: Sequence[File]
# Unique identifier of the user initiating the execution.
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
#
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
user_id: str user_id: str
# extras # extras

View File

@ -1,15 +1,64 @@
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Engine from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent from core.workflow.graph_events.graph import GraphRunPausedEvent
from models.model import AppMode
from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
# Wrapper types for `WorkflowAppGenerateEntity` and
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
# and correct reconstruction of the entity field during (de)serialization.
class _WorkflowGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
entity: WorkflowAppGenerateEntity
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]
class WorkflowResumptionContext(BaseModel):
"""WorkflowResumptionContext captures all state necessary for resumption."""
version: Literal["1"] = "1"
# Only workflow / chatflow could be paused.
generate_entity: _GenerateEntityUnion
serialized_graph_runtime_state: str
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, value: str) -> Self:
return cls.model_validate_json(value)
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
return self.generate_entity.entity
class PauseStatePersistenceLayer(GraphEngineLayer): class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(self, session_factory: Engine | sessionmaker[Session], state_owner_user_id: str): def __init__(
self,
session_factory: Engine | sessionmaker[Session],
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
state_owner_user_id: str,
):
"""Create a PauseStatePersistenceLayer. """Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause. The `state_owner_user_id` is used when creating state file for pause.
@ -19,6 +68,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
session_factory = sessionmaker(session_factory) session_factory = sessionmaker(session_factory)
self._session_maker = session_factory self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
def _get_repo(self) -> APIWorkflowRunRepository: def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker) return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@ -49,13 +99,25 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
return return
assert self.graph_runtime_state is not None assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
else:
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
state = WorkflowResumptionContext(
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
generate_entity=entity_wrapper,
)
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None assert workflow_run_id is not None
repo = self._get_repo() repo = self._get_repo()
repo.create_workflow_pause( repo.create_workflow_pause(
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id, state_owner_user_id=self._state_owner_user_id,
state=self.graph_runtime_state.dumps(), state=state.dumps(),
) )
def on_graph_end(self, error: Exception | None) -> None: def on_graph_end(self, error: Exception | None) -> None:

View File

@ -25,7 +25,12 @@ import pytest
from sqlalchemy import Engine, delete, select from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import (
PauseStatePersistenceLayer,
WorkflowResumptionContext,
)
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
@ -39,7 +44,7 @@ from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models import Account from models import Account
from models import WorkflowPause as WorkflowPauseModel from models import WorkflowPause as WorkflowPauseModel
from models.model import UploadFile from models.model import AppMode, UploadFile
from models.workflow import Workflow, WorkflowRun from models.workflow import Workflow, WorkflowRun
from services.file_service import FileService from services.file_service import FileService
from services.workflow_run_service import WorkflowRunService from services.workflow_run_service import WorkflowRunService
@ -226,11 +231,39 @@ class TestPauseStatePersistenceLayerTestContainers:
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state) return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
def _create_generate_entity(
self,
workflow_execution_id: str | None = None,
user_id: str | None = None,
workflow_id: str | None = None,
) -> WorkflowAppGenerateEntity:
execution_id = workflow_execution_id or getattr(self, "test_workflow_run_id", str(uuid.uuid4()))
wf_id = workflow_id or getattr(self, "test_workflow_id", str(uuid.uuid4()))
tenant_id = getattr(self, "test_tenant_id", "tenant-123")
app_id = getattr(self, "test_app_id", "app-123")
app_config = WorkflowUIBasedAppConfig(
tenant_id=str(tenant_id),
app_id=str(app_id),
app_mode=AppMode.WORKFLOW,
workflow_id=str(wf_id),
)
return WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user_id or getattr(self, "test_user_id", str(uuid.uuid4())),
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=execution_id,
)
def _create_pause_state_persistence_layer( def _create_pause_state_persistence_layer(
self, self,
workflow_run: WorkflowRun | None = None, workflow_run: WorkflowRun | None = None,
workflow: Workflow | None = None, workflow: Workflow | None = None,
state_owner_user_id: str | None = None, state_owner_user_id: str | None = None,
generate_entity: WorkflowAppGenerateEntity | None = None,
) -> PauseStatePersistenceLayer: ) -> PauseStatePersistenceLayer:
"""Create PauseStatePersistenceLayer with real dependencies.""" """Create PauseStatePersistenceLayer with real dependencies."""
owner_id = state_owner_user_id owner_id = state_owner_user_id
@ -244,10 +277,23 @@ class TestPauseStatePersistenceLayerTestContainers:
assert owner_id is not None assert owner_id is not None
owner_id = str(owner_id) owner_id = str(owner_id)
workflow_execution_id = (
workflow_run.id if workflow_run is not None else getattr(self, "test_workflow_run_id", None)
)
assert workflow_execution_id is not None
workflow_id = workflow.id if workflow is not None else getattr(self, "test_workflow_id", None)
assert workflow_id is not None
entity_user_id = getattr(self, "test_user_id", owner_id)
entity = generate_entity or self._create_generate_entity(
workflow_execution_id=str(workflow_execution_id),
user_id=entity_user_id,
workflow_id=str(workflow_id),
)
return PauseStatePersistenceLayer( return PauseStatePersistenceLayer(
session_factory=self.session.get_bind(), session_factory=self.session.get_bind(),
state_owner_user_id=owner_id, state_owner_user_id=owner_id,
generate_entity=entity,
) )
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers): def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
@ -297,10 +343,15 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_model.resumed_at is None assert pause_model.resumed_at is None
storage_content = storage.load(pause_model.state_object_key).decode() storage_content = storage.load(pause_model.state_object_key).decode()
resumption_context = WorkflowResumptionContext.loads(storage_content)
assert resumption_context.version == "1"
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
expected_state = json.loads(graph_runtime_state.dumps()) expected_state = json.loads(graph_runtime_state.dumps())
actual_state = json.loads(storage_content) actual_state = json.loads(resumption_context.serialized_graph_runtime_state)
assert actual_state == expected_state assert actual_state == expected_state
persisted_entity = resumption_context.get_generate_entity()
assert isinstance(persisted_entity, WorkflowAppGenerateEntity)
assert persisted_entity.workflow_execution_id == self.test_workflow_run_id
def test_state_persistence_and_retrieval(self, db_session_with_containers): def test_state_persistence_and_retrieval(self, db_session_with_containers):
"""Test that pause state can be persisted and retrieved correctly.""" """Test that pause state can be persisted and retrieved correctly."""
@ -341,13 +392,15 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_entity.workflow_execution_id == self.test_workflow_run_id assert pause_entity.workflow_execution_id == self.test_workflow_run_id
state_bytes = pause_entity.get_state() state_bytes = pause_entity.get_state()
retrieved_state = json.loads(state_bytes.decode()) resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
retrieved_state = json.loads(resumption_context.serialized_graph_runtime_state)
expected_state = json.loads(graph_runtime_state.dumps()) expected_state = json.loads(graph_runtime_state.dumps())
assert retrieved_state == expected_state assert retrieved_state == expected_state
assert retrieved_state["outputs"] == complex_outputs assert retrieved_state["outputs"] == complex_outputs
assert retrieved_state["total_tokens"] == 250 assert retrieved_state["total_tokens"] == 250
assert retrieved_state["node_run_steps"] == 10 assert retrieved_state["node_run_steps"] == 10
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_database_transaction_handling(self, db_session_with_containers): def test_database_transaction_handling(self, db_session_with_containers):
"""Test that database transactions are handled correctly.""" """Test that database transactions are handled correctly."""
@ -410,7 +463,9 @@ class TestPauseStatePersistenceLayerTestContainers:
# Verify content in storage # Verify content in storage
storage_content = storage.load(pause_model.state_object_key).decode() storage_content = storage.load(pause_model.state_object_key).decode()
assert storage_content == graph_runtime_state.dumps() resumption_context = WorkflowResumptionContext.loads(storage_content)
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_workflow_with_different_creators(self, db_session_with_containers): def test_workflow_with_different_creators(self, db_session_with_containers):
"""Test pause state with workflows created by different users.""" """Test pause state with workflows created by different users."""
@ -474,6 +529,8 @@ class TestPauseStatePersistenceLayerTestContainers:
# Verify the state owner is the workflow creator # Verify the state owner is the workflow creator
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id) pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
assert pause_entity is not None assert pause_entity is not None
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id
def test_layer_ignores_non_pause_events(self, db_session_with_containers): def test_layer_ignores_non_pause_events(self, db_session_with_containers):
"""Test that layer ignores non-pause events.""" """Test that layer ignores non-pause events."""

View File

@ -4,7 +4,14 @@ from unittest.mock import Mock
import pytest import pytest
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import (
PauseStatePersistenceLayer,
WorkflowResumptionContext,
_AdvancedChatAppGenerateEntityWrapper,
_WorkflowGenerateEntityWrapper,
)
from core.variables.segments import Segment from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand from core.workflow.graph_engine.entities.commands import GraphEngineCommand
@ -15,6 +22,7 @@ from core.workflow.graph_events.graph import (
GraphRunSucceededEvent, GraphRunSucceededEvent,
) )
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
@ -170,6 +178,25 @@ class MockCommandChannel:
class TestPauseStatePersistenceLayer: class TestPauseStatePersistenceLayer:
"""Unit tests for PauseStatePersistenceLayer.""" """Unit tests for PauseStatePersistenceLayer."""
@staticmethod
def _create_generate_entity(workflow_execution_id: str = "run-123") -> WorkflowAppGenerateEntity:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-123",
app_id="app-123",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-123",
)
return WorkflowAppGenerateEntity(
task_id="task-123",
app_config=app_config,
inputs={},
files=[],
user_id="user-123",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=workflow_execution_id,
)
def test_init_with_dependency_injection(self): def test_init_with_dependency_injection(self):
session_factory = Mock(name="session_factory") session_factory = Mock(name="session_factory")
state_owner_user_id = "user-123" state_owner_user_id = "user-123"
@ -177,6 +204,7 @@ class TestPauseStatePersistenceLayer:
layer = PauseStatePersistenceLayer( layer = PauseStatePersistenceLayer(
session_factory=session_factory, session_factory=session_factory,
state_owner_user_id=state_owner_user_id, state_owner_user_id=state_owner_user_id,
generate_entity=self._create_generate_entity(),
) )
assert layer._session_maker is session_factory assert layer._session_maker is session_factory
@ -186,7 +214,11 @@ class TestPauseStatePersistenceLayer:
def test_initialize_sets_dependencies(self): def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory") session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner") layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner",
generate_entity=self._create_generate_entity(),
)
graph_runtime_state = MockReadOnlyGraphRuntimeState() graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel() command_channel = MockCommandChannel()
@ -198,7 +230,12 @@ class TestPauseStatePersistenceLayer:
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch): def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory") session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") generate_entity = self._create_generate_entity(workflow_execution_id="run-123")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=generate_entity,
)
mock_repo = Mock() mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo) mock_factory = Mock(return_value=mock_repo)
@ -221,12 +258,20 @@ class TestPauseStatePersistenceLayer:
mock_repo.create_workflow_pause.assert_called_once_with( mock_repo.create_workflow_pause.assert_called_once_with(
workflow_run_id="run-123", workflow_run_id="run-123",
state_owner_user_id="owner-123", state_owner_user_id="owner-123",
state=expected_state, state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
) )
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
resumption_context = WorkflowResumptionContext.loads(serialized_state)
assert resumption_context.serialized_graph_runtime_state == expected_state
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch): def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory") session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
mock_repo = Mock() mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo) mock_factory = Mock(return_value=mock_repo)
@ -250,7 +295,11 @@ class TestPauseStatePersistenceLayer:
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
session_factory = Mock(name="session_factory") session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
event = TestDataFactory.create_graph_run_paused_event() event = TestDataFactory.create_graph_run_paused_event()
@ -259,7 +308,11 @@ class TestPauseStatePersistenceLayer:
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory") session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123") layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
mock_repo = Mock() mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo) mock_factory = Mock(return_value=mock_repo)
@ -276,3 +329,82 @@ class TestPauseStatePersistenceLayer:
mock_factory.assert_not_called() mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called() mock_repo.create_workflow_pause.assert_not_called()
def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create a WorkflowAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-roundtrip",
app_id="app-roundtrip",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-roundtrip",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_WorkflowGenerateEntityWrapper(
entity=WorkflowAppGenerateEntity(
task_id="workflow-task",
app_config=app_config,
inputs={"input_key": "input_value"},
files=[],
user_id="user-roundtrip",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id="workflow-exec-roundtrip",
)
),
)
def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create an AdvancedChatAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-advanced",
app_id="app-advanced",
app_mode=AppMode.ADVANCED_CHAT,
workflow_id="workflow-advanced",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_AdvancedChatAppGenerateEntityWrapper(
entity=AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=app_config,
inputs={"topic": "roundtrip"},
files=[],
user_id="advanced-user",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_run_id="advanced-run-id",
query="Explain serialization behavior",
)
),
)
@pytest.mark.parametrize(
"state",
[
pytest.param(
_build_advanced_chat_generate_entity_for_roundtrip(),
id="advanced_chat",
),
pytest.param(
_build_workflow_generate_entity_for_roundtrip(),
id="workflow",
),
],
)
def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResumptionContext):
"""WorkflowResumptionContext roundtrip preserves workflow generate entity metadata."""
dumped = state.dumps()
loaded = WorkflowResumptionContext.loads(dumped)
assert loaded == state
assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state
restored_entity = loaded.get_generate_entity()
assert isinstance(restored_entity, type(state.generate_entity.entity))