feat(api): Implement truncation for WorkflowNodeExecution

This commit is contained in:
QuantumGhost 2025-08-29 14:49:09 +08:00
parent 2fd337e610
commit 6b9d2e98b9
10 changed files with 366 additions and 46 deletions

View File

@ -29,6 +29,7 @@ from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from extensions.ext_storage import storage
from models import App, Message, WorkflowNodeExecutionModel, db
logger = logging.getLogger(__name__)
@ -443,7 +444,9 @@ class LLMGenerator:
) -> dict:
from services.workflow_service import WorkflowService
app: App | None = db.session.query(App).where(App.id == flow_id).first()
session = db.session()
app: App | None = session.query(App).where(App.id == flow_id).first()
if not app:
raise ValueError("App not found.")
workflow = WorkflowService().get_draft_workflow(app_model=app)
@ -487,8 +490,9 @@ class LLMGenerator:
return [dict_of_event(event) for event in parsed]
inputs = last_run.load_full_inputs(session, storage)
last_run_dict = {
"inputs": last_run.inputs_dict,
"inputs": inputs,
"status": last_run.status,
"error": last_run.error,
"agent_log": agent_log_of(last_run),

View File

@ -283,7 +283,7 @@ class AliyunDataTrace(BaseTraceInstance):
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run

View File

@ -2,15 +2,18 @@
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
"""
import dataclasses
import json
import logging
from collections.abc import Sequence
from typing import Optional, Union
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional, TypeVar, Union
from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
@ -20,7 +23,9 @@ from core.workflow.entities.workflow_node_execution import (
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_storage import storage
from libs.helper import extract_tenant_id
from libs.uuid_utils import uuidv7
from models import (
Account,
CreatorUserRole,
@ -28,10 +33,22 @@ from models import (
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
)
from models.enums import ExecutionOffLoadType
from models.model import UploadFile
from models.workflow import WorkflowNodeExecutionOffload
from services.file_service import FileService
from services.variable_truncator import VariableTruncator
logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class _InputsOutputsTruncationResult:
truncated_value: Mapping[str, Any]
file: UploadFile
offload: WorkflowNodeExecutionOffload
class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
@ -48,7 +65,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
app_id: str,
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
):
"""
@ -82,6 +99,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
self._user = user # Store the user object directly
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
@ -90,17 +108,30 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
# Initialize FileService for handling offloaded data
self._file_service = FileService(session_factory)
def _create_truncator(self) -> VariableTruncator:
return VariableTruncator(
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH,
)
def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
"""
Convert a database model to a domain model.
This requires the offload_data, and correspond inputs_file and outputs_file are preloaded.
Args:
db_model: The database model to convert
db_model: The database model to convert. It must have `offload_data`
and the corresponding `inputs_file` and `outputs_file` preloaded.
Returns:
The domain model
"""
# Parse JSON fields
# Parse JSON fields - these might be truncated versions
inputs = db_model.inputs_dict
process_data = db_model.process_data_dict
outputs = db_model.outputs_dict
@ -109,7 +140,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Convert status to domain enum
status = WorkflowNodeExecutionStatus(db_model.status)
return WorkflowNodeExecution(
domain_model = WorkflowNodeExecution(
id=db_model.id,
node_execution_id=db_model.node_execution_id,
workflow_id=db_model.workflow_id,
@ -130,15 +161,52 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
finished_at=db_model.finished_at,
)
def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel:
if not db_model.offload_data:
return domain_model
offload_data = db_model.offload_data
# Store truncated versions for API responses
# TODO: consider load content concurrently.
input_offload = _find_first(offload_data, _filter_by_offload_type(ExecutionOffLoadType.INPUTS))
if input_offload is not None:
assert input_offload.file is not None
domain_model.inputs = self._load_file(input_offload.file)
domain_model.set_truncated_inputs(inputs)
outputs_offload = _find_first(offload_data, _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS))
if outputs_offload is not None:
assert outputs_offload.file is not None
domain_model.outputs = self._load_file(outputs_offload.file)
domain_model.set_truncated_outputs(outputs)
process_data_offload = _find_first(offload_data, _filter_by_offload_type(ExecutionOffLoadType.PROCESS_DATA))
if process_data_offload is not None:
assert process_data_offload.file is not None
domain_model.process_data = self._load_file(process_data_offload.file)
domain_model.set_truncated_process_data(process_data)
return domain_model
def _load_file(self, file: UploadFile) -> Mapping[str, Any]:
content = storage.load(file.key)
return json.loads(content)
@staticmethod
def _json_encode(values: Mapping[str, Any]) -> str:
json_converter = WorkflowRuntimeTypeConverter()
return json.dumps(json_converter.to_json_encodable(values))
def _to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel:
"""
Convert a domain model to a database model.
Convert a domain model to a database model. This copies the inputs /
process_data / outputs from domain model directly without applying truncation.
Args:
domain_model: The domain model to convert
Returns:
The database model
The database model, without setting inputs, process_data and outputs fields.
"""
# Use values from constructor if provided
if not self._triggered_from:
@ -148,7 +216,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
json_converter = WorkflowRuntimeTypeConverter()
converter = WorkflowRuntimeTypeConverter()
# json_converter = WorkflowRuntimeTypeConverter()
db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
@ -164,16 +234,21 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.node_type = domain_model.node_type
db_model.title = domain_model.title
db_model.inputs = (
json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None
_deterministic_json_dump(converter.to_json_encodable(domain_model.inputs))
if domain_model.inputs is not None
else None
)
db_model.process_data = (
json.dumps(json_converter.to_json_encodable(domain_model.process_data))
if domain_model.process_data
_deterministic_json_dump(converter.to_json_encodable(domain_model.process_data))
if domain_model.process_data is not None
else None
)
db_model.outputs = (
json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None
_deterministic_json_dump(converter.to_json_encodable(domain_model.outputs))
if domain_model.outputs is not None
else None
)
# inputs, process_data and outputs are handled below
db_model.status = domain_model.status
db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time
@ -184,17 +259,59 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.created_by_role = self._creator_user_role
db_model.created_by = self._creator_user_id
db_model.finished_at = domain_model.finished_at
return db_model
def _truncate_and_upload(
self,
values: Mapping[str, Any] | None,
execution_id: str,
type_: ExecutionOffLoadType,
) -> _InputsOutputsTruncationResult | None:
if values is None:
return None
converter = WorkflowRuntimeTypeConverter()
json_encodable_value = converter.to_json_encodable(values)
truncator = self._create_truncator()
truncated_values, truncated = truncator.truncate_io_mapping(json_encodable_value)
if not truncated:
return None
value_json = _deterministic_json_dump(json_encodable_value)
assert value_json is not None, "value_json should be None here."
suffix = type_.value
upload_file = self._file_service.upload_file(
filename=f"node_execution_{execution_id}_{suffix}.json",
content=value_json.encode("utf-8"),
mimetype="application/json",
user=self._user,
)
offload = WorkflowNodeExecutionOffload(
id=uuidv7(),
tenant_id=self._tenant_id,
app_id=self._app_id,
node_execution_id=execution_id,
type_=type_,
file_id=upload_file.id,
)
return _InputsOutputsTruncationResult(
truncated_value=truncated_values,
file=upload_file,
offload=offload,
)
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution domain entity to the database.
This method serves as a domain-to-database adapter that:
1. Converts the domain entity to its database representation
2. Persists the database model using SQLAlchemy's merge operation
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Updates the in-memory cache for faster subsequent lookups
2. Handles truncation and offloading of large inputs/outputs
3. Persists the database model using SQLAlchemy's merge operation
4. Maintains proper multi-tenancy by including tenant context during conversion
5. Updates the in-memory cache for faster subsequent lookups
The method handles both creating new records and updating existing ones through
SQLAlchemy's merge operation.
@ -202,8 +319,20 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
Args:
execution: The NodeExecution domain entity to persist
"""
# NOTE: As per the implementation of `WorkflowCycleManager`,
# the `save` method is invoked multiple times during the node's execution lifecycle, including:
#
# - When the node starts execution
# - When the node retries execution
# - When the node completes execution (either successfully or with failure)
#
# Only the final invocation will have `inputs` and `outputs` populated.
#
# This simplifies the logic for saving offloaded variables but introduces a tight coupling
# between this module and `WorkflowCycleManager`.
# Convert domain model to database model using tenant context and other attributes
db_model = self.to_db_model(execution)
db_model = self._to_db_model(execution)
# Create a new database session
with self._session_factory() as session:
@ -218,6 +347,66 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
self._node_execution_cache[db_model.node_execution_id] = db_model
def save_execution_data(self, execution: WorkflowNodeExecution):
domain_model = execution
with self._session_factory(expire_on_commit=False) as session:
query = WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel)).where(
WorkflowNodeExecutionModel.id == domain_model.id
)
db_model: WorkflowNodeExecutionModel | None = session.execute(query).scalars().first()
if db_model is not None:
offload_data = db_model.offload_data
else:
db_model = self._to_db_model(domain_model)
offload_data = []
offload_data = db_model.offload_data
if domain_model.inputs is not None:
result = self._truncate_and_upload(
domain_model.inputs,
domain_model.id,
ExecutionOffLoadType.INPUTS,
)
if result is not None:
db_model.inputs = self._json_encode(result.truncated_value)
domain_model.set_truncated_inputs(result.truncated_value)
offload_data = _replace_or_append_offload(offload_data, result.offload)
else:
db_model.inputs = self._json_encode(domain_model.inputs)
if domain_model.outputs is not None:
result = self._truncate_and_upload(
domain_model.outputs,
domain_model.id,
ExecutionOffLoadType.OUTPUTS,
)
if result is not None:
db_model.outputs = self._json_encode(result.truncated_value)
domain_model.set_truncated_outputs(result.truncated_value)
offload_data = _replace_or_append_offload(offload_data, result.offload)
else:
db_model.outputs = self._json_encode(domain_model.outputs)
if domain_model.process_data is not None:
result = self._truncate_and_upload(
domain_model.process_data,
domain_model.id,
ExecutionOffLoadType.PROCESS_DATA,
)
if result is not None:
db_model.process_data = self._json_encode(result.truncated_value)
domain_model.set_truncated_process_data(result.truncated_value)
offload_data = _replace_or_append_offload(offload_data, result.offload)
else:
db_model.process_data = self._json_encode(domain_model.process_data)
db_model.offload_data = offload_data
with self._session_factory() as session, session.begin():
session.merge(db_model)
session.flush()
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
@ -226,6 +415,9 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
"""
Retrieve all WorkflowNodeExecution database models for a specific workflow run.
The returned models have `offload_data` preloaded, along with the associated
`inputs_file` and `outputs_file` data.
This method directly returns database models without converting to domain models,
which is useful when you need to access database-specific fields like triggered_from.
It also updates the in-memory cache with the retrieved models.
@ -240,7 +432,8 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
A list of WorkflowNodeExecution database models
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecutionModel).where(
stmt = WorkflowNodeExecutionModel.preload_offload_data_and_files(select(WorkflowNodeExecutionModel))
stmt = stmt.where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
@ -296,10 +489,46 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Get the database models using the new method
db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config)
# Convert database models to domain models
domain_models = []
for model in db_models:
domain_model = self._to_domain_model(model)
domain_models.append(domain_model)
with ThreadPoolExecutor(max_workers=10) as executor:
domain_models = executor.map(self._to_domain_model, db_models, timeout=30)
return domain_models
return list(domain_models)
def _deterministic_json_dump(value: Mapping[str, Any]) -> str:
return json.dumps(value, sort_keys=True)
_T = TypeVar("_T")
def _find_first(seq: Sequence[_T], pred: Callable[[_T], bool]) -> _T | None:
filtered = [i for i in seq if pred(i)]
if filtered:
return filtered[0]
return None
def _filter_by_offload_type(offload_type: ExecutionOffLoadType) -> Callable[[WorkflowNodeExecutionOffload], bool]:
def f(offload: WorkflowNodeExecutionOffload) -> bool:
return offload.type_ == offload_type
return f
def _replace_or_append_offload(
seq: list[WorkflowNodeExecutionOffload], elem: WorkflowNodeExecutionOffload
) -> list[WorkflowNodeExecutionOffload]:
"""Replace all elements in `seq` that satisfy the equality condition defined by `eq_func` with `elem`.
Args:
seq: The sequence of elements to process.
elem: The new element to insert.
eq_func: A function that determines equality between elements.
Returns:
A new sequence with the specified elements replaced or appended.
"""
ls = [i for i in seq if i.type_ != elem.type_]
ls.append(elem)
return ls

View File

@ -11,7 +11,7 @@ from datetime import datetime
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from core.workflow.nodes.enums import NodeType
@ -90,6 +90,7 @@ class WorkflowNodeExecution(BaseModel):
title: str # Display title of the node
# Execution data
# The `inputs` and `outputs` fields hold the full content
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
@ -106,6 +107,58 @@ class WorkflowNodeExecution(BaseModel):
created_at: datetime # When execution started
finished_at: Optional[datetime] = None # When execution completed
_truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None)
def get_truncated_inputs(self) -> Mapping[str, Any] | None:
return self._truncated_inputs
def get_truncated_outputs(self) -> Mapping[str, Any] | None:
return self._truncated_outputs
def get_truncated_process_data(self) -> Mapping[str, Any] | None:
return self._truncated_process_data
def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None):
self._truncated_inputs = truncated_inputs
def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None):
self._truncated_outputs = truncated_outputs
def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None):
self._truncated_process_data = truncated_process_data
def get_response_inputs(self) -> Mapping[str, Any] | None:
inputs = self.get_truncated_inputs()
if inputs:
return inputs
return self.inputs
@property
def inputs_truncated(self):
return self._truncated_inputs is not None
@property
def outputs_truncated(self):
return self._truncated_outputs is not None
@property
def process_data_truncated(self):
return self._truncated_process_data is not None
def get_response_outputs(self) -> Mapping[str, Any] | None:
outputs = self.get_truncated_outputs()
if outputs is not None:
return outputs
return self.outputs
def get_response_process_data(self) -> Mapping[str, Any] | None:
process_data = self.get_truncated_process_data()
if process_data is not None:
return process_data
return self.process_data
def update_from_mapping(
self,
inputs: Optional[Mapping[str, Any]] = None,

View File

@ -30,6 +30,12 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
Save or update a NodeExecution instance.
This method saves all data on the `WorkflowNodeExecution` object, except for `inputs`, `process_data`,
and `outputs`. Its primary purpose is to persist the status and various metadata, such as execution time
and execution-related details.
It's main purpose is to save the status and various metadata (execution time, execution metadata etc.)
This method handles both creating new records and updating existing ones.
The implementation should determine whether to create or update based on
the execution's ID or other identifying fields.
@ -39,6 +45,14 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def save_execution_data(self, execution: WorkflowNodeExecution):
"""Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
If any of the inputs, process_data, or outputs are None, those fields will not be updated.
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,

View File

@ -188,6 +188,7 @@ class WorkflowCycleManager:
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
@ -220,6 +221,7 @@ class WorkflowCycleManager:
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_retried(
@ -242,7 +244,9 @@ class WorkflowCycleManager:
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
return self._save_and_cache_node_execution(domain_execution)
execution = self._save_and_cache_node_execution(domain_execution)
self._workflow_node_execution_repository.save_execution_data(execution)
return execution
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
@ -275,7 +279,10 @@ class WorkflowCycleManager:
return execution
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
"""Save node execution to repository and cache it if it has an ID."""
"""Save node execution to repository and cache it if it has an ID.
This does not persist the `inputs` / `process_data` / `outputs` fields of the execution model.
"""
self._workflow_node_execution_repository.save(execution)
if execution.node_execution_id:
self._node_execution_cache[execution.node_execution_id] = execution

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from decimal import Decimal
from typing import Any
from typing import Any, overload
from pydantic import BaseModel
@ -9,6 +9,11 @@ from core.variables import Segment
class WorkflowRuntimeTypeConverter:
@overload
def to_json_encodable(self, value: Mapping[str, Any]) -> Mapping[str, Any]: ...
@overload
def to_json_encodable(self, value: None) -> None: ...
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
return result if isinstance(result, Mapping) or result is None else dict(result)

View File

@ -116,6 +116,9 @@ workflow_run_node_execution_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"finished_at": TimestampField,
"inputs_truncated": fields.Boolean,
"outputs_truncated": fields.Boolean,
"process_data_truncated": fields.Boolean,
}
workflow_run_node_execution_list_fields = {

View File

@ -63,11 +63,14 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
node_id: The node identifier
Returns:
The most recent WorkflowNodeExecutionModel for the node, or None if not found
The most recent WorkflowNodeExecutionModel for the node, or None if not found.
The returned WorkflowNodeExecutionModel will have `offload_data` preloaded.
"""
stmt = select(WorkflowNodeExecutionModel)
stmt = WorkflowNodeExecutionModel.preload_offload_data(stmt)
stmt = (
select(WorkflowNodeExecutionModel)
.where(
stmt.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_id == workflow_id,
@ -100,15 +103,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
Returns:
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
"""
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(desc(WorkflowNodeExecutionModel.index))
)
stmt = WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel))
stmt = stmt.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
).order_by(desc(WorkflowNodeExecutionModel.index))
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
@ -135,7 +135,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
Returns:
The WorkflowNodeExecutionModel if found, or None if not found
"""
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
stmt = WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel))
stmt = stmt.where(WorkflowNodeExecutionModel.id == execution_id)
# Add tenant filtering if provided
if tenant_id is not None:

View File

@ -31,6 +31,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models.account import Account
@ -425,6 +426,9 @@ class WorkflowService:
if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(db.engine) as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
@ -435,7 +439,7 @@ class WorkflowService:
node_execution_id=node_execution.id,
user=account,
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
session.commit()
return workflow_node_execution