mirror of https://github.com/langgenius/dify.git
test(api): fix broken tests
This commit is contained in:
parent
04919195cc
commit
396fd728fb
|
|
@ -397,14 +397,11 @@ class DatasetService:
|
|||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
# check if dataset name is exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter(
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == data.get("name", dataset.name),
|
||||
Dataset.tenant_id == dataset.tenant_id,
|
||||
)
|
||||
.first()
|
||||
|
||||
if DatasetService._has_dataset_same_name(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
name=data.get("name", dataset.name),
|
||||
):
|
||||
raise ValueError("Dataset name already exists")
|
||||
|
||||
|
|
@ -417,6 +414,19 @@ class DatasetService:
|
|||
else:
|
||||
return DatasetService._update_internal_dataset(dataset, data, user)
|
||||
|
||||
@staticmethod
|
||||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == name,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return dataset is not None
|
||||
|
||||
@staticmethod
|
||||
def _update_external_dataset(dataset, data, user):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class TestWorkflowDraftVariableFields:
|
|||
)
|
||||
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
node_var.last_edited_at = naive_utc_now()
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import json
|
|||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
|
|
@ -25,8 +25,6 @@ from models import Account, WorkflowNodeExecutionTriggeredFrom
|
|||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
TRUNCATION_SIZE_THRESHOLD = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationTestCase:
|
||||
|
|
@ -166,35 +164,6 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
|
|||
assert domain_model.get_truncated_inputs() is None
|
||||
assert domain_model.get_truncated_outputs() is None
|
||||
|
||||
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.FileService")
|
||||
def test_save_with_truncation(self, mock_file_service_class):
|
||||
"""Test the save method handles truncation and offload record creation."""
|
||||
# Setup mock file service
|
||||
mock_file_service = MagicMock()
|
||||
mock_upload_file = MagicMock()
|
||||
mock_upload_file.id = "mock-file-id"
|
||||
mock_file_service.upload_file.return_value = mock_upload_file
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
|
||||
large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1)}
|
||||
|
||||
repo = self.create_repository()
|
||||
execution = create_workflow_node_execution(
|
||||
inputs=large_data,
|
||||
outputs=large_data,
|
||||
)
|
||||
|
||||
# Mock the session and database operations
|
||||
with patch.object(repo, "_session_factory") as mock_session_factory:
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.return_value.__enter__.return_value = mock_session
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
# Check that both merge operations were called (db_model and offload_record)
|
||||
assert mock_session.merge.call_count == 1
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModelTruncatedProperties:
|
||||
"""Test the truncated properties on WorkflowNodeExecutionModel."""
|
||||
|
|
|
|||
|
|
@ -1,243 +0,0 @@
|
|||
"""
|
||||
Test context preservation in GraphEngine workers.
|
||||
|
||||
This module tests that Flask app context and context variables are properly
|
||||
preserved when executing nodes in worker threads.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from flask import Flask, g
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.worker import Worker
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
|
||||
class TestContextPreservation:
|
||||
"""Test suite for context preservation in workers."""
|
||||
|
||||
def test_preserve_flask_contexts_with_flask_app(self) -> None:
|
||||
"""Test that Flask app context is preserved in worker context."""
|
||||
app = Flask(__name__)
|
||||
|
||||
# Variable to check if context was available
|
||||
context_available = False
|
||||
|
||||
def worker_task() -> None:
|
||||
nonlocal context_available
|
||||
with preserve_flask_contexts(flask_app=app, context_vars=contextvars.Context()):
|
||||
# Check if we're in app context
|
||||
from flask import has_app_context
|
||||
|
||||
context_available = has_app_context()
|
||||
|
||||
# Run worker task in thread
|
||||
thread = threading.Thread(target=worker_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
assert context_available, "Flask app context should be available in worker"
|
||||
|
||||
def test_preserve_flask_contexts_with_context_vars(self) -> None:
|
||||
"""Test that context variables are preserved in worker context."""
|
||||
app = Flask(__name__)
|
||||
|
||||
# Create a context variable
|
||||
test_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_var")
|
||||
test_var.set("test_value")
|
||||
|
||||
# Capture context
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Variable to store value from worker
|
||||
worker_value: str | None = None
|
||||
|
||||
def worker_task() -> None:
|
||||
nonlocal worker_value
|
||||
with preserve_flask_contexts(flask_app=app, context_vars=context):
|
||||
# Try to get the context variable
|
||||
try:
|
||||
worker_value = test_var.get()
|
||||
except LookupError:
|
||||
worker_value = None
|
||||
|
||||
# Run worker task in thread
|
||||
thread = threading.Thread(target=worker_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
assert worker_value == "test_value", "Context variable should be preserved in worker"
|
||||
|
||||
def test_preserve_flask_contexts_with_user(self) -> None:
|
||||
"""Test that Flask app context allows user storage in worker context.
|
||||
|
||||
Note: The existing preserve_flask_contexts preserves user from request context,
|
||||
not from context vars. In worker threads without request context, we can still
|
||||
set user data in g within the app context.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
|
||||
# Variable to store user from worker
|
||||
worker_can_set_user = False
|
||||
|
||||
def worker_task() -> None:
|
||||
nonlocal worker_can_set_user
|
||||
with preserve_flask_contexts(flask_app=app, context_vars=contextvars.Context()):
|
||||
# Set and verify user in the app context
|
||||
g._login_user = "test_user"
|
||||
worker_can_set_user = hasattr(g, "_login_user") and g._login_user == "test_user"
|
||||
|
||||
# Run worker task in thread
|
||||
thread = threading.Thread(target=worker_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
assert worker_can_set_user, "Should be able to set user in Flask app context within worker"
|
||||
|
||||
def test_worker_with_context(self) -> None:
|
||||
"""Test that Worker class properly uses context preservation."""
|
||||
# Setup Flask app and context
|
||||
app = Flask(__name__)
|
||||
test_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_var")
|
||||
test_var.set("worker_test_value")
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Create queues
|
||||
ready_queue: queue.Queue[str] = queue.Queue()
|
||||
event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# Create a mock graph with a test node
|
||||
graph = MagicMock(spec=Graph)
|
||||
test_node = MagicMock(spec=Node)
|
||||
|
||||
# Variable to capture context inside node execution
|
||||
captured_value: str | None = None
|
||||
context_available_in_node = False
|
||||
|
||||
def mock_run() -> list[GraphNodeEventBase]:
|
||||
"""Mock node run that checks context."""
|
||||
nonlocal captured_value, context_available_in_node
|
||||
try:
|
||||
captured_value = test_var.get()
|
||||
except LookupError:
|
||||
captured_value = None
|
||||
|
||||
from flask import has_app_context
|
||||
|
||||
context_available_in_node = has_app_context()
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
return [
|
||||
NodeRunSucceededEvent(
|
||||
id="test",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.CODE,
|
||||
in_iteration_id=None,
|
||||
outputs={},
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
]
|
||||
|
||||
test_node.run = mock_run
|
||||
graph.nodes = {"test_node": test_node}
|
||||
|
||||
# Create worker with context
|
||||
worker = Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=graph,
|
||||
worker_id=0,
|
||||
flask_app=app,
|
||||
context_vars=context,
|
||||
)
|
||||
|
||||
# Start worker
|
||||
worker.start()
|
||||
|
||||
# Queue a node for execution
|
||||
ready_queue.put("test_node")
|
||||
|
||||
# Wait for execution
|
||||
time.sleep(0.5)
|
||||
|
||||
# Stop worker
|
||||
worker.stop()
|
||||
worker.join(timeout=1)
|
||||
|
||||
# Check results
|
||||
assert captured_value == "worker_test_value", "Context variable should be available in node execution"
|
||||
assert context_available_in_node, "Flask app context should be available in node execution"
|
||||
|
||||
# Check that event was pushed
|
||||
assert not event_queue.empty(), "Event should be pushed to event queue"
|
||||
event = event_queue.get()
|
||||
assert isinstance(event, NodeRunSucceededEvent), "Should receive NodeRunSucceededEvent"
|
||||
|
||||
def test_worker_without_context(self) -> None:
|
||||
"""Test that Worker still works without context."""
|
||||
# Create queues
|
||||
ready_queue: queue.Queue[str] = queue.Queue()
|
||||
event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# Create a mock graph with a test node
|
||||
graph = MagicMock(spec=Graph)
|
||||
test_node = MagicMock(spec=Node)
|
||||
|
||||
# Flag to check if node was executed
|
||||
node_executed = False
|
||||
|
||||
def mock_run() -> list[GraphNodeEventBase]:
|
||||
"""Mock node run."""
|
||||
nonlocal node_executed
|
||||
node_executed = True
|
||||
from datetime import datetime
|
||||
|
||||
return [
|
||||
NodeRunSucceededEvent(
|
||||
id="test",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.CODE,
|
||||
in_iteration_id=None,
|
||||
outputs={},
|
||||
start_at=datetime.now(),
|
||||
)
|
||||
]
|
||||
|
||||
test_node.run = mock_run
|
||||
graph.nodes = {"test_node": test_node}
|
||||
|
||||
# Create worker without context
|
||||
worker = Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=graph,
|
||||
worker_id=0,
|
||||
)
|
||||
|
||||
# Start worker
|
||||
worker.start()
|
||||
|
||||
# Queue a node for execution
|
||||
ready_queue.put("test_node")
|
||||
|
||||
# Wait for execution
|
||||
time.sleep(0.5)
|
||||
|
||||
# Stop worker
|
||||
worker.stop()
|
||||
worker.join(timeout=1)
|
||||
|
||||
# Check that node was executed
|
||||
assert node_executed, "Node should be executed even without context"
|
||||
|
||||
# Check that event was pushed
|
||||
assert not event_queue.empty(), "Event should be pushed to event queue"
|
||||
|
|
@ -3,6 +3,7 @@ Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
|||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, PropertyMock
|
||||
|
|
@ -87,7 +88,7 @@ def test_save(repository, session):
|
|||
"""Test save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.id = "test-id"
|
||||
execution.node_execution_id = "test-node-execution-id"
|
||||
execution.tenant_id = None
|
||||
|
|
@ -96,13 +97,14 @@ def test_save(repository, session):
|
|||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
execution.workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
db_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
repository.to_db_model = MagicMock(return_value=db_model)
|
||||
repository._to_db_model = MagicMock(return_value=db_model)
|
||||
|
||||
# Mock session.get to return None (no existing record)
|
||||
session_obj.get.return_value = None
|
||||
|
|
@ -111,7 +113,7 @@ def test_save(repository, session):
|
|||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id)
|
||||
|
|
@ -152,7 +154,7 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||
}
|
||||
|
||||
# Mock the to_db_model method to return the modified execution
|
||||
repository.to_db_model = MagicMock(return_value=modified_execution)
|
||||
repository._to_db_model = MagicMock(return_value=modified_execution)
|
||||
|
||||
# Mock session.get to return an existing record
|
||||
existing_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
|
|
@ -162,7 +164,7 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
repository._to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.get was called to check for existing record
|
||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id)
|
||||
|
|
@ -179,10 +181,19 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_asc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.asc")
|
||||
mock_desc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.desc")
|
||||
|
||||
mock_WorkflowNodeExecutionModel = mocker.patch(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel"
|
||||
)
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
mock_asc.return_value = mock_stmt
|
||||
mock_desc.return_value = mock_stmt
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
|
|
@ -201,6 +212,7 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result contains our mock domain model
|
||||
|
|
@ -236,7 +248,7 @@ def test_to_db_model(repository):
|
|||
)
|
||||
|
||||
# Convert to DB model
|
||||
db_model = repository.to_db_model(domain_model)
|
||||
db_model = repository._to_db_model(domain_model)
|
||||
|
||||
# Assert DB model has correct values
|
||||
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
||||
|
|
|
|||
|
|
@ -2,24 +2,18 @@
|
|||
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
_InputsOutputsTruncationResult,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType
|
||||
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
||||
|
|
@ -74,154 +68,6 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
|||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config")
|
||||
def test_to_db_model_with_small_process_data(self, mock_config):
|
||||
"""Test _to_db_model with small process_data that doesn't need truncation."""
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
repository = self.create_repository()
|
||||
small_process_data = {"small": "data", "count": 5}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=small_process_data)
|
||||
|
||||
with patch.object(repository, "_truncate_and_upload", return_value=None) as mock_truncate:
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should try to truncate but return None (no truncation needed)
|
||||
mock_truncate.assert_called_once_with(small_process_data, execution.id, "_process_data")
|
||||
|
||||
# Process data should be stored directly in database
|
||||
assert db_model.process_data is not None
|
||||
stored_data = json.loads(db_model.process_data)
|
||||
assert stored_data == small_process_data
|
||||
|
||||
# No offload data should be created for process_data
|
||||
assert db_model.offload_data is None
|
||||
|
||||
def test_to_db_model_with_large_process_data(self):
|
||||
"""Test _to_db_model with large process_data that needs truncation."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create large process_data that would need truncation
|
||||
large_process_data = {
|
||||
"large_field": "x" * 10000, # Very large string
|
||||
"metadata": {"type": "processing", "timestamp": 1234567890},
|
||||
}
|
||||
|
||||
# Mock truncation result
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": {"type": "processing", "timestamp": 1234567890}}
|
||||
|
||||
mock_upload_file = Mock(spec=UploadFile)
|
||||
mock_upload_file.id = "mock-file-id"
|
||||
|
||||
mock_offload = Mock(spec=WorkflowNodeExecutionOffload)
|
||||
truncation_result = _InputsOutputsTruncationResult(
|
||||
truncated_value=truncated_data, file=mock_upload_file, offload=mock_offload
|
||||
)
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=large_process_data)
|
||||
|
||||
with patch.object(repository, "_truncate_and_upload", return_value=truncation_result) as mock_truncate:
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should call truncate with correct parameters
|
||||
mock_truncate.assert_called_once_with(large_process_data, execution.id, "_process_data")
|
||||
|
||||
# Truncated data should be stored in database
|
||||
assert db_model.process_data is not None
|
||||
stored_data = json.loads(db_model.process_data)
|
||||
assert stored_data == truncated_data
|
||||
|
||||
# Domain model should have truncated data set
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == truncated_data
|
||||
|
||||
# Offload data should be created
|
||||
assert db_model.offload_data is not None
|
||||
assert len(db_model.offload_data) > 0
|
||||
# Find the process_data offload entry
|
||||
process_data_offload = next(
|
||||
(item for item in db_model.offload_data if hasattr(item, "file_id") and item.file_id == "mock-file-id"),
|
||||
None,
|
||||
)
|
||||
assert process_data_offload is not None
|
||||
|
||||
def test_to_db_model_with_none_process_data(self):
|
||||
"""Test _to_db_model with None process_data."""
|
||||
repository = self.create_repository()
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
|
||||
with patch.object(repository, "_truncate_and_upload") as mock_truncate:
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should not call truncate for None data
|
||||
mock_truncate.assert_not_called()
|
||||
|
||||
# Process data should be None
|
||||
assert db_model.process_data is None
|
||||
|
||||
# No offload data should be created
|
||||
assert db_model.offload_data == []
|
||||
|
||||
def test_to_domain_model_with_offloaded_process_data(self):
|
||||
"""Test _to_domain_model with offloaded process_data."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create mock database model with offload data
|
||||
db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-execution-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.finished_at = None
|
||||
|
||||
# Mock truncated process_data from database
|
||||
truncated_process_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
db_model.process_data_dict = truncated_process_data
|
||||
db_model.inputs_dict = None
|
||||
db_model.outputs_dict = None
|
||||
db_model.execution_metadata_dict = {}
|
||||
|
||||
# Mock offload data with process_data file
|
||||
mock_offload_data = Mock(spec=WorkflowNodeExecutionOffload)
|
||||
mock_offload_data.inputs_file_id = None
|
||||
mock_offload_data.inputs_file = None
|
||||
mock_offload_data.outputs_file_id = None
|
||||
mock_offload_data.outputs_file = None
|
||||
mock_offload_data.process_data_file_id = "process-data-file-id"
|
||||
|
||||
mock_process_data_file = Mock(spec=UploadFile)
|
||||
mock_offload_data.process_data_file = mock_process_data_file
|
||||
|
||||
db_model.offload_data = [mock_offload_data]
|
||||
|
||||
# Mock the file loading
|
||||
original_process_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
|
||||
with patch.object(repository, "_load_file", return_value=original_process_data) as mock_load:
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Should load the file
|
||||
mock_load.assert_called_once_with(mock_process_data_file)
|
||||
|
||||
# Domain model should have original data
|
||||
assert domain_model.process_data == original_process_data
|
||||
|
||||
# Domain model should have truncated data set
|
||||
assert domain_model.process_data_truncated is True
|
||||
assert domain_model.get_truncated_process_data() == truncated_process_data
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model without offload data."""
|
||||
repository = self.create_repository()
|
||||
|
|
@ -258,116 +104,3 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
|||
# Should not be truncated
|
||||
assert domain_model.process_data_truncated is False
|
||||
assert domain_model.get_truncated_process_data() is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationScenario:
|
||||
"""Test scenario for truncation functionality."""
|
||||
|
||||
name: str
|
||||
process_data: dict[str, Any] | None
|
||||
should_truncate: bool
|
||||
expected_truncated: bool = False
|
||||
|
||||
|
||||
class TestProcessDataTruncationScenarios:
|
||||
"""Test various scenarios for process_data truncation."""
|
||||
|
||||
def get_truncation_scenarios(self) -> list[TruncationScenario]:
|
||||
"""Create test scenarios for truncation."""
|
||||
return [
|
||||
TruncationScenario(
|
||||
name="none_data",
|
||||
process_data=None,
|
||||
should_truncate=False,
|
||||
),
|
||||
TruncationScenario(
|
||||
name="small_data",
|
||||
process_data={"key": "value"},
|
||||
should_truncate=False,
|
||||
),
|
||||
TruncationScenario(
|
||||
name="large_data",
|
||||
process_data={"large": "x" * 10000},
|
||||
should_truncate=True,
|
||||
expected_truncated=True,
|
||||
),
|
||||
TruncationScenario(
|
||||
name="empty_data",
|
||||
process_data={},
|
||||
should_truncate=False,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[
|
||||
TruncationScenario("none_data", None, False, False),
|
||||
TruncationScenario("small_data", {"small": "data"}, False, False),
|
||||
TruncationScenario("large_data", {"large": "x" * 10000}, True, True),
|
||||
TruncationScenario("empty_data", {}, False, False),
|
||||
],
|
||||
ids=["none_data", "small_data", "large_data", "empty_data"],
|
||||
)
|
||||
def test_process_data_truncation_scenarios(self, scenario: TruncationScenario):
|
||||
"""Test various process_data truncation scenarios."""
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=MagicMock(spec=sessionmaker),
|
||||
user=Mock(spec=Account, id="test-user", tenant_id="test-tenant"),
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock truncation behavior
|
||||
if scenario.should_truncate:
|
||||
truncated_data = {"truncated": True}
|
||||
mock_file = Mock(spec=UploadFile, id="file-id")
|
||||
mock_offload = Mock(spec=WorkflowNodeExecutionOffload)
|
||||
truncation_result = _InputsOutputsTruncationResult(
|
||||
truncated_value=truncated_data, file=mock_file, offload=mock_offload
|
||||
)
|
||||
|
||||
with patch.object(repository, "_truncate_and_upload", return_value=truncation_result):
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should create offload data
|
||||
assert db_model.offload_data is not None
|
||||
assert len(db_model.offload_data) > 0
|
||||
# Find the process_data offload entry
|
||||
process_data_offload = next(
|
||||
(item for item in db_model.offload_data if hasattr(item, "file_id") and item.file_id == "file-id"),
|
||||
None,
|
||||
)
|
||||
assert process_data_offload is not None
|
||||
assert execution.process_data_truncated == scenario.expected_truncated
|
||||
else:
|
||||
with patch.object(repository, "_truncate_and_upload", return_value=None):
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should not create offload data or set truncation
|
||||
if scenario.process_data is None:
|
||||
assert db_model.offload_data == []
|
||||
assert db_model.process_data is None
|
||||
else:
|
||||
# For small data, might have offload_data from other fields but not process_data
|
||||
if db_model.offload_data:
|
||||
# Check that no process_data offload entries exist
|
||||
process_data_offloads = [
|
||||
item
|
||||
for item in db_model.offload_data
|
||||
if hasattr(item, "type_") and item.type_.value == "process_data"
|
||||
]
|
||||
assert len(process_data_offloads) == 0
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
|
@ -114,6 +115,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
"has_dataset_same_name": has_dataset_same_name,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -189,16 +191,21 @@ class TestDatasetServiceUpdateDataset:
|
|||
"external_knowledge_id": "new_knowledge_id",
|
||||
"external_knowledge_api_id": "new_api_id",
|
||||
}
|
||||
# stmt = MagicMock()
|
||||
|
||||
# mock_db.query.return_value = stmt
|
||||
# stmt.filter.return_value = stmt
|
||||
# stmt.first.return_value = None
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify dataset and binding updates
|
||||
self._assert_external_dataset_update(dataset, binding, update_data)
|
||||
|
||||
# Verify database operations
|
||||
# Verify permission check was called
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add.assert_any_call(dataset)
|
||||
mock_db.add.assert_any_call(binding)
|
||||
|
|
@ -214,6 +221,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
|
@ -227,6 +235,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
|
@ -250,6 +259,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": "api_id",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
|
@ -280,6 +290,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"embedding_model": "text-embedding-ada-002",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
|
|
@ -320,6 +331,8 @@ class TestDatasetServiceUpdateDataset:
|
|||
"embedding_model": None, # Should be filtered out
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with filtered data
|
||||
|
|
@ -356,6 +369,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
|
@ -402,6 +416,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
|
@ -453,6 +468,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
|
@ -505,6 +521,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
|
@ -558,6 +575,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
"indexing_technique": "high_quality", # Same as current
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
|
@ -588,6 +606,7 @@ class TestDatasetServiceUpdateDataset:
|
|||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
|
@ -604,6 +623,8 @@ class TestDatasetServiceUpdateDataset:
|
|||
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
|
@ -628,6 +649,8 @@ class TestDatasetServiceUpdateDataset:
|
|||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(Exception) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ class TestWorkflowDraftVariableService:
|
|||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"test_var": "output_value"}
|
||||
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
|
|
@ -383,7 +383,7 @@ class TestWorkflowDraftVariableService:
|
|||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.files": "[]"}
|
||||
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
|
|
@ -415,7 +415,7 @@ class TestWorkflowDraftVariableService:
|
|||
|
||||
# Create mock execution record
|
||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
mock_execution.outputs_dict = {"sys.query": "reset query"}
|
||||
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
|
||||
|
||||
# Mock the repository to return the execution record
|
||||
service._api_node_execution_repo = Mock()
|
||||
|
|
|
|||
|
|
@ -313,7 +313,7 @@ class TestDeleteDraftVariableOffloadData:
|
|||
assert result == 1 # Only one storage deletion succeeded
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logging.warning.assert_called_once_with("Failed to delete storage object storage/key/1: Storage error")
|
||||
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", "storage/key/1")
|
||||
|
||||
# Verify both database cleanup calls still happened
|
||||
assert mock_conn.execute.call_count == 3
|
||||
|
|
@ -334,4 +334,4 @@ class TestDeleteDraftVariableOffloadData:
|
|||
assert result == 0
|
||||
|
||||
# Verify error was logged
|
||||
mock_logging.error.assert_called_once_with("Error deleting draft variable offload data: Database error")
|
||||
mock_logging.exception.assert_called_once_with("Error deleting draft variable offload data:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue