mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
test: migrate SQLAlchemyWorkflowNodeExecutionRepository tests to testcontainers (#34926)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1703df5c00
commit
98d3bcd079
@ -0,0 +1,395 @@
|
||||
"""Testcontainers integration tests for SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig
|
||||
from models.account import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def _create_account_with_tenant(session: Session) -> Account:
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
account = Account(name="test", email=f"test-{uuid4()}@example.com")
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
|
||||
def _make_repo(session: Session, account: Account, app_id: str) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
engine = session.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=sessionmaker(bind=engine, expire_on_commit=False),
|
||||
user=account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def _create_node_execution_model(
|
||||
session: Session,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
workflow_run_id: str,
|
||||
index: int = 1,
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
model = WorkflowNodeExecutionModel(
|
||||
id=str(uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=index,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=str(uuid4()),
|
||||
node_id=f"node-{index}",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title=f"Test Node {index}",
|
||||
inputs='{"input_key": "input_value"}',
|
||||
process_data='{"process_key": "process_value"}',
|
||||
outputs='{"output_key": "output_value"}',
|
||||
status=status,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
execution_metadata="{}",
|
||||
created_at=datetime.now(),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
finished_at=None,
|
||||
)
|
||||
session.add(model)
|
||||
session.flush()
|
||||
return model
|
||||
|
||||
|
||||
class TestSave:
|
||||
def test_save_new_record(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-1",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs={"input_key": "input_value"},
|
||||
process_data={"process_key": "process_value"},
|
||||
outputs={"result": "success"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100},
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session:
|
||||
saved = verify_session.get(WorkflowNodeExecutionModel, execution.id)
|
||||
assert saved is not None
|
||||
assert saved.tenant_id == account.current_tenant_id
|
||||
assert saved.app_id == app_id
|
||||
assert saved.node_id == "node-1"
|
||||
assert saved.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
def test_save_updates_existing_record(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
repo = _make_repo(db_session_with_containers, account, str(uuid4()))
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_id="node-1",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
metadata=None,
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
execution.elapsed_time = 2.5
|
||||
repo.save(execution)
|
||||
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session:
|
||||
saved = verify_session.get(WorkflowNodeExecutionModel, execution.id)
|
||||
assert saved is not None
|
||||
assert saved.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert saved.elapsed_time == 2.5
|
||||
|
||||
|
||||
class TestGetByWorkflowExecution:
|
||||
def test_returns_executions_ordered(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
tenant_id = account.current_tenant_id
|
||||
app_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=2,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repo.get_by_workflow_execution(
|
||||
workflow_execution_id=workflow_run_id,
|
||||
order_config=order_config,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 2
|
||||
assert result[1].index == 1
|
||||
assert all(isinstance(r, WorkflowNodeExecution) for r in result)
|
||||
|
||||
def test_excludes_paused_executions(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
tenant_id = account.current_tenant_id
|
||||
app_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
)
|
||||
_create_node_execution_model(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
index=2,
|
||||
status=WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = repo.get_by_workflow_execution(workflow_execution_id=workflow_run_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].index == 1
|
||||
|
||||
|
||||
class TestToDbModel:
|
||||
def test_converts_domain_to_db_model(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
domain_model = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
index=1,
|
||||
predecessor_node_id="test-predecessor-id",
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs={"input_key": "input_value"},
|
||||
process_data={"process_key": "process_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"),
|
||||
},
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
db_model = repo._to_db_model(domain_model)
|
||||
|
||||
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
||||
assert db_model.id == domain_model.id
|
||||
assert db_model.tenant_id == account.current_tenant_id
|
||||
assert db_model.app_id == app_id
|
||||
assert db_model.workflow_id == domain_model.workflow_id
|
||||
assert db_model.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert db_model.workflow_run_id == domain_model.workflow_execution_id
|
||||
assert db_model.index == domain_model.index
|
||||
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
|
||||
assert db_model.node_execution_id == domain_model.node_execution_id
|
||||
assert db_model.node_id == domain_model.node_id
|
||||
assert db_model.node_type == domain_model.node_type
|
||||
assert db_model.title == domain_model.title
|
||||
assert db_model.inputs_dict == domain_model.inputs
|
||||
assert db_model.process_data_dict == domain_model.process_data
|
||||
assert db_model.outputs_dict == domain_model.outputs
|
||||
assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata)
|
||||
assert db_model.status == domain_model.status
|
||||
assert db_model.error == domain_model.error
|
||||
assert db_model.elapsed_time == domain_model.elapsed_time
|
||||
assert db_model.created_at == domain_model.created_at
|
||||
assert db_model.created_by_role == CreatorUserRole.ACCOUNT
|
||||
assert db_model.created_by == account.id
|
||||
assert db_model.finished_at == domain_model.finished_at
|
||||
|
||||
|
||||
class TestToDomainModel:
|
||||
def test_converts_db_to_domain_model(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
app_id = str(uuid4())
|
||||
repo = _make_repo(db_session_with_containers, account, app_id)
|
||||
|
||||
inputs_dict = {"input_key": "input_value"}
|
||||
process_data_dict = {"process_key": "process_value"}
|
||||
outputs_dict = {"output_key": "output_value"}
|
||||
metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100}
|
||||
now = datetime.now()
|
||||
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "test-id"
|
||||
db_model.tenant_id = account.current_tenant_id
|
||||
db_model.app_id = app_id
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.triggered_from = "workflow-run"
|
||||
db_model.workflow_run_id = "test-workflow-run-id"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = "test-predecessor-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = BuiltinNodeTypes.START
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps(inputs_dict)
|
||||
db_model.process_data = json.dumps(process_data_dict)
|
||||
db_model.outputs = json.dumps(outputs_dict)
|
||||
db_model.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.execution_metadata = json.dumps(metadata_dict)
|
||||
db_model.created_at = now
|
||||
db_model.created_by_role = "account"
|
||||
db_model.created_by = account.id
|
||||
db_model.finished_at = None
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert isinstance(domain_model, WorkflowNodeExecution)
|
||||
assert domain_model.id == "test-id"
|
||||
assert domain_model.workflow_id == "test-workflow-id"
|
||||
assert domain_model.workflow_execution_id == "test-workflow-run-id"
|
||||
assert domain_model.index == 1
|
||||
assert domain_model.predecessor_node_id == "test-predecessor-id"
|
||||
assert domain_model.node_execution_id == "test-node-execution-id"
|
||||
assert domain_model.node_id == "test-node-id"
|
||||
assert domain_model.node_type == BuiltinNodeTypes.START
|
||||
assert domain_model.title == "Test Node"
|
||||
assert domain_model.inputs == inputs_dict
|
||||
assert domain_model.process_data == process_data_dict
|
||||
assert domain_model.outputs == outputs_dict
|
||||
assert domain_model.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
assert domain_model.error is None
|
||||
assert domain_model.elapsed_time == 1.5
|
||||
assert domain_model.metadata == {WorkflowNodeExecutionMetadataKey(k): v for k, v in metadata_dict.items()}
|
||||
assert domain_model.created_at == now
|
||||
assert domain_model.finished_at is None
|
||||
|
||||
def test_domain_model_without_offload_data(self, db_session_with_containers: Session) -> None:
|
||||
account = _create_account_with_tenant(db_session_with_containers)
|
||||
repo = _make_repo(db_session_with_containers, account, str(uuid4()))
|
||||
|
||||
process_data = {"normal": "data"}
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = str(uuid4())
|
||||
db_model.tenant_id = account.current_tenant_id
|
||||
db_model.app_id = str(uuid4())
|
||||
db_model.workflow_id = str(uuid4())
|
||||
db_model.triggered_from = "workflow-run"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_execution_id = str(uuid4())
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = None
|
||||
db_model.process_data = json.dumps(process_data)
|
||||
db_model.outputs = None
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.created_by_role = "account"
|
||||
db_model.created_by = account.id
|
||||
db_model.finished_at = None
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
assert domain_model.process_data == process_data
|
||||
assert domain_model.process_data_truncated is False
|
||||
assert domain_model.get_truncated_process_data() is None
|
||||
@ -1,3 +0,0 @@
|
||||
"""
|
||||
Unit tests for workflow_node_execution repositories.
|
||||
"""
|
||||
@ -1,340 +0,0 @@
|
||||
"""
|
||||
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, PropertyMock
|
||||
|
||||
import pytest
|
||||
from graphon.entities import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig
|
||||
from models.account import Account, Tenant
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
def configure_mock_execution(mock_execution):
|
||||
"""Configure a mock execution with proper JSON serializable values."""
|
||||
# Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values
|
||||
type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}')
|
||||
type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}')
|
||||
type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}')
|
||||
type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}')
|
||||
|
||||
# Configure status and triggered_from to be valid enum values
|
||||
mock_execution.status = "running"
|
||||
mock_execution.triggered_from = "workflow-run"
|
||||
|
||||
return mock_execution
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
"""Create a mock SQLAlchemy session."""
|
||||
session = MagicMock(spec=Session)
|
||||
# Configure the session to be used as a context manager
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Configure the session factory to return the session
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session_factory.return_value = session
|
||||
return session, session_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Create a user instance for testing."""
|
||||
user = Account(name="test", email="test@example.com")
|
||||
user.id = "test-user-id"
|
||||
|
||||
tenant = Tenant(name="Test Workspace")
|
||||
tenant.id = "test-tenant"
|
||||
user._current_tenant = MagicMock()
|
||||
user._current_tenant.id = "test-tenant"
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(session, mock_user):
|
||||
"""Create a repository instance with test data."""
|
||||
_, session_factory = session
|
||||
app_id = "test-app"
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_save(repository, session):
|
||||
"""Test save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.id = "test-id"
|
||||
execution.node_execution_id = "test-node-execution-id"
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
execution.workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
db_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
repository._to_db_model = MagicMock(return_value=db_model)
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
session_obj.get.return_value = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id)
|
||||
|
||||
# Assert session.add was called for new record
|
||||
session_obj.add.assert_called_once_with(db_model)
|
||||
|
||||
# Assert session.commit was called
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_with_existing_tenant_id(repository, session):
|
||||
"""Test save method with existing tenant_id."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution with existing tenant_id
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = "existing-id"
|
||||
execution.node_execution_id = "existing-node-execution-id"
|
||||
execution.tenant_id = "existing-tenant"
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
|
||||
# Create a modified execution that will be returned by _to_db_model
|
||||
modified_execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
modified_execution.id = "existing-id"
|
||||
modified_execution.node_execution_id = "existing-node-execution-id"
|
||||
modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change
|
||||
modified_execution.app_id = repository._app_id # App ID should be set
|
||||
# Create a dictionary to simulate __dict__ for updating attributes
|
||||
modified_execution.__dict__ = {
|
||||
"id": "existing-id",
|
||||
"node_execution_id": "existing-node-execution-id",
|
||||
"tenant_id": "existing-tenant",
|
||||
"app_id": repository._app_id,
|
||||
}
|
||||
|
||||
# Mock the to_db_model method to return the modified execution
|
||||
repository._to_db_model = MagicMock(return_value=modified_execution)
|
||||
|
||||
# Mock session.get to return an existing record
|
||||
existing_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
session_obj.get.return_value = existing_model
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id)
|
||||
|
||||
# Assert session.add was NOT called since we're updating existing
|
||||
session_obj.add.assert_not_called()
|
||||
|
||||
# Assert session.commit was called
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_by_workflow_execution(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_workflow_execution method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_asc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.asc")
|
||||
mock_desc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.desc")
|
||||
|
||||
mock_WorkflowNodeExecutionModel = mocker.patch(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel"
|
||||
)
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_asc.return_value = mock_stmt
|
||||
mock_desc.return_value = mock_stmt
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
configure_mock_execution(mock_execution)
|
||||
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||
|
||||
# Create a mock domain model to be returned by _to_domain_model
|
||||
mock_domain_model = mocker.MagicMock()
|
||||
# Mock the _to_domain_model method to return our mock domain model
|
||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||
|
||||
# Call method
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repository.get_by_workflow_execution(
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
order_config=order_config,
|
||||
)
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result contains our mock domain model
|
||||
assert len(result) == 1
|
||||
assert result[0] is mock_domain_model
|
||||
|
||||
|
||||
def test_to_db_model(repository):
|
||||
"""Test to_db_model method."""
|
||||
# Create a domain model
|
||||
domain_model = WorkflowNodeExecution(
|
||||
id="test-id",
|
||||
workflow_id="test-workflow-id",
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_execution_id="test-workflow-run-id",
|
||||
index=1,
|
||||
predecessor_node_id="test-predecessor-id",
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs={"input_key": "input_value"},
|
||||
process_data={"process_key": "process_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
error=None,
|
||||
elapsed_time=1.5,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"),
|
||||
},
|
||||
created_at=datetime.now(),
|
||||
finished_at=None,
|
||||
)
|
||||
|
||||
# Convert to DB model
|
||||
db_model = repository._to_db_model(domain_model)
|
||||
|
||||
# Assert DB model has correct values
|
||||
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
||||
assert db_model.id == domain_model.id
|
||||
assert db_model.tenant_id == repository._tenant_id
|
||||
assert db_model.app_id == repository._app_id
|
||||
assert db_model.workflow_id == domain_model.workflow_id
|
||||
assert db_model.triggered_from == repository._triggered_from
|
||||
assert db_model.workflow_run_id == domain_model.workflow_execution_id
|
||||
assert db_model.index == domain_model.index
|
||||
assert db_model.predecessor_node_id == domain_model.predecessor_node_id
|
||||
assert db_model.node_execution_id == domain_model.node_execution_id
|
||||
assert db_model.node_id == domain_model.node_id
|
||||
assert db_model.node_type == domain_model.node_type
|
||||
assert db_model.title == domain_model.title
|
||||
|
||||
assert db_model.inputs_dict == domain_model.inputs
|
||||
assert db_model.process_data_dict == domain_model.process_data
|
||||
assert db_model.outputs_dict == domain_model.outputs
|
||||
assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata)
|
||||
|
||||
assert db_model.status == domain_model.status
|
||||
assert db_model.error == domain_model.error
|
||||
assert db_model.elapsed_time == domain_model.elapsed_time
|
||||
assert db_model.created_at == domain_model.created_at
|
||||
assert db_model.created_by_role == repository._creator_user_role
|
||||
assert db_model.created_by == repository._creator_user_id
|
||||
assert db_model.finished_at == domain_model.finished_at
|
||||
|
||||
|
||||
def test_to_domain_model(repository):
|
||||
"""Test _to_domain_model method."""
|
||||
# Create input dictionaries
|
||||
inputs_dict = {"input_key": "input_value"}
|
||||
process_data_dict = {"process_key": "process_value"}
|
||||
outputs_dict = {"output_key": "output_value"}
|
||||
metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100}
|
||||
|
||||
# Create a DB model using our custom subclass
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "test-id"
|
||||
db_model.tenant_id = "test-tenant-id"
|
||||
db_model.app_id = "test-app-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.triggered_from = "workflow-run"
|
||||
db_model.workflow_run_id = "test-workflow-run-id"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = "test-predecessor-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = BuiltinNodeTypes.START
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps(inputs_dict)
|
||||
db_model.process_data = json.dumps(process_data_dict)
|
||||
db_model.outputs = json.dumps(outputs_dict)
|
||||
db_model.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.execution_metadata = json.dumps(metadata_dict)
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.created_by_role = "account"
|
||||
db_model.created_by = "test-user-id"
|
||||
db_model.finished_at = None
|
||||
|
||||
# Convert to domain model
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Assert domain model has correct values
|
||||
assert isinstance(domain_model, WorkflowNodeExecution)
|
||||
assert domain_model.id == db_model.id
|
||||
assert domain_model.workflow_id == db_model.workflow_id
|
||||
assert domain_model.workflow_execution_id == db_model.workflow_run_id
|
||||
assert domain_model.index == db_model.index
|
||||
assert domain_model.predecessor_node_id == db_model.predecessor_node_id
|
||||
assert domain_model.node_execution_id == db_model.node_execution_id
|
||||
assert domain_model.node_id == db_model.node_id
|
||||
assert domain_model.node_type == db_model.node_type
|
||||
assert domain_model.title == db_model.title
|
||||
assert domain_model.inputs == inputs_dict
|
||||
assert domain_model.process_data == process_data_dict
|
||||
assert domain_model.outputs == outputs_dict
|
||||
assert domain_model.status == WorkflowNodeExecutionStatus(db_model.status)
|
||||
assert domain_model.error == db_model.error
|
||||
assert domain_model.elapsed_time == db_model.elapsed_time
|
||||
assert domain_model.metadata == metadata_dict
|
||||
assert domain_model.created_at == db_model.created_at
|
||||
assert domain_model.finished_at == db_model.finished_at
|
||||
@ -1,106 +0,0 @@
|
||||
"""
|
||||
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
||||
"""Test process_data truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
def create_mock_account(self) -> Account:
|
||||
"""Create a mock Account for testing."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-user-id"
|
||||
account.tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
def create_mock_session_factory(self) -> sessionmaker:
|
||||
"""Create a mock session factory for testing."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_session_factory.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.return_value.__exit__.return_value = None
|
||||
return mock_session_factory
|
||||
|
||||
def create_repository(self, mock_file_service=None) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""Create a repository instance for testing."""
|
||||
mock_account = self.create_mock_account()
|
||||
mock_session_factory = self.create_mock_session_factory()
|
||||
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if mock_file_service:
|
||||
repository._file_service = mock_file_service
|
||||
|
||||
return repository
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution instance for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model without offload data."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create mock database model without offload data
|
||||
db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-execution-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.finished_at = None
|
||||
|
||||
process_data = {"normal": "data"}
|
||||
db_model.process_data_dict = process_data
|
||||
db_model.inputs_dict = None
|
||||
db_model.outputs_dict = None
|
||||
db_model.execution_metadata_dict = {}
|
||||
db_model.offload_data = None
|
||||
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Domain model should have the data from database
|
||||
assert domain_model.process_data == process_data
|
||||
|
||||
# Should not be truncated
|
||||
assert domain_model.process_data_truncated is False
|
||||
assert domain_model.get_truncated_process_data() is None
|
||||
Loading…
Reference in New Issue
Block a user