Merge branch 'main' into feat/trigger

This commit is contained in:
Yeuoly 2025-11-12 17:04:31 +08:00
commit 6744306818
4 changed files with 271 additions and 15 deletions

View File

@ -113,6 +113,11 @@ class AppGenerateEntity(BaseModel):
inputs: Mapping[str, Any]
files: Sequence[File]
# Unique identifier of the user initiating the execution.
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
#
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
user_id: str
# extras

View File

@ -1,15 +1,64 @@
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
from models.model import AppMode
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
# Wrapper types for `WorkflowAppGenerateEntity` and
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
# and correct reconstruction of the entity field during (de)serialization.
class _WorkflowGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
entity: WorkflowAppGenerateEntity
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]
class WorkflowResumptionContext(BaseModel):
"""WorkflowResumptionContext captures all state necessary for resumption."""
version: Literal["1"] = "1"
# Only workflow / chatflow could be paused.
generate_entity: _GenerateEntityUnion
serialized_graph_runtime_state: str
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, value: str) -> Self:
return cls.model_validate_json(value)
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
return self.generate_entity.entity
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(self, session_factory: Engine | sessionmaker[Session], state_owner_user_id: str):
def __init__(
self,
session_factory: Engine | sessionmaker[Session],
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
state_owner_user_id: str,
):
"""Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause.
@ -19,6 +68,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
session_factory = sessionmaker(session_factory)
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@ -49,13 +99,25 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
else:
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
state = WorkflowResumptionContext(
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
generate_entity=entity_wrapper,
)
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=self.graph_runtime_state.dumps(),
state=state.dumps(),
)
def on_graph_end(self, error: Exception | None) -> None:

View File

@ -25,7 +25,12 @@ import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import (
PauseStatePersistenceLayer,
WorkflowResumptionContext,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
@ -39,7 +44,7 @@ from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.model import UploadFile
from models.model import AppMode, UploadFile
from models.workflow import Workflow, WorkflowRun
from services.file_service import FileService
from services.workflow_run_service import WorkflowRunService
@ -226,11 +231,39 @@ class TestPauseStatePersistenceLayerTestContainers:
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
def _create_generate_entity(
self,
workflow_execution_id: str | None = None,
user_id: str | None = None,
workflow_id: str | None = None,
) -> WorkflowAppGenerateEntity:
execution_id = workflow_execution_id or getattr(self, "test_workflow_run_id", str(uuid.uuid4()))
wf_id = workflow_id or getattr(self, "test_workflow_id", str(uuid.uuid4()))
tenant_id = getattr(self, "test_tenant_id", "tenant-123")
app_id = getattr(self, "test_app_id", "app-123")
app_config = WorkflowUIBasedAppConfig(
tenant_id=str(tenant_id),
app_id=str(app_id),
app_mode=AppMode.WORKFLOW,
workflow_id=str(wf_id),
)
return WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user_id or getattr(self, "test_user_id", str(uuid.uuid4())),
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=execution_id,
)
def _create_pause_state_persistence_layer(
self,
workflow_run: WorkflowRun | None = None,
workflow: Workflow | None = None,
state_owner_user_id: str | None = None,
generate_entity: WorkflowAppGenerateEntity | None = None,
) -> PauseStatePersistenceLayer:
"""Create PauseStatePersistenceLayer with real dependencies."""
owner_id = state_owner_user_id
@ -244,10 +277,23 @@ class TestPauseStatePersistenceLayerTestContainers:
assert owner_id is not None
owner_id = str(owner_id)
workflow_execution_id = (
workflow_run.id if workflow_run is not None else getattr(self, "test_workflow_run_id", None)
)
assert workflow_execution_id is not None
workflow_id = workflow.id if workflow is not None else getattr(self, "test_workflow_id", None)
assert workflow_id is not None
entity_user_id = getattr(self, "test_user_id", owner_id)
entity = generate_entity or self._create_generate_entity(
workflow_execution_id=str(workflow_execution_id),
user_id=entity_user_id,
workflow_id=str(workflow_id),
)
return PauseStatePersistenceLayer(
session_factory=self.session.get_bind(),
state_owner_user_id=owner_id,
generate_entity=entity,
)
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
@ -297,10 +343,15 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_model.resumed_at is None
storage_content = storage.load(pause_model.state_object_key).decode()
resumption_context = WorkflowResumptionContext.loads(storage_content)
assert resumption_context.version == "1"
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
expected_state = json.loads(graph_runtime_state.dumps())
actual_state = json.loads(storage_content)
actual_state = json.loads(resumption_context.serialized_graph_runtime_state)
assert actual_state == expected_state
persisted_entity = resumption_context.get_generate_entity()
assert isinstance(persisted_entity, WorkflowAppGenerateEntity)
assert persisted_entity.workflow_execution_id == self.test_workflow_run_id
def test_state_persistence_and_retrieval(self, db_session_with_containers):
"""Test that pause state can be persisted and retrieved correctly."""
@ -341,13 +392,15 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
state_bytes = pause_entity.get_state()
retrieved_state = json.loads(state_bytes.decode())
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
retrieved_state = json.loads(resumption_context.serialized_graph_runtime_state)
expected_state = json.loads(graph_runtime_state.dumps())
assert retrieved_state == expected_state
assert retrieved_state["outputs"] == complex_outputs
assert retrieved_state["total_tokens"] == 250
assert retrieved_state["node_run_steps"] == 10
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_database_transaction_handling(self, db_session_with_containers):
"""Test that database transactions are handled correctly."""
@ -410,7 +463,9 @@ class TestPauseStatePersistenceLayerTestContainers:
# Verify content in storage
storage_content = storage.load(pause_model.state_object_key).decode()
assert storage_content == graph_runtime_state.dumps()
resumption_context = WorkflowResumptionContext.loads(storage_content)
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_workflow_with_different_creators(self, db_session_with_containers):
"""Test pause state with workflows created by different users."""
@ -474,6 +529,8 @@ class TestPauseStatePersistenceLayerTestContainers:
# Verify the state owner is the workflow creator
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
assert pause_entity is not None
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
"""Test that layer ignores non-pause events."""

View File

@ -4,7 +4,14 @@ from unittest.mock import Mock
import pytest
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import (
PauseStatePersistenceLayer,
WorkflowResumptionContext,
_AdvancedChatAppGenerateEntityWrapper,
_WorkflowGenerateEntityWrapper,
)
from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
@ -15,6 +22,7 @@ from core.workflow.graph_events.graph import (
GraphRunSucceededEvent,
)
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
@ -170,6 +178,25 @@ class MockCommandChannel:
class TestPauseStatePersistenceLayer:
"""Unit tests for PauseStatePersistenceLayer."""
@staticmethod
def _create_generate_entity(workflow_execution_id: str = "run-123") -> WorkflowAppGenerateEntity:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-123",
app_id="app-123",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-123",
)
return WorkflowAppGenerateEntity(
task_id="task-123",
app_config=app_config,
inputs={},
files=[],
user_id="user-123",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=workflow_execution_id,
)
def test_init_with_dependency_injection(self):
session_factory = Mock(name="session_factory")
state_owner_user_id = "user-123"
@ -177,6 +204,7 @@ class TestPauseStatePersistenceLayer:
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id=state_owner_user_id,
generate_entity=self._create_generate_entity(),
)
assert layer._session_maker is session_factory
@ -186,7 +214,11 @@ class TestPauseStatePersistenceLayer:
def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner",
generate_entity=self._create_generate_entity(),
)
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
@ -198,7 +230,12 @@ class TestPauseStatePersistenceLayer:
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
generate_entity = self._create_generate_entity(workflow_execution_id="run-123")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=generate_entity,
)
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
@ -221,12 +258,20 @@ class TestPauseStatePersistenceLayer:
mock_repo.create_workflow_pause.assert_called_once_with(
workflow_run_id="run-123",
state_owner_user_id="owner-123",
state=expected_state,
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
)
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
resumption_context = WorkflowResumptionContext.loads(serialized_state)
assert resumption_context.serialized_graph_runtime_state == expected_state
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
@ -250,7 +295,11 @@ class TestPauseStatePersistenceLayer:
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
event = TestDataFactory.create_graph_run_paused_event()
@ -259,7 +308,11 @@ class TestPauseStatePersistenceLayer:
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
@ -276,3 +329,82 @@ class TestPauseStatePersistenceLayer:
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()
def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create a WorkflowAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-roundtrip",
app_id="app-roundtrip",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-roundtrip",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_WorkflowGenerateEntityWrapper(
entity=WorkflowAppGenerateEntity(
task_id="workflow-task",
app_config=app_config,
inputs={"input_key": "input_value"},
files=[],
user_id="user-roundtrip",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id="workflow-exec-roundtrip",
)
),
)
def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create an AdvancedChatAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-advanced",
app_id="app-advanced",
app_mode=AppMode.ADVANCED_CHAT,
workflow_id="workflow-advanced",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_AdvancedChatAppGenerateEntityWrapper(
entity=AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=app_config,
inputs={"topic": "roundtrip"},
files=[],
user_id="advanced-user",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_run_id="advanced-run-id",
query="Explain serialization behavior",
)
),
)
@pytest.mark.parametrize(
"state",
[
pytest.param(
_build_advanced_chat_generate_entity_for_roundtrip(),
id="advanced_chat",
),
pytest.param(
_build_workflow_generate_entity_for_roundtrip(),
id="workflow",
),
],
)
def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResumptionContext):
"""WorkflowResumptionContext roundtrip preserves workflow generate entity metadata."""
dumped = state.dumps()
loaded = WorkflowResumptionContext.loads(dumped)
assert loaded == state
assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state
restored_entity = loaded.get_generate_entity()
assert isinstance(restored_entity, type(state.generate_entity.entity))