diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index e96d70c4a9..98e84f2032 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,16 +3,27 @@ import unittest import uuid import pytest +from sqlalchemy import delete from sqlalchemy.orm import Session +from core.variables.segments import StringSegment +from core.variables.types import SegmentType from core.variables.variables import StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes import NodeType +from extensions.ext_storage import storage from factories.variable_factory import build_segment from libs import datetime_utils from models import db -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel -from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService +from models.enums import CreatorUserRole +from models.model import UploadFile +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, WorkflowNodeExecutionModel +from services.workflow_draft_variable_service import ( + DraftVariableSaver, + DraftVarLoader, + VariableResetError, + WorkflowDraftVariableService, +) @pytest.mark.usefixtures("flask_req_ctx") @@ -175,6 +186,23 @@ class TestDraftVariableLoader(unittest.TestCase): _node1_id = "test_loader_node_1" _node_exec_id = str(uuid.uuid4()) + # @pytest.fixture + # def test_app_id(self): + # return str(uuid.uuid4()) + + # @pytest.fixture + # def test_tenant_id(self): + # return str(uuid.uuid4()) + + # @pytest.fixture + # def session(self): + # with Session(bind=db.engine, expire_on_commit=False) as session: + # yield session + + # @pytest.fixture + # def node_var(self, session): + # pass + def setUp(self): self._test_app_id = str(uuid.uuid4()) self._test_tenant_id = str(uuid.uuid4()) @@ -241,6 +269,246 @@ class TestDraftVariableLoader(unittest.TestCase): node1_var = next(v for v in variables if v.selector[0] == self._node1_id) assert node1_var.id == self._node_var_id + @pytest.mark.usefixtures("setup_account") + def test_load_offloaded_variable_string_type_integration(self, setup_account): + """Test _load_offloaded_variable with string type using DraftVariableSaver for data creation.""" + + # Create a large string that will be offloaded + test_content = "x" * 15000 # Create a string larger than LARGE_VARIABLE_THRESHOLD (10KB) + large_string_segment = StringSegment(value=test_content) + + node_execution_id = str(uuid.uuid4()) + + try: + with Session(bind=db.engine, expire_on_commit=False) as session: + # Use DraftVariableSaver to create offloaded variable (this mimics production) + saver = DraftVariableSaver( + session=session, + app_id=self._test_app_id, + node_id="test_offload_node", + node_type=NodeType.LLM, # Use a real node type + node_execution_id=node_execution_id, + user=setup_account, + ) + + # Save the variable - this will trigger offloading due to large size + saver.save(outputs={"offloaded_string_var": large_string_segment}) + session.commit() + + # Now test loading using DraftVarLoader + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + + # Load the variable using the standard workflow + variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]]) + + # Verify results + assert len(variables) == 1 + loaded_variable = variables[0] + assert loaded_variable.name == "offloaded_string_var" + assert loaded_variable.selector == ["test_offload_node", "offloaded_string_var"] + assert isinstance(loaded_variable.value, StringSegment) + assert loaded_variable.value.value == test_content + + finally: + # Clean up - delete all draft variables for this app + with Session(bind=db.engine) as session: + service = WorkflowDraftVariableService(session) + service.delete_workflow_variables(self._test_app_id) + session.commit() + + def test_load_offloaded_variable_object_type_integration(self): + """Test _load_offloaded_variable with object type using real storage and service.""" + + # Create a test object + test_object = {"key1": "value1", "key2": 42, "nested": {"inner": "data"}} + test_json = json.dumps(test_object, ensure_ascii=False, separators=(",", ":")) + content_bytes = test_json.encode() + + # Create an upload file record + upload_file = UploadFile( + tenant_id=self._test_tenant_id, + storage_type="local", + key=f"test_offload_{uuid.uuid4()}.json", + name="test_offload.json", + size=len(content_bytes), + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=datetime_utils.naive_utc_now(), + used=True, + used_by=str(uuid.uuid4()), + used_at=datetime_utils.naive_utc_now(), + ) + + # Store the content in storage + storage.save(upload_file.key, content_bytes) + + # Create a variable file record + variable_file = WorkflowDraftVariableFile( + upload_file_id=upload_file.id, + value_type=SegmentType.OBJECT, + tenant_id=self._test_tenant_id, + app_id=self._test_app_id, + user_id=str(uuid.uuid4()), + size=len(content_bytes), + created_at=datetime_utils.naive_utc_now(), + ) + + try: + with Session(bind=db.engine, expire_on_commit=False) as session: + # Add upload file and variable file first to get their IDs + session.add_all([upload_file, variable_file]) + session.flush() # This generates the IDs + + # Now create the offloaded draft variable with the correct file_id + offloaded_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id="test_offload_node", + name="offloaded_object_var", + value=build_segment({"truncated": True}), + visible=True, + node_execution_id=str(uuid.uuid4()), + ) + offloaded_var.file_id = variable_file.id + + session.add(offloaded_var) + session.flush() + session.commit() + + # Use the service method that properly preloads relationships + service = WorkflowDraftVariableService(session) + draft_vars = service.get_draft_variables_by_selectors( + self._test_app_id, [["test_offload_node", "offloaded_object_var"]] + ) + + assert len(draft_vars) == 1 + loaded_var = draft_vars[0] + assert loaded_var.is_truncated() + + # Create DraftVarLoader and test loading + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + + # Test the _load_offloaded_variable method + selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var) + + # Verify the results + assert selector_tuple == ("test_offload_node", "offloaded_object_var") + assert variable.id == loaded_var.id + assert variable.name == "offloaded_object_var" + assert variable.value.value == test_object + + finally: + # Clean up + with Session(bind=db.engine) as session: + # Query and delete by ID to ensure they're tracked in this session + session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() + session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() + session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.commit() + # Clean up storage + try: + storage.delete(upload_file.key) + except Exception: + pass # Ignore cleanup failures + + def test_load_variables_with_offloaded_variables_integration(self): + """Test load_variables method with mix of regular and offloaded variables using real storage.""" + # Create a regular variable (already exists from setUp) + # Create offloaded variable content + test_content = "This is offloaded content for integration test" + content_bytes = test_content.encode() + + # Create upload file record + upload_file = UploadFile( + tenant_id=self._test_tenant_id, + storage_type="local", + key=f"test_integration_{uuid.uuid4()}.txt", + name="test_integration.txt", + size=len(content_bytes), + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=datetime_utils.naive_utc_now(), + used=True, + used_by=str(uuid.uuid4()), + used_at=datetime_utils.naive_utc_now(), + ) + + # Store the content + storage.save(upload_file.key, content_bytes) + + # Create variable file + variable_file = WorkflowDraftVariableFile( + upload_file_id=upload_file.id, + value_type=SegmentType.STRING, + tenant_id=self._test_tenant_id, + app_id=self._test_app_id, + user_id=str(uuid.uuid4()), + size=len(content_bytes), + created_at=datetime_utils.naive_utc_now(), + ) + + try: + with Session(bind=db.engine, expire_on_commit=False) as session: + # Add upload file and variable file first to get their IDs + session.add_all([upload_file, variable_file]) + session.flush() # This generates the IDs + + # Now create the offloaded draft variable with the correct file_id + offloaded_var = WorkflowDraftVariable.new_node_variable( + app_id=self._test_app_id, + node_id="test_integration_node", + name="offloaded_integration_var", + value=build_segment("truncated"), + visible=True, + node_execution_id=str(uuid.uuid4()), + ) + offloaded_var.file_id = variable_file.id + + session.add(offloaded_var) + session.flush() + session.commit() + + # Test load_variables with both regular and offloaded variables + # This method should handle the relationship preloading internally + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + + variables = var_loader.load_variables( + [ + [SYSTEM_VARIABLE_NODE_ID, "sys_var"], # Regular variable from setUp + ["test_integration_node", "offloaded_integration_var"], # Offloaded variable + ] + ) + + # Verify results + assert len(variables) == 2 + + # Find regular variable + regular_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID) + assert regular_var.id == self._sys_var_id + assert regular_var.value == "sys_value" + + # Find offloaded variable + offloaded_loaded_var = next(v for v in variables if v.selector[0] == "test_integration_node") + assert offloaded_loaded_var.id == offloaded_var.id + assert offloaded_loaded_var.value == test_content + + finally: + # Clean up + with Session(bind=db.engine) as session: + # Query and delete by ID to ensure they're tracked in this session + session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() + session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() + session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.commit() + # Clean up storage + try: + storage.delete(upload_file.key) + except Exception: + pass # Ignore cleanup failures + @pytest.mark.usefixtures("flask_req_ctx") class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): @@ -272,7 +540,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): triggered_from="workflow-run", workflow_run_id=str(uuid.uuid4()), index=1, - node_execution_id=self._node_exec_id, + node_execution_id=str(uuid.uuid4()), node_id=self._node_id, node_type=NodeType.LLM.value, title="Test Node", @@ -281,7 +549,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): outputs='{"test_var": "output_value", "other_var": "other_output"}', status="succeeded", elapsed_time=1.5, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=str(uuid.uuid4()), ) @@ -336,10 +604,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): ) self._conv_var.last_edited_at = datetime_utils.naive_utc_now() + with Session(db.engine, expire_on_commit=False) as persistent_session, persistent_session.begin(): + persistent_session.add( + self._workflow_node_execution, + ) + # Add all to database db.session.add_all( [ - self._workflow_node_execution, self._node_var_with_exec, self._node_var_without_exec, self._node_var_missing_exec, @@ -354,6 +626,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): self._node_var_missing_exec_id = self._node_var_missing_exec.id self._conv_var_id = self._conv_var.id + def tearDown(self): + self._session.rollback() + with Session(db.engine) as session, session.begin(): + stmt = delete(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.id == self._workflow_node_execution.id + ) + session.execute(stmt) + def _get_test_srv(self) -> WorkflowDraftVariableService: return WorkflowDraftVariableService(session=self._session) @@ -380,9 +660,6 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): ) return workflow - def tearDown(self): - self._session.rollback() - def test_reset_node_variable_with_valid_execution_record(self): """Test resetting a node variable with valid execution record - should restore from execution""" srv = self._get_test_srv() diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 2f7fc60ada..29a66f1d9d 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -1,12 +1,14 @@ import uuid +from unittest.mock import patch import pytest from sqlalchemy import delete from core.variables.segments import StringSegment from models import Tenant, db -from models.model import App -from models.workflow import WorkflowDraftVariable +from models.enums import CreatorUserRole +from models.model import App, UploadFile +from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch @@ -212,3 +214,255 @@ class TestDeleteDraftVariablesIntegration: .execution_options(synchronize_session=False) ) db.session.execute(query) + + +class TestDeleteDraftVariablesWithOffloadIntegration: + """Integration tests for draft variable deletion with Offload data.""" + + @pytest.fixture + def setup_offload_test_data(self, app_and_tenant): + """Create test data with draft variables that have associated Offload files.""" + tenant, app = app_and_tenant + + # Create UploadFile records + from libs.datetime_utils import naive_utc_now + + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + db.session.add(upload_file1) + db.session.add(upload_file2) + db.session.flush() + + # Create WorkflowDraftVariableFile records + from core.variables.types import SegmentType + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + db.session.add(var_file1) + db.session.add(var_file2) + db.session.flush() + + # Create WorkflowDraftVariable records with file associations + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + # Create a regular variable without Offload data + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + + db.session.add(draft_var1) + db.session.add(draft_var2) + db.session.add(draft_var3) + db.session.commit() + + yield { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } + + # Cleanup + db.session.rollback() + + # Clean up any remaining records + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]), + (WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]), + (UploadFile, [uf.id for uf in [upload_file1, upload_file2]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + db.session.execute(cleanup_query) + + db.session.commit() + + @patch("extensions.ext_storage.storage") + def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): + """Test that deleting draft variables also cleans up associated Offload data.""" + data = setup_offload_test_data + app_id = data["app"].id + + # Mock storage deletion to succeed + mock_storage.delete.return_value = None + + # Verify initial state + draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = db.session.query(WorkflowDraftVariableFile).count() + upload_files_before = db.session.query(UploadFile).count() + + assert draft_vars_before == 3 # 2 with files + 1 regular + assert var_files_before == 2 + assert upload_files_before == 2 + + # Delete draft variables + deleted_count = delete_draft_variables_batch(app_id, batch_size=10) + + # Verify results + assert deleted_count == 3 + + # Check that all draft variables are deleted + draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert draft_vars_after == 0 + + # Check that associated Offload data is cleaned up + var_files_after = db.session.query(WorkflowDraftVariableFile).count() + upload_files_after = db.session.query(UploadFile).count() + + assert var_files_after == 0 # All variable files should be deleted + assert upload_files_after == 0 # All upload files should be deleted + + # Verify storage deletion was called for both files + assert mock_storage.delete.call_count == 2 + storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list] + assert "test/file1.json" in storage_keys_deleted + assert "test/file2.json" in storage_keys_deleted + + @patch("extensions.ext_storage.storage") + def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): + """Test that database cleanup continues even when storage deletion fails.""" + data = setup_offload_test_data + app_id = data["app"].id + + # Mock storage deletion to fail for first file, succeed for second + mock_storage.delete.side_effect = [Exception("Storage error"), None] + + # Delete draft variables + deleted_count = delete_draft_variables_batch(app_id, batch_size=10) + + # Verify that all draft variables are still deleted + assert deleted_count == 3 + + draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert draft_vars_after == 0 + + # Database cleanup should still succeed even with storage errors + var_files_after = db.session.query(WorkflowDraftVariableFile).count() + upload_files_after = db.session.query(UploadFile).count() + + assert var_files_after == 0 + assert upload_files_after == 0 + + # Verify storage deletion was attempted for both files + assert mock_storage.delete.call_count == 2 + + @patch("extensions.ext_storage.storage") + def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data): + """Test deletion with mix of variables with and without Offload data.""" + data = setup_offload_test_data + app_id = data["app"].id + + # Create additional app with only regular variables (no offload data) + tenant = data["tenant"] + app2 = App( + tenant_id=tenant.id, + name="Test App 2", + mode="workflow", + enable_site=True, + enable_api=True, + ) + db.session.add(app2) + db.session.flush() + + # Add regular variables to app2 + regular_vars = [] + for i in range(3): + var = WorkflowDraftVariable.new_node_variable( + app_id=app2.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + db.session.add(var) + regular_vars.append(var) + db.session.commit() + + try: + # Mock storage deletion + mock_storage.delete.return_value = None + + # Delete variables for app2 (no offload data) + deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10) + assert deleted_count_app2 == 3 + + # Verify storage wasn't called for app2 (no offload files) + mock_storage.delete.assert_not_called() + + # Delete variables for original app (with offload data) + deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10) + assert deleted_count_app1 == 3 + + # Now storage should be called for the offload files + assert mock_storage.delete.call_count == 2 + + finally: + # Cleanup app2 and its variables + cleanup_vars_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.app_id == app2.id) + .execution_options(synchronize_session=False) + ) + db.session.execute(cleanup_vars_query) + + app2_obj = db.session.get(App, app2.id) + if app2_obj: + db.session.delete(app2_obj) + db.session.commit() diff --git a/api/tests/integration_tests/test_offload.py b/api/tests/integration_tests/test_offload.py new file mode 100644 index 0000000000..a49330475e --- /dev/null +++ b/api/tests/integration_tests/test_offload.py @@ -0,0 +1,213 @@ +import uuid + +import pytest +from sqlalchemy.orm import Session, joinedload, selectinload + +from libs.datetime_utils import naive_utc_now +from libs.uuid_utils import uuidv7 +from models import db +from models.enums import CreatorUserRole +from models.model import UploadFile +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom + + +@pytest.fixture +def session(flask_req_ctx): + with Session(bind=db.engine, expire_on_commit=False) as session: + yield session + + +def test_offload(session, setup_account): + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + # step 1: create a UploadFile + input_upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key="fake_storage_key", + name="test_file.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=setup_account.id, + created_at=naive_utc_now(), + used=False, + ) + output_upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key="fake_storage_key", + name="test_file.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=setup_account.id, + created_at=naive_utc_now(), + used=False, + ) + session.add(input_upload_file) + session.add(output_upload_file) + session.flush() + + # step 2: create a WorkflowNodeExecutionModel + node_execution = WorkflowNodeExecutionModel( + id=str(uuid.uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid.uuid4()), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + index=1, + node_id="test_node_id", + node_type="test", + title="Test Node", + status="succeeded", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=setup_account.id, + ) + session.add(node_execution) + session.flush() + + # step 3: create a WorkflowNodeExecutionOffload + offload = WorkflowNodeExecutionOffload( + id=uuidv7(), + tenant_id=tenant_id, + app_id=app_id, + node_execution_id=node_execution.id, + inputs_file_id=input_upload_file.id, + outputs_file_id=output_upload_file.id, + ) + session.add(offload) + session.flush() + + # Test preloading - this should work without raising LazyLoadError + result = ( + session.query(WorkflowNodeExecutionModel) + .options( + selectinload(WorkflowNodeExecutionModel.offload_data).options( + joinedload( + WorkflowNodeExecutionOffload.inputs_file, + ), + joinedload( + WorkflowNodeExecutionOffload.outputs_file, + ), + ) + ) + .filter(WorkflowNodeExecutionModel.id == node_execution.id) + .first() + ) + + # Verify the relationships are properly loaded + assert result is not None + assert result.offload_data is not None + assert result.offload_data.inputs_file is not None + assert result.offload_data.inputs_file.id == input_upload_file.id + assert result.offload_data.inputs_file.name == "test_file.txt" + + # Test the computed properties + assert result.inputs_truncated is True + assert result.outputs_truncated is False + assert False + + +def _test_offload_save(session, setup_account): + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + # step 1: create a UploadFile + input_upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key="fake_storage_key", + name="test_file.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=setup_account.id, + created_at=naive_utc_now(), + used=False, + ) + output_upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key="fake_storage_key", + name="test_file.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=setup_account.id, + created_at=naive_utc_now(), + used=False, + ) + + node_execution_id = id = str(uuid.uuid4()) + + # step 3: create a WorkflowNodeExecutionOffload + offload = WorkflowNodeExecutionOffload( + id=uuidv7(), + tenant_id=tenant_id, + app_id=app_id, + node_execution_id=node_execution_id, + ) + offload.inputs_file = input_upload_file + offload.outputs_file = output_upload_file + + # step 2: create a WorkflowNodeExecutionModel + node_execution = WorkflowNodeExecutionModel( + id=str(uuid.uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid.uuid4()), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + index=1, + node_id="test_node_id", + node_type="test", + title="Test Node", + status="succeeded", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=setup_account.id, + ) + node_execution.offload_data = offload + session.add(node_execution) + session.flush() + + assert False + + +""" +2025-08-21 15:34:49,570 INFO sqlalchemy.engine.Engine BEGIN (implicit) +2025-08-21 15:34:49,572 INFO sqlalchemy.engine.Engine INSERT INTO upload_files (id, tenant_id, storage_type, key, name, size, extension, mime_type, created_by_role, created_by, created_at, used, used_by, used_at, hash, source_url) VALUES (%(id__0)s::UUID, %(tenant_id__0)s::UUID, %(storage_type__0)s, %(k ... 410 characters truncated ... (created_at__1)s, %(used__1)s, %(used_by__1)s::UUID, %(used_at__1)s, %(hash__1)s, %(source_url__1)s) +2025-08-21 15:34:49,572 INFO sqlalchemy.engine.Engine [generated in 0.00009s (insertmanyvalues) 1/1 (unordered)] {'created_at__0': datetime.datetime(2025, 8, 21, 15, 34, 49, 570482), 'id__0': '366621fa-4326-403e-8709-62e4d0de7367', 'storage_type__0': 'local', 'extension__0': 'txt', 'created_by__0': 'ccc7657c-fb48-46bd-8f42-c837b14eab18', 'used_at__0': None, 'used_by__0': None, 'source_url__0': '', 'mime_type__0': 'text/plain', 'created_by_role__0': 'account', 'used__0': False, 'size__0': 1024, 'tenant_id__0': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'hash__0': None, 'key__0': 'fake_storage_key', 'name__0': 'test_file.txt', 'created_at__1': datetime.datetime(2025, 8, 21, 15, 34, 49, 570563), 'id__1': '3cdec641-a452-4df0-a9af-4a1a30c27ea5', 'storage_type__1': 'local', 'extension__1': 'txt', 'created_by__1': 'ccc7657c-fb48-46bd-8f42-c837b14eab18', 'used_at__1': None, 'used_by__1': None, 'source_url__1': '', 'mime_type__1': 'text/plain', 'created_by_role__1': 'account', 'used__1': False, 'size__1': 1024, 'tenant_id__1': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'hash__1': None, 'key__1': 'fake_storage_key', 'name__1': 'test_file.txt'} +2025-08-21 15:34:49,576 INFO sqlalchemy.engine.Engine INSERT INTO workflow_node_executions (id, tenant_id, app_id, workflow_id, triggered_from, workflow_run_id, index, predecessor_node_id, node_execution_id, node_id, node_type, title, inputs, process_data, outputs, status, error, execution_metadata, created_by_role, created_by, finished_at) VALUES (%(id)s::UUID, %(tenant_id)s::UUID, %(app_id)s::UUID, %(workflow_id)s::UUID, %(triggered_from)s, %(workflow_run_id)s::UUID, %(index)s, %(predecessor_node_id)s, %(node_execution_id)s, %(node_id)s, %(node_type)s, %(title)s, %(inputs)s, %(process_data)s, %(outputs)s, %(status)s, %(error)s, %(execution_metadata)s, %(created_by_role)s, %(created_by)s::UUID, %(finished_at)s) RETURNING workflow_node_executions.elapsed_time, workflow_node_executions.created_at +2025-08-21 15:34:49,576 INFO sqlalchemy.engine.Engine [generated in 0.00019s] {'id': '9aac28b6-b6fc-4aea-abdf-21da3227e621', 'tenant_id': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'app_id': '79fa81c7-2760-40db-af54-74cb2fea2ce7', 'workflow_id': '95d341e3-381c-4c54-a383-f685a9741053', 'triggered_from': , 'workflow_run_id': None, 'index': 1, 'predecessor_node_id': None, 'node_execution_id': None, 'node_id': 'test_node_id', 'node_type': 'test', 'title': 'Test Node', 'inputs': None, 'process_data': None, 'outputs': None, 'status': 'succeeded', 'error': None, 'execution_metadata': None, 'created_by_role': 'account', 'created_by': 'ccc7657c-fb48-46bd-8f42-c837b14eab18', 'finished_at': None} +2025-08-21 15:34:49,579 INFO sqlalchemy.engine.Engine INSERT INTO workflow_node_execution_offload (id, created_at, tenant_id, app_id, node_execution_id, inputs_file_id, outputs_file_id) VALUES (%(id)s::UUID, %(created_at)s, %(tenant_id)s::UUID, %(app_id)s::UUID, %(node_execution_id)s::UUID, %(inputs_file_id)s::UUID, %(outputs_file_id)s::UUID) +2025-08-21 15:34:49,579 INFO sqlalchemy.engine.Engine [generated in 0.00016s] {'id': '0198cd44-b7ea-724b-9e1b-5f062a2ef45b', 'created_at': datetime.datetime(2025, 8, 21, 15, 34, 49, 579072), 'tenant_id': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'app_id': '79fa81c7-2760-40db-af54-74cb2fea2ce7', 'node_execution_id': '9aac28b6-b6fc-4aea-abdf-21da3227e621', 'inputs_file_id': '366621fa-4326-403e-8709-62e4d0de7367', 'outputs_file_id': '3cdec641-a452-4df0-a9af-4a1a30c27ea5'} +2025-08-21 15:34:49,581 INFO sqlalchemy.engine.Engine SELECT workflow_node_executions.id AS workflow_node_executions_id, workflow_node_executions.tenant_id AS workflow_node_executions_tenant_id, workflow_node_executions.app_id AS workflow_node_executions_app_id, workflow_node_executions.workflow_id AS workflow_node_executions_workflow_id, workflow_node_executions.triggered_from AS workflow_node_executions_triggered_from, workflow_node_executions.workflow_run_id AS workflow_node_executions_workflow_run_id, workflow_node_executions.index AS workflow_node_executions_index, workflow_node_executions.predecessor_node_id AS workflow_node_executions_predecessor_node_id, workflow_node_executions.node_execution_id AS workflow_node_executions_node_execution_id, workflow_node_executions.node_id AS workflow_node_executions_node_id, workflow_node_executions.node_type AS workflow_node_executions_node_type, workflow_node_executions.title AS workflow_node_executions_title, workflow_node_executions.inputs AS workflow_node_executions_inputs, workflow_node_executions.process_data AS workflow_node_executions_process_data, workflow_node_executions.outputs AS workflow_node_executions_outputs, workflow_node_executions.status AS workflow_node_executions_status, workflow_node_executions.error AS workflow_node_executions_error, workflow_node_executions.elapsed_time AS workflow_node_executions_elapsed_time, workflow_node_executions.execution_metadata AS workflow_node_executions_execution_metadata, workflow_node_executions.created_at AS workflow_node_executions_created_at, workflow_node_executions.created_by_role AS workflow_node_executions_created_by_role, workflow_node_executions.created_by AS workflow_node_executions_created_by, workflow_node_executions.finished_at AS workflow_node_executions_finished_at +FROM workflow_node_executions +WHERE workflow_node_executions.id = %(id_1)s::UUID + LIMIT %(param_1)s +2025-08-21 15:34:49,581 INFO sqlalchemy.engine.Engine [generated in 0.00009s] {'id_1': '9aac28b6-b6fc-4aea-abdf-21da3227e621', 'param_1': 1} +2025-08-21 15:34:49,585 INFO sqlalchemy.engine.Engine SELECT workflow_node_execution_offload.node_execution_id AS workflow_node_execution_offload_node_execution_id, workflow_node_execution_offload.id AS workflow_node_execution_offload_id, workflow_node_execution_offload.created_at AS workflow_node_execution_offload_created_at, workflow_node_execution_offload.tenant_id AS workflow_node_execution_offload_tenant_id, workflow_node_execution_offload.app_id AS workflow_node_execution_offload_app_id, workflow_node_execution_offload.inputs_file_id AS workflow_node_execution_offload_inputs_file_id, workflow_node_execution_offload.outputs_file_id AS workflow_node_execution_offload_outputs_file_id +FROM workflow_node_execution_offload +WHERE workflow_node_execution_offload.node_execution_id IN (%(primary_keys_1)s::UUID) +2025-08-21 15:34:49,585 INFO sqlalchemy.engine.Engine [generated in 0.00021s] {'primary_keys_1': '9aac28b6-b6fc-4aea-abdf-21da3227e621'} +2025-08-21 15:34:49,587 INFO sqlalchemy.engine.Engine SELECT upload_files.id AS upload_files_id, upload_files.tenant_id AS upload_files_tenant_id, upload_files.storage_type AS upload_files_storage_type, upload_files.key AS upload_files_key, upload_files.name AS upload_files_name, upload_files.size AS upload_files_size, upload_files.extension AS upload_files_extension, upload_files.mime_type AS upload_files_mime_type, upload_files.created_by_role AS upload_files_created_by_role, upload_files.created_by AS upload_files_created_by, upload_files.created_at AS upload_files_created_at, upload_files.used AS upload_files_used, upload_files.used_by AS upload_files_used_by, upload_files.used_at AS upload_files_used_at, upload_files.hash AS upload_files_hash, upload_files.source_url AS upload_files_source_url +FROM upload_files +WHERE upload_files.id IN (%(primary_keys_1)s::UUID) +2025-08-21 15:34:49,587 INFO sqlalchemy.engine.Engine [generated in 0.00012s] {'primary_keys_1': '3cdec641-a452-4df0-a9af-4a1a30c27ea5'} +2025-08-21 15:34:49,588 INFO sqlalchemy.engine.Engine SELECT upload_files.id AS upload_files_id, upload_files.tenant_id AS upload_files_tenant_id, upload_files.storage_type AS upload_files_storage_type, upload_files.key AS upload_files_key, upload_files.name AS upload_files_name, upload_files.size AS upload_files_size, upload_files.extension AS upload_files_extension, upload_files.mime_type AS upload_files_mime_type, upload_files.created_by_role AS upload_files_created_by_role, upload_files.created_by AS upload_files_created_by, upload_files.created_at AS upload_files_created_at, upload_files.used AS upload_files_used, upload_files.used_by AS upload_files_used_by, upload_files.used_at AS upload_files_used_at, upload_files.hash AS upload_files_hash, upload_files.source_url AS upload_files_source_url +FROM upload_files +WHERE upload_files.id IN (%(primary_keys_1)s::UUID) +2025-08-21 15:34:49,588 INFO sqlalchemy.engine.Engine [generated in 0.00010s] {'primary_keys_1': '366621fa-4326-403e-8709-62e4d0de7367'} +""" + + +""" +upload_file_id: 366621fa-4326-403e-8709-62e4d0de7367 3cdec641-a452-4df0-a9af-4a1a30c27ea5 + +workflow_node_executions_id: 9aac28b6-b6fc-4aea-abdf-21da3227e621 + +offload_id: 0198cd44-b7ea-724b-9e1b-5f062a2ef45b +""" diff --git a/api/tests/integration_tests/workflow/test_process_data_truncation_integration.py b/api/tests/integration_tests/workflow/test_process_data_truncation_integration.py new file mode 100644 index 0000000000..624deb4abb --- /dev/null +++ b/api/tests/integration_tests/workflow/test_process_data_truncation_integration.py @@ -0,0 +1,421 @@ +""" +Integration tests for process_data truncation functionality. + +These tests verify the end-to-end behavior of process_data truncation across +the entire system, from database storage to API responses. +""" + +import json +from dataclasses import dataclass +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from core.workflow.nodes.enums import NodeType +from models import Account +from models.workflow import WorkflowNodeExecutionTriggeredFrom + + +@dataclass +class TruncationTestData: + """Test data for truncation scenarios.""" + name: str + process_data: dict[str, any] + should_truncate: bool + expected_storage_interaction: bool + + +class TestProcessDataTruncationIntegration: + """Integration tests for process_data truncation functionality.""" + + @pytest.fixture + def in_memory_db_engine(self): + """Create an in-memory SQLite database for testing.""" + engine = create_engine("sqlite:///:memory:") + + # Create minimal table structure for testing + with engine.connect() as conn: + # Create workflow_node_executions table + conn.execute(text(""" + CREATE TABLE workflow_node_executions ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL, + app_id TEXT NOT NULL, + workflow_id TEXT NOT NULL, + triggered_from TEXT NOT NULL, + workflow_run_id TEXT, + index_ INTEGER NOT NULL, + predecessor_node_id TEXT, + node_execution_id TEXT, + node_id TEXT NOT NULL, + node_type TEXT NOT NULL, + title TEXT NOT NULL, + inputs TEXT, + process_data TEXT, + outputs TEXT, + status TEXT NOT NULL, + error TEXT, + elapsed_time REAL DEFAULT 0, + execution_metadata TEXT, + created_at DATETIME NOT NULL, + created_by_role TEXT NOT NULL, + created_by TEXT NOT NULL, + finished_at DATETIME + ) + """)) + + # Create workflow_node_execution_offload table + conn.execute(text(""" + CREATE TABLE workflow_node_execution_offload ( + id TEXT PRIMARY KEY, + created_at DATETIME NOT NULL, + tenant_id TEXT NOT NULL, + app_id TEXT NOT NULL, + node_execution_id TEXT NOT NULL UNIQUE, + inputs_file_id TEXT, + outputs_file_id TEXT, + process_data_file_id TEXT + ) + """)) + + # Create upload_files table (simplified) + conn.execute(text(""" + CREATE TABLE upload_files ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL, + storage_key TEXT NOT NULL, + filename TEXT NOT NULL, + size INTEGER NOT NULL, + created_at DATETIME NOT NULL + ) + """)) + + conn.commit() + + return engine + + @pytest.fixture + def mock_account(self): + """Create a mock account for testing.""" + account = Mock(spec=Account) + account.id = "test-user-id" + account.tenant_id = "test-tenant-id" + return account + + @pytest.fixture + def repository(self, in_memory_db_engine, mock_account): + """Create a repository instance for testing.""" + session_factory = sessionmaker(bind=in_memory_db_engine) + + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=mock_account, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + def create_test_execution( + self, + process_data: dict[str, any] | None = None, + execution_id: str = "test-execution-id" + ) -> WorkflowNodeExecution: + """Create a test execution with process_data.""" + return WorkflowNodeExecution( + id=execution_id, + workflow_id="test-workflow-id", + workflow_execution_id="test-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=process_data, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_at=datetime.now(), + finished_at=datetime.now(), + ) + + def get_truncation_test_data(self) -> list[TruncationTestData]: + """Get test data for various truncation scenarios.""" + return [ + TruncationTestData( + name="small_process_data", + process_data={"small": "data", "count": 5}, + should_truncate=False, + expected_storage_interaction=False, + ), + TruncationTestData( + name="large_process_data", + process_data={"large_field": "x" * 10000, "metadata": "info"}, + should_truncate=True, + expected_storage_interaction=True, + ), + TruncationTestData( + name="complex_large_data", + process_data={ + "logs": ["log entry"] * 500, # Large array + "config": {"setting": "value"}, + "status": "processing", + "details": {"description": "y" * 5000} # Large string + }, + should_truncate=True, + expected_storage_interaction=True, + ), + ] + + @patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config') + @patch('services.file_service.FileService.upload_file') + @patch('extensions.ext_storage.storage') + def test_end_to_end_process_data_truncation( + self, + mock_storage, + mock_upload_file, + mock_config, + repository + ): + """Test end-to-end process_data truncation functionality.""" + # Configure truncation limits + mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 + + # Create large process_data that should be truncated + large_process_data = { + "large_field": "x" * 10000, # Exceeds string length limit + "metadata": {"type": "processing", "timestamp": 1234567890} + } + + # Mock file upload + mock_file = Mock() + mock_file.id = "mock-process-data-file-id" + mock_upload_file.return_value = mock_file + + # Create and save execution + execution = self.create_test_execution(process_data=large_process_data) + repository.save(execution) + + # Verify truncation occurred + assert execution.process_data_truncated is True + truncated_data = execution.get_truncated_process_data() + assert truncated_data is not None + assert truncated_data != large_process_data # Should be different due to truncation + + # Verify file upload was called for process_data + assert mock_upload_file.called + upload_args = mock_upload_file.call_args + assert "_process_data" in upload_args[1]["filename"] + + @patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config') + def test_small_process_data_no_truncation(self, mock_config, repository): + """Test that small process_data is not truncated.""" + # Configure truncation limits + mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 + + # Create small process_data + small_process_data = {"small": "data", "count": 5} + + execution = self.create_test_execution(process_data=small_process_data) + repository.save(execution) + + # Verify no truncation occurred + assert execution.process_data_truncated is False + assert execution.get_truncated_process_data() is None + assert execution.get_response_process_data() == small_process_data + + @pytest.mark.parametrize("test_data", [ + data for data in get_truncation_test_data(None) + ], ids=[data.name for data in get_truncation_test_data(None)]) + @patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config') + @patch('services.file_service.FileService.upload_file') + def test_various_truncation_scenarios( + self, + mock_upload_file, + mock_config, + test_data: TruncationTestData, + repository + ): + """Test various process_data truncation scenarios.""" + # Configure truncation limits + mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 + + if test_data.expected_storage_interaction: + # Mock file upload for truncation scenarios + mock_file = Mock() + mock_file.id = f"file-{test_data.name}" + mock_upload_file.return_value = mock_file + + execution = self.create_test_execution(process_data=test_data.process_data) + repository.save(execution) + + # Verify truncation behavior matches expectations + assert execution.process_data_truncated == test_data.should_truncate + + if test_data.should_truncate: + assert execution.get_truncated_process_data() is not None + assert execution.get_truncated_process_data() != test_data.process_data + assert mock_upload_file.called + else: + assert execution.get_truncated_process_data() is None + assert execution.get_response_process_data() == test_data.process_data + + @patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config') + @patch('services.file_service.FileService.upload_file') + @patch('extensions.ext_storage.storage') + def test_load_truncated_execution_from_database( + self, + mock_storage, + mock_upload_file, + mock_config, + repository, + in_memory_db_engine + ): + """Test loading an execution with truncated process_data from database.""" + # Configure truncation + mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 + + # Create and save execution with large process_data + large_process_data = { + "large_field": "x" * 10000, + "metadata": "info" + } + + # Mock file upload + mock_file = Mock() + mock_file.id = "process-data-file-id" + mock_upload_file.return_value = mock_file + + execution = self.create_test_execution(process_data=large_process_data) + repository.save(execution) + + # Mock storage load for reconstruction + mock_storage.load.return_value = json.dumps(large_process_data).encode() + + # Create a new repository instance to simulate fresh load + session_factory = sessionmaker(bind=in_memory_db_engine) + new_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=Mock(spec=Account, id="test-user", tenant_id="test-tenant"), + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Load executions from database + executions = new_repository.get_by_workflow_run("test-run-id") + + assert len(executions) == 1 + loaded_execution = executions[0] + + # Verify that full data is loaded + assert loaded_execution.process_data == large_process_data + assert loaded_execution.process_data_truncated is True + + # Verify truncated data for responses + response_data = loaded_execution.get_response_process_data() + assert response_data != large_process_data # Should be truncated version + + def test_process_data_none_handling(self, repository): + """Test handling of None process_data.""" + execution = self.create_test_execution(process_data=None) + repository.save(execution) + + # Should handle None gracefully + assert execution.process_data is None + assert execution.process_data_truncated is False + assert execution.get_response_process_data() is None + + def test_empty_process_data_handling(self, repository): + """Test handling of empty process_data.""" + execution = self.create_test_execution(process_data={}) + repository.save(execution) + + # Should handle empty dict gracefully + assert execution.process_data == {} + assert execution.process_data_truncated is False + assert execution.get_response_process_data() == {} + + +class TestProcessDataTruncationApiIntegration: + """Integration tests for API responses with process_data truncation.""" + + def test_api_response_includes_truncated_flag(self): + """Test that API responses include the process_data_truncated flag.""" + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter + from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity + from core.app.entities.queue_entities import QueueNodeSucceededEvent + + # Create execution with truncated process_data + execution = WorkflowNodeExecution( + id="test-execution-id", + workflow_id="test-workflow-id", + workflow_execution_id="test-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data={"large": "x" * 10000}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_at=datetime.now(), + finished_at=datetime.now(), + ) + + # Set truncated data + execution.set_truncated_process_data({"large": "[TRUNCATED]"}) + + # Create converter and event + converter = WorkflowResponseConverter( + application_generate_entity=Mock( + spec=WorkflowAppGenerateEntity, + app_config=Mock(tenant_id="test-tenant") + ) + ) + + event = QueueNodeSucceededEvent( + node_id="test-node-id", + node_type=NodeType.LLM, + node_data=Mock(), + parallel_id=None, + parallel_start_node_id=None, + parent_parallel_id=None, + parent_parallel_start_node_id=None, + in_iteration_id=None, + in_loop_id=None, + ) + + # Generate response + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + # Verify response includes truncated flag and data + assert response is not None + assert response.data.process_data_truncated is True + assert response.data.process_data == {"large": "[TRUNCATED]"} + + # Verify response can be serialized + response_dict = response.to_dict() + assert "process_data_truncated" in response_dict["data"] + assert response_dict["data"]["process_data_truncated"] is True + + def test_workflow_run_fields_include_truncated_flag(self): + """Test that workflow run fields include process_data_truncated.""" + from fields.workflow_run_fields import workflow_run_node_execution_fields + + # Verify the field is included in the definition + assert "process_data_truncated" in workflow_run_node_execution_fields + + # The field should be a Boolean field + field = workflow_run_node_execution_fields["process_data_truncated"] + from flask_restful import fields + assert isinstance(field, fields.Boolean) \ No newline at end of file diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 66ddc0ba4c..e47134ec2d 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -14,6 +14,8 @@ from pathlib import Path from typing import Optional import pytest +from alembic import command as alembic_command +from alembic.config import Config from flask import Flask from flask.testing import FlaskClient from sqlalchemy import Engine, text @@ -345,6 +347,12 @@ def _create_app_with_containers() -> Flask: with db.engine.connect() as conn, conn.begin(): conn.execute(text(_UUIDv7SQL)) db.create_all() + # migration_dir = _get_migration_dir() + # alembic_config = Config() + # alembic_config.config_file_name = str(migration_dir / "alembic.ini") + # alembic_config.set_main_option("sqlalchemy.url", _get_engine_url(db.engine)) + # alembic_config.set_main_option("script_location", str(migration_dir)) + # alembic_command.upgrade(revision="head", config=alembic_config) logger.info("Database schema created successfully") logger.info("Flask application configured and ready for testing") diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index ac3c8e45c9..723fdfeb17 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -1,7 +1,9 @@ import uuid from collections import OrderedDict from typing import Any, NamedTuple +from unittest.mock import MagicMock, patch +import pytest from flask_restx import marshal from controllers.console.app.workflow_draft_variable import ( @@ -9,11 +11,14 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + _serialize_full_content, ) +from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now -from models.workflow import WorkflowDraftVariable +from libs.uuid_utils import uuidv7 +from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import WorkflowDraftVariableList _TEST_APP_ID = "test_app_id" @@ -21,6 +26,54 @@ _TEST_NODE_EXEC_ID = str(uuid.uuid4()) class TestWorkflowDraftVariableFields: + def test_serialize_full_content(self): + """Test that _serialize_full_content uses pre-loaded relationships.""" + # Create mock objects with relationships pre-loaded + mock_variable_file = MagicMock(spec=WorkflowDraftVariableFile) + mock_variable_file.size = 100000 + mock_variable_file.length = 50 + mock_variable_file.value_type = SegmentType.OBJECT + mock_variable_file.upload_file_id = "test-upload-file-id" + + mock_variable = MagicMock(spec=WorkflowDraftVariable) + mock_variable.file_id = "test-file-id" + mock_variable.variable_file = mock_variable_file + + # Mock the file helpers + with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers: + mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url" + + # Call the function + result = _serialize_full_content(mock_variable) + + # Verify it returns the expected structure + assert result is not None + assert result["size_bytes"] == 100000 + assert result["length"] == 50 + assert result["value_type"] == "object" + assert "download_url" in result + assert result["download_url"] == "http://example.com/signed-url" + + # Verify it used the pre-loaded relationships (no database queries) + mock_file_helpers.get_signed_file_url.assert_called_once_with("test-upload-file-id", as_attachment=True) + + def test_serialize_full_content_handles_none_cases(self): + """Test that _serialize_full_content handles None cases properly.""" + + # Test with no file_id + draft_var = WorkflowDraftVariable() + draft_var.file_id = None + result = _serialize_full_content(draft_var) + assert result is None + + def test_serialize_full_content_should_raises_when_file_id_exists_but_file_is_none(self): + # Test with no file_id + draft_var = WorkflowDraftVariable() + draft_var.file_id = str(uuid.uuid4()) + draft_var.variable_file = None + with pytest.raises(AssertionError): + result = _serialize_full_content(draft_var) + def test_conversation_variable(self): conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1) @@ -39,12 +92,14 @@ class TestWorkflowDraftVariableFields: "value_type": "number", "edited": False, "visible": True, + "is_truncated": False, } ) assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value expected_with_value = expected_without_value.copy() expected_with_value["value"] = 1 + expected_with_value["full_content"] = None assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value def test_create_sys_variable(self): @@ -70,11 +125,13 @@ class TestWorkflowDraftVariableFields: "value_type": "string", "edited": True, "visible": True, + "is_truncated": False, } ) assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value expected_with_value = expected_without_value.copy() expected_with_value["value"] = "a" + expected_with_value["full_content"] = None assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value def test_node_variable(self): @@ -100,14 +157,65 @@ class TestWorkflowDraftVariableFields: "value_type": "array[any]", "edited": True, "visible": False, + "is_truncated": False, } ) assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value expected_with_value = expected_without_value.copy() expected_with_value["value"] = [1, "a"] + expected_with_value["full_content"] = None assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + def test_node_variable_with_file(self): + node_var = WorkflowDraftVariable.new_node_variable( + app_id=_TEST_APP_ID, + node_id="test_node", + name="node_var", + value=build_segment([1, "a"]), + visible=False, + node_execution_id=_TEST_NODE_EXEC_ID, + ) + + node_var.id = str(uuid.uuid4()) + node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + variable_file = WorkflowDraftVariableFile( + id=str(uuidv7()), + upload_file_id=str(uuid.uuid4()), + size=1024, + length=10, + value_type=SegmentType.ARRAY_STRING, + ) + node_var.variable_file = variable_file + node_var.file_id = variable_file.id + + expected_without_value: OrderedDict[str, Any] = OrderedDict( + { + "id": str(node_var.id), + "type": node_var.get_variable_type().value, + "name": "node_var", + "description": "", + "selector": ["test_node", "node_var"], + "value_type": "array[any]", + "edited": True, + "visible": False, + "is_truncated": True, + } + ) + + with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers: + mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url" + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value + expected_with_value = expected_without_value.copy() + expected_with_value["value"] = [1, "a"] + expected_with_value["full_content"] = { + "size_bytes": 1024, + "value_type": "array[string]", + "length": 10, + "download_url": "http://example.com/signed-url", + } + assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value + class TestWorkflowDraftVariableList: def test_workflow_draft_variable_list(self): @@ -135,6 +243,7 @@ class TestWorkflowDraftVariableList: "value_type": "string", "edited": False, "visible": True, + "is_truncated": False, } ) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_process_data.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_process_data.py new file mode 100644 index 0000000000..c2cd1e9296 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_process_data.py @@ -0,0 +1,429 @@ +""" +Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality. +""" + +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import Any +from unittest.mock import Mock + +import pytest + +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent +from core.helper.code_executor.code_executor import CodeLanguage +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.enums import NodeType +from libs.datetime_utils import naive_utc_now + + +@dataclass +class ProcessDataResponseScenario: + """Test scenario for process_data in responses.""" + + name: str + original_process_data: dict[str, Any] | None + truncated_process_data: dict[str, Any] | None + expected_response_data: dict[str, Any] | None + expected_truncated_flag: bool + + +class TestWorkflowResponseConverterCenarios: + """Test process_data truncation in WorkflowResponseConverter.""" + + def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity: + """Create a mock WorkflowAppGenerateEntity.""" + mock_entity = Mock(spec=WorkflowAppGenerateEntity) + mock_app_config = Mock() + mock_app_config.tenant_id = "test-tenant-id" + mock_entity.app_config = mock_app_config + return mock_entity + + def create_workflow_response_converter(self) -> WorkflowResponseConverter: + """Create a WorkflowResponseConverter for testing.""" + mock_entity = self.create_mock_generate_entity() + return WorkflowResponseConverter(application_generate_entity=mock_entity) + + def create_workflow_node_execution( + self, + process_data: dict[str, Any] | None = None, + truncated_process_data: dict[str, Any] | None = None, + execution_id: str = "test-execution-id", + ) -> WorkflowNodeExecution: + """Create a WorkflowNodeExecution for testing.""" + execution = WorkflowNodeExecution( + id=execution_id, + workflow_id="test-workflow-id", + workflow_execution_id="test-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=process_data, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_at=datetime.now(), + finished_at=datetime.now(), + ) + + if truncated_process_data is not None: + execution.set_truncated_process_data(truncated_process_data) + + return execution + + def create_node_succeeded_event(self) -> QueueNodeSucceededEvent: + """Create a QueueNodeSucceededEvent for testing.""" + return QueueNodeSucceededEvent( + node_id="test-node-id", + node_type=NodeType.CODE, + node_data=CodeNodeData( + title="test code", + variables=[], + code_language=CodeLanguage.PYTHON3, + code="", + outputs={}, + ), + node_execution_id=str(uuid.uuid4()), + start_at=naive_utc_now(), + parallel_id=None, + parallel_start_node_id=None, + parent_parallel_id=None, + parent_parallel_start_node_id=None, + in_iteration_id=None, + in_loop_id=None, + ) + + def create_node_retry_event(self) -> QueueNodeRetryEvent: + """Create a QueueNodeRetryEvent for testing.""" + return QueueNodeRetryEvent( + inputs={"data": "inputs"}, + outputs={"data": "outputs"}, + error="oops", + retry_index=1, + node_id="test-node-id", + node_type=NodeType.CODE, + node_data=CodeNodeData( + title="test code", + variables=[], + code_language=CodeLanguage.PYTHON3, + code="", + outputs={}, + ), + node_execution_id=str(uuid.uuid4()), + start_at=naive_utc_now(), + parallel_id=None, + parallel_start_node_id=None, + parent_parallel_id=None, + parent_parallel_start_node_id=None, + in_iteration_id=None, + in_loop_id=None, + ) + + def test_workflow_node_finish_response_uses_truncated_process_data(self): + """Test that node finish response uses get_response_process_data().""" + converter = self.create_workflow_response_converter() + + original_data = {"large_field": "x" * 10000, "metadata": "info"} + truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} + + execution = self.create_workflow_node_execution( + process_data=original_data, truncated_process_data=truncated_data + ) + event = self.create_node_succeeded_event() + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + # Response should use truncated data, not original + assert response is not None + assert response.data.process_data == truncated_data + assert response.data.process_data != original_data + assert response.data.process_data_truncated is True + + def test_workflow_node_finish_response_without_truncation(self): + """Test node finish response when no truncation is applied.""" + converter = self.create_workflow_response_converter() + + original_data = {"small": "data"} + + execution = self.create_workflow_node_execution(process_data=original_data) + event = self.create_node_succeeded_event() + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + # Response should use original data + assert response is not None + assert response.data.process_data == original_data + assert response.data.process_data_truncated is False + + def test_workflow_node_finish_response_with_none_process_data(self): + """Test node finish response when process_data is None.""" + converter = self.create_workflow_response_converter() + + execution = self.create_workflow_node_execution(process_data=None) + event = self.create_node_succeeded_event() + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + # Response should have None process_data + assert response is not None + assert response.data.process_data is None + assert response.data.process_data_truncated is False + + def test_workflow_node_retry_response_uses_truncated_process_data(self): + """Test that node retry response uses get_response_process_data().""" + converter = self.create_workflow_response_converter() + + original_data = {"large_field": "x" * 10000, "metadata": "info"} + truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"} + + execution = self.create_workflow_node_execution( + process_data=original_data, truncated_process_data=truncated_data + ) + event = self.create_node_retry_event() + + response = converter.workflow_node_retry_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + # Response should use truncated data, not original + assert response is not None + assert response.data.process_data == truncated_data + assert response.data.process_data != original_data + assert response.data.process_data_truncated is True + + def test_workflow_node_retry_response_without_truncation(self): + """Test node retry response when no truncation is applied.""" + converter = self.create_workflow_response_converter() + + original_data = {"small": "data"} + + execution = self.create_workflow_node_execution(process_data=original_data) + event = self.create_node_retry_event() + + response = converter.workflow_node_retry_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + # Response should use original data + assert response is not None + assert response.data.process_data == original_data + assert response.data.process_data_truncated is False + + def test_iteration_and_loop_nodes_return_none(self): + """Test that iteration and loop nodes return None (no change from existing behavior).""" + converter = self.create_workflow_response_converter() + + # Test iteration node + iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"}) + iteration_execution.node_type = NodeType.ITERATION + + event = self.create_node_succeeded_event() + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=iteration_execution, + ) + + # Should return None for iteration nodes + assert response is None + + # Test loop node + loop_execution = self.create_workflow_node_execution(process_data={"test": "data"}) + loop_execution.node_type = NodeType.LOOP + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=loop_execution, + ) + + # Should return None for loop nodes + assert response is None + + def test_execution_without_workflow_execution_id_returns_none(self): + """Test that executions without workflow_execution_id return None.""" + converter = self.create_workflow_response_converter() + + execution = self.create_workflow_node_execution(process_data={"test": "data"}) + execution.workflow_execution_id = None # Single-step debugging + + event = self.create_node_succeeded_event() + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + # Should return None for single-step debugging + assert response is None + + @staticmethod + def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]: + """Create test scenarios for process_data responses.""" + return [ + ProcessDataResponseScenario( + name="none_process_data", + original_process_data=None, + truncated_process_data=None, + expected_response_data=None, + expected_truncated_flag=False, + ), + ProcessDataResponseScenario( + name="small_process_data_no_truncation", + original_process_data={"small": "data"}, + truncated_process_data=None, + expected_response_data={"small": "data"}, + expected_truncated_flag=False, + ), + ProcessDataResponseScenario( + name="large_process_data_with_truncation", + original_process_data={"large": "x" * 10000, "metadata": "info"}, + truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"}, + expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, + expected_truncated_flag=True, + ), + ProcessDataResponseScenario( + name="empty_process_data", + original_process_data={}, + truncated_process_data=None, + expected_response_data={}, + expected_truncated_flag=False, + ), + ProcessDataResponseScenario( + name="complex_data_with_truncation", + original_process_data={ + "logs": ["entry"] * 1000, # Large array + "config": {"setting": "value"}, + "status": "processing", + }, + truncated_process_data={ + "logs": "[TRUNCATED: 1000 items]", + "config": {"setting": "value"}, + "status": "processing", + }, + expected_response_data={ + "logs": "[TRUNCATED: 1000 items]", + "config": {"setting": "value"}, + "status": "processing", + }, + expected_truncated_flag=True, + ), + ] + + @pytest.mark.parametrize( + "scenario", + [scenario for scenario in get_process_data_response_scenarios()], + ids=[scenario.name for scenario in get_process_data_response_scenarios()], + ) + def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario): + """Test various scenarios for node finish responses.""" + converter = WorkflowResponseConverter( + application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")) + ) + + execution = WorkflowNodeExecution( + id="test-execution-id", + workflow_id="test-workflow-id", + workflow_execution_id="test-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=scenario.original_process_data, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_at=datetime.now(), + finished_at=datetime.now(), + ) + + if scenario.truncated_process_data is not None: + execution.set_truncated_process_data(scenario.truncated_process_data) + + event = QueueNodeSucceededEvent( + node_id="test-node-id", + node_type=NodeType.CODE, + node_data=CodeNodeData( + title="test code", + variables=[], + code_language=CodeLanguage.PYTHON3, + code="", + outputs={}, + ), + node_execution_id=str(uuid.uuid4()), + start_at=naive_utc_now(), + parallel_id=None, + parallel_start_node_id=None, + parent_parallel_id=None, + parent_parallel_start_node_id=None, + in_iteration_id=None, + in_loop_id=None, + ) + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + assert response is not None + assert response.data.process_data == scenario.expected_response_data + assert response.data.process_data_truncated == scenario.expected_truncated_flag + + @pytest.mark.parametrize( + "scenario", + [scenario for scenario in get_process_data_response_scenarios()], + ids=[scenario.name for scenario in get_process_data_response_scenarios()], + ) + def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario): + """Test various scenarios for node retry responses.""" + converter = WorkflowResponseConverter( + application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")) + ) + + execution = WorkflowNodeExecution( + id="test-execution-id", + workflow_id="test-workflow-id", + workflow_execution_id="test-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=scenario.original_process_data, + status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario + created_at=datetime.now(), + finished_at=datetime.now(), + ) + + if scenario.truncated_process_data is not None: + execution.set_truncated_process_data(scenario.truncated_process_data) + + event = self.create_node_retry_event() + + response = converter.workflow_node_retry_to_stream_response( + event=event, + task_id="test-task-id", + workflow_node_execution=execution, + ) + + assert response is not None + assert response.data.process_data == scenario.expected_response_data + assert response.data.process_data_truncated == scenario.expected_truncated_flag diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py new file mode 100644 index 0000000000..c1c9707daf --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -0,0 +1,248 @@ +""" +Unit tests for WorkflowNodeExecution truncation functionality. + +Tests the truncation and offloading logic for large inputs and outputs +in the SQLAlchemyWorkflowNodeExecutionRepository. +""" + +import json +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any +from unittest.mock import MagicMock, patch + +from sqlalchemy import Engine + +from core.repositories.sqlalchemy_workflow_node_execution_repository import ( + SQLAlchemyWorkflowNodeExecutionRepository, +) +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from core.workflow.nodes.enums import NodeType +from models import Account, WorkflowNodeExecutionTriggeredFrom +from models.enums import ExecutionOffLoadType +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload + +TRUNCATION_SIZE_THRESHOLD = 500 + + +@dataclass +class TruncationTestCase: + """Test case data for truncation scenarios.""" + + name: str + inputs: dict[str, Any] | None + outputs: dict[str, Any] | None + should_truncate_inputs: bool + should_truncate_outputs: bool + description: str + + +def create_test_cases() -> list[TruncationTestCase]: + """Create test cases for different truncation scenarios.""" + # Create large data that will definitely exceed the threshold (10KB) + large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1000)} + small_data = {"data": "small"} + + return [ + TruncationTestCase( + name="small_data_no_truncation", + inputs=small_data, + outputs=small_data, + should_truncate_inputs=False, + should_truncate_outputs=False, + description="Small data should not be truncated", + ), + TruncationTestCase( + name="large_inputs_truncation", + inputs=large_data, + outputs=small_data, + should_truncate_inputs=True, + should_truncate_outputs=False, + description="Large inputs should be truncated", + ), + TruncationTestCase( + name="large_outputs_truncation", + inputs=small_data, + outputs=large_data, + should_truncate_inputs=False, + should_truncate_outputs=True, + description="Large outputs should be truncated", + ), + TruncationTestCase( + name="large_both_truncation", + inputs=large_data, + outputs=large_data, + should_truncate_inputs=True, + should_truncate_outputs=True, + description="Both large inputs and outputs should be truncated", + ), + TruncationTestCase( + name="none_inputs_outputs", + inputs=None, + outputs=None, + should_truncate_inputs=False, + should_truncate_outputs=False, + description="None inputs and outputs should not be truncated", + ), + ] + + +def create_workflow_node_execution( + execution_id: str = "test-execution-id", + inputs: dict[str, Any] | None = None, + outputs: dict[str, Any] | None = None, +) -> WorkflowNodeExecution: + """Factory function to create a WorkflowNodeExecution for testing.""" + return WorkflowNodeExecution( + id=execution_id, + node_execution_id="test-node-execution-id", + workflow_id="test-workflow-id", + workflow_execution_id="test-workflow-execution-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + inputs=inputs, + outputs=outputs, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + created_at=datetime.now(UTC), + ) + + +def mock_user() -> Account: + """Create a mock Account user for testing.""" + from unittest.mock import MagicMock + + user = MagicMock(spec=Account) + user.id = "test-user-id" + user.current_tenant_id = "test-tenant-id" + return user + + +class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation: + """Test class for truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository.""" + + def create_repository(self) -> SQLAlchemyWorkflowNodeExecutionRepository: + """Create a repository instance for testing.""" + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=MagicMock(spec=Engine), + user=mock_user(), + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + def test_to_domain_model_without_offload_data(self): + """Test _to_domain_model correctly handles models without offload data.""" + repo = self.create_repository() + + # Create a mock database model without offload data + db_model = WorkflowNodeExecutionModel() + db_model.id = "test-id" + db_model.node_execution_id = "node-exec-id" + db_model.workflow_id = "workflow-id" + db_model.workflow_run_id = "run-id" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node-id" + db_model.node_type = NodeType.LLM.value + db_model.title = "Test Node" + db_model.inputs = json.dumps({"value": "inputs"}) + db_model.process_data = json.dumps({"value": "process_data"}) + db_model.outputs = json.dumps({"value": "outputs"}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + db_model.error = None + db_model.elapsed_time = 1.0 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + db_model.offload_data = [] + + domain_model = repo._to_domain_model(db_model) + + # Check that no truncated data was set + assert domain_model.get_truncated_inputs() is None + assert domain_model.get_truncated_outputs() is None + + @patch("core.repositories.sqlalchemy_workflow_node_execution_repository.FileService") + def test_save_with_truncation(self, mock_file_service_class): + """Test the save method handles truncation and offload record creation.""" + # Setup mock file service + mock_file_service = MagicMock() + mock_upload_file = MagicMock() + mock_upload_file.id = "mock-file-id" + mock_file_service.upload_file.return_value = mock_upload_file + mock_file_service_class.return_value = mock_file_service + + large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1)} + + repo = self.create_repository() + execution = create_workflow_node_execution( + inputs=large_data, + outputs=large_data, + ) + + # Mock the session and database operations + with patch.object(repo, "_session_factory") as mock_session_factory: + mock_session = MagicMock() + mock_session_factory.return_value.__enter__.return_value = mock_session + + repo.save(execution) + + # Check that both merge operations were called (db_model and offload_record) + assert mock_session.merge.call_count == 1 + mock_session.commit.assert_called_once() + + +class TestWorkflowNodeExecutionModelTruncatedProperties: + """Test the truncated properties on WorkflowNodeExecutionModel.""" + + def test_inputs_truncated_with_offload_data(self): + """Test inputs_truncated property when offload data exists.""" + model = WorkflowNodeExecutionModel() + offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + model.offload_data = [offload] + + assert model.inputs_truncated is True + assert model.process_data_truncated is False + assert model.outputs_truncated is False + + def test_outputs_truncated_with_offload_data(self): + """Test outputs_truncated property when offload data exists.""" + model = WorkflowNodeExecutionModel() + + # Mock offload data with outputs file + offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + model.offload_data = [offload] + + assert model.inputs_truncated is False + assert model.process_data_truncated is False + assert model.outputs_truncated is True + + def test_process_data_truncated_with_offload_data(self): + model = WorkflowNodeExecutionModel() + offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA) + model.offload_data = [offload] + assert model.process_data_truncated is True + assert model.inputs_truncated is False + assert model.outputs_truncated is False + + def test_truncated_properties_without_offload_data(self): + """Test truncated properties when no offload data exists.""" + model = WorkflowNodeExecutionModel() + model.offload_data = [] + + assert model.inputs_truncated is False + assert model.outputs_truncated is False + assert model.process_data_truncated is False + + def test_truncated_properties_without_offload_attribute(self): + """Test truncated properties when offload_data attribute doesn't exist.""" + model = WorkflowNodeExecutionModel() + # Don't set offload_data attribute at all + + assert model.inputs_truncated is False + assert model.outputs_truncated is False + assert model.process_data_truncated is False diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py new file mode 100644 index 0000000000..431e62ce94 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py @@ -0,0 +1,225 @@ +""" +Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Any + +import pytest + +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution +from core.workflow.nodes.enums import NodeType + + +class TestWorkflowNodeExecutionProcessDataTruncation: + """Test process_data truncation functionality in WorkflowNodeExecution domain model.""" + + def create_workflow_node_execution( + self, + process_data: dict[str, Any] | None = None, + ) -> WorkflowNodeExecution: + """Create a WorkflowNodeExecution instance for testing.""" + return WorkflowNodeExecution( + id="test-execution-id", + workflow_id="test-workflow-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=process_data, + created_at=datetime.now(), + ) + + def test_initial_process_data_truncated_state(self): + """Test that process_data_truncated returns False initially.""" + execution = self.create_workflow_node_execution() + + assert execution.process_data_truncated is False + assert execution.get_truncated_process_data() is None + + def test_set_and_get_truncated_process_data(self): + """Test setting and getting truncated process_data.""" + execution = self.create_workflow_node_execution() + test_truncated_data = {"truncated": True, "key": "value"} + + execution.set_truncated_process_data(test_truncated_data) + + assert execution.process_data_truncated is True + assert execution.get_truncated_process_data() == test_truncated_data + + def test_set_truncated_process_data_to_none(self): + """Test setting truncated process_data to None.""" + execution = self.create_workflow_node_execution() + + # First set some data + execution.set_truncated_process_data({"key": "value"}) + assert execution.process_data_truncated is True + + # Then set to None + execution.set_truncated_process_data(None) + assert execution.process_data_truncated is False + assert execution.get_truncated_process_data() is None + + def test_get_response_process_data_with_no_truncation(self): + """Test get_response_process_data when no truncation is set.""" + original_data = {"original": True, "data": "value"} + execution = self.create_workflow_node_execution(process_data=original_data) + + response_data = execution.get_response_process_data() + + assert response_data == original_data + assert execution.process_data_truncated is False + + def test_get_response_process_data_with_truncation(self): + """Test get_response_process_data when truncation is set.""" + original_data = {"original": True, "large_data": "x" * 10000} + truncated_data = {"original": True, "large_data": "[TRUNCATED]"} + + execution = self.create_workflow_node_execution(process_data=original_data) + execution.set_truncated_process_data(truncated_data) + + response_data = execution.get_response_process_data() + + # Should return truncated data, not original + assert response_data == truncated_data + assert response_data != original_data + assert execution.process_data_truncated is True + + def test_get_response_process_data_with_none_process_data(self): + """Test get_response_process_data when process_data is None.""" + execution = self.create_workflow_node_execution(process_data=None) + + response_data = execution.get_response_process_data() + + assert response_data is None + assert execution.process_data_truncated is False + + def test_consistency_with_inputs_outputs_pattern(self): + """Test that process_data truncation follows the same pattern as inputs/outputs.""" + execution = self.create_workflow_node_execution() + + # Test that all truncation methods exist and behave consistently + test_data = {"test": "data"} + + # Test inputs truncation + execution.set_truncated_inputs(test_data) + assert execution.inputs_truncated is True + assert execution.get_truncated_inputs() == test_data + + # Test outputs truncation + execution.set_truncated_outputs(test_data) + assert execution.outputs_truncated is True + assert execution.get_truncated_outputs() == test_data + + # Test process_data truncation + execution.set_truncated_process_data(test_data) + assert execution.process_data_truncated is True + assert execution.get_truncated_process_data() == test_data + + @pytest.mark.parametrize( + "test_data", + [ + {"simple": "value"}, + {"nested": {"key": "value"}}, + {"list": [1, 2, 3]}, + {"mixed": {"string": "value", "number": 42, "list": [1, 2]}}, + {}, # empty dict + ], + ) + def test_truncated_process_data_with_various_data_types(self, test_data): + """Test that truncated process_data works with various data types.""" + execution = self.create_workflow_node_execution() + + execution.set_truncated_process_data(test_data) + + assert execution.process_data_truncated is True + assert execution.get_truncated_process_data() == test_data + assert execution.get_response_process_data() == test_data + + +@dataclass +class ProcessDataScenario: + """Test scenario data for process_data functionality.""" + + name: str + original_data: dict[str, Any] | None + truncated_data: dict[str, Any] | None + expected_truncated_flag: bool + expected_response_data: dict[str, Any] | None + + +class TestWorkflowNodeExecutionProcessDataScenarios: + """Test various scenarios for process_data handling.""" + + def get_process_data_scenarios(self) -> list[ProcessDataScenario]: + """Create test scenarios for process_data functionality.""" + return [ + ProcessDataScenario( + name="no_process_data", + original_data=None, + truncated_data=None, + expected_truncated_flag=False, + expected_response_data=None, + ), + ProcessDataScenario( + name="process_data_without_truncation", + original_data={"small": "data"}, + truncated_data=None, + expected_truncated_flag=False, + expected_response_data={"small": "data"}, + ), + ProcessDataScenario( + name="process_data_with_truncation", + original_data={"large": "x" * 10000, "metadata": "info"}, + truncated_data={"large": "[TRUNCATED]", "metadata": "info"}, + expected_truncated_flag=True, + expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, + ), + ProcessDataScenario( + name="empty_process_data", + original_data={}, + truncated_data=None, + expected_truncated_flag=False, + expected_response_data={}, + ), + ProcessDataScenario( + name="complex_nested_data_with_truncation", + original_data={ + "config": {"setting": "value"}, + "logs": ["log1", "log2"] * 1000, # Large list + "status": "running", + }, + truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"}, + expected_truncated_flag=True, + expected_response_data={ + "config": {"setting": "value"}, + "logs": "[TRUNCATED: 2000 items]", + "status": "running", + }, + ), + ] + + @pytest.mark.parametrize( + "scenario", + [scenario for scenario in get_process_data_scenarios(None)], + ids=[scenario.name for scenario in get_process_data_scenarios(None)], + ) + def test_process_data_scenarios(self, scenario: ProcessDataScenario): + """Test various process_data scenarios.""" + execution = WorkflowNodeExecution( + id="test-execution-id", + workflow_id="test-workflow-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=scenario.original_data, + created_at=datetime.now(), + ) + + if scenario.truncated_data is not None: + execution.set_truncated_process_data(scenario.truncated_data) + + assert execution.process_data_truncated == scenario.expected_truncated_flag + assert execution.get_response_process_data() == scenario.expected_response_data diff --git a/api/tests/unit_tests/models/test_workflow_node_execution_offload.py b/api/tests/unit_tests/models/test_workflow_node_execution_offload.py new file mode 100644 index 0000000000..93f66914c5 --- /dev/null +++ b/api/tests/unit_tests/models/test_workflow_node_execution_offload.py @@ -0,0 +1,181 @@ +""" +Unit tests for WorkflowNodeExecutionOffload model, focusing on process_data truncation functionality. +""" + +from unittest.mock import Mock + +import pytest + +from models.model import UploadFile +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload + + +class TestWorkflowNodeExecutionOffload: + """Test WorkflowNodeExecutionOffload model with process_data fields.""" + + def test_get_exe(self): + WorkflowNodeExecutionOffload + + +class TestWorkflowNodeExecutionModel: + """Test WorkflowNodeExecutionModel with process_data truncation features.""" + + def create_mock_offload_data( + self, + inputs_file_id: str | None = None, + outputs_file_id: str | None = None, + process_data_file_id: str | None = None, + ) -> WorkflowNodeExecutionOffload: + """Create a mock offload data object.""" + offload = Mock(spec=WorkflowNodeExecutionOffload) + offload.inputs_file_id = inputs_file_id + offload.outputs_file_id = outputs_file_id + offload.process_data_file_id = process_data_file_id + + # Mock file objects + if inputs_file_id: + offload.inputs_file = Mock(spec=UploadFile) + else: + offload.inputs_file = None + + if outputs_file_id: + offload.outputs_file = Mock(spec=UploadFile) + else: + offload.outputs_file = None + + if process_data_file_id: + offload.process_data_file = Mock(spec=UploadFile) + else: + offload.process_data_file = None + + return offload + + def test_process_data_truncated_property_false_when_no_offload_data(self): + """Test process_data_truncated returns False when no offload_data.""" + execution = WorkflowNodeExecutionModel() + execution.offload_data = None + + assert execution.process_data_truncated is False + + def test_process_data_truncated_property_false_when_no_process_data_file(self): + """Test process_data_truncated returns False when no process_data file.""" + execution = WorkflowNodeExecutionModel() + + # Create real offload instance + offload_data = WorkflowNodeExecutionOffload() + offload_data.inputs_file_id = "inputs-file" + offload_data.outputs_file_id = "outputs-file" + offload_data.process_data_file_id = None # No process_data file + execution.offload_data = offload_data + + assert execution.process_data_truncated is False + + def test_process_data_truncated_property_true_when_process_data_file_exists(self): + """Test process_data_truncated returns True when process_data file exists.""" + execution = WorkflowNodeExecutionModel() + + # Create a real offload instance rather than mock + offload_data = WorkflowNodeExecutionOffload() + offload_data.process_data_file_id = "process-data-file-id" + execution.offload_data = offload_data + + assert execution.process_data_truncated is True + + def test_load_full_process_data_with_no_offload_data(self): + """Test load_full_process_data when no offload data exists.""" + execution = WorkflowNodeExecutionModel() + execution.offload_data = None + execution.process_data_dict = {"test": "data"} + + # Mock session and storage + mock_session = Mock() + mock_storage = Mock() + + result = execution.load_full_process_data(mock_session, mock_storage) + + assert result == {"test": "data"} + + def test_load_full_process_data_with_no_file(self): + """Test load_full_process_data when no process_data file exists.""" + execution = WorkflowNodeExecutionModel() + execution.offload_data = self.create_mock_offload_data(process_data_file_id=None) + execution.process_data_dict = {"test": "data"} + + # Mock session and storage + mock_session = Mock() + mock_storage = Mock() + + result = execution.load_full_process_data(mock_session, mock_storage) + + assert result == {"test": "data"} + + def test_load_full_process_data_with_file(self): + """Test load_full_process_data when process_data file exists.""" + execution = WorkflowNodeExecutionModel() + offload_data = self.create_mock_offload_data(process_data_file_id="file-id") + execution.offload_data = offload_data + execution.process_data_dict = {"truncated": "data"} + + # Mock session and storage + mock_session = Mock() + mock_storage = Mock() + + # Mock the _load_full_content method to return full data + full_process_data = {"full": "data", "large_field": "x" * 10000} + + with pytest.MonkeyPatch.context() as mp: + # Mock the _load_full_content method + def mock_load_full_content(session, file_id, storage): + assert session == mock_session + assert file_id == "file-id" + assert storage == mock_storage + return full_process_data + + mp.setattr(execution, "_load_full_content", mock_load_full_content) + + result = execution.load_full_process_data(mock_session, mock_storage) + + assert result == full_process_data + + def test_consistency_with_inputs_outputs_truncation(self): + """Test that process_data truncation behaves consistently with inputs/outputs.""" + execution = WorkflowNodeExecutionModel() + + # Test all three truncation properties together + offload_data = self.create_mock_offload_data( + inputs_file_id="inputs-file", outputs_file_id="outputs-file", process_data_file_id="process-data-file" + ) + execution.offload_data = offload_data + + # All should be truncated + assert execution.inputs_truncated is True + assert execution.outputs_truncated is True + assert execution.process_data_truncated is True + + def test_mixed_truncation_states(self): + """Test mixed states of truncation.""" + execution = WorkflowNodeExecutionModel() + + # Only process_data is truncated + offload_data = self.create_mock_offload_data( + inputs_file_id=None, outputs_file_id=None, process_data_file_id="process-data-file" + ) + execution.offload_data = offload_data + + assert execution.inputs_truncated is False + assert execution.outputs_truncated is False + assert execution.process_data_truncated is True + + def test_preload_offload_data_and_files_method_exists(self): + """Test that the preload method includes process_data_file.""" + # This test verifies the method exists and can be called + # The actual SQL behavior would be tested in integration tests + from sqlalchemy import select + + stmt = select(WorkflowNodeExecutionModel) + + # This should not raise an exception + preloaded_stmt = WorkflowNodeExecutionModel.preload_offload_data_and_files(stmt) + + # The statement should be modified (different object) + assert preloaded_stmt is not stmt diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 0000000000..339d335a34 --- /dev/null +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,362 @@ +""" +Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality. +""" + +import json +from dataclasses import dataclass +from datetime import datetime +from unittest.mock import MagicMock, Mock, patch + +import pytest +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_node_execution_repository import ( + SQLAlchemyWorkflowNodeExecutionRepository, + _InputsOutputsTruncationResult, +) +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution +from core.workflow.nodes.enums import NodeType +from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom +from models.model import UploadFile +from models.workflow import WorkflowNodeExecutionOffload + + +class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData: + """Test process_data truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository.""" + + def create_mock_account(self) -> Account: + """Create a mock Account for testing.""" + account = Mock(spec=Account) + account.id = "test-user-id" + account.tenant_id = "test-tenant-id" + return account + + def create_mock_session_factory(self) -> sessionmaker: + """Create a mock session factory for testing.""" + mock_session = MagicMock() + mock_session_factory = MagicMock(spec=sessionmaker) + mock_session_factory.return_value.__enter__.return_value = mock_session + mock_session_factory.return_value.__exit__.return_value = None + return mock_session_factory + + def create_repository(self, mock_file_service=None) -> SQLAlchemyWorkflowNodeExecutionRepository: + """Create a repository instance for testing.""" + mock_account = self.create_mock_account() + mock_session_factory = self.create_mock_session_factory() + + repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if mock_file_service: + repository._file_service = mock_file_service + + return repository + + def create_workflow_node_execution( + self, + process_data: dict[str, any] | None = None, + execution_id: str = "test-execution-id", + ) -> WorkflowNodeExecution: + """Create a WorkflowNodeExecution instance for testing.""" + return WorkflowNodeExecution( + id=execution_id, + workflow_id="test-workflow-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=process_data, + created_at=datetime.now(), + ) + + @patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config') + def test_to_db_model_with_small_process_data(self, mock_config): + """Test _to_db_model with small process_data that doesn't need truncation.""" + mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100 + mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500 + + repository = self.create_repository() + small_process_data = {"small": "data", "count": 5} + + execution = self.create_workflow_node_execution(process_data=small_process_data) + + with patch.object(repository, '_truncate_and_upload', return_value=None) as mock_truncate: + db_model = repository._to_db_model(execution) + + # Should try to truncate but return None (no truncation needed) + mock_truncate.assert_called_once_with( + small_process_data, + execution.id, + "_process_data" + ) + + # Process data should be stored directly in database + assert db_model.process_data is not None + stored_data = json.loads(db_model.process_data) + assert stored_data == small_process_data + + # No offload data should be created for process_data + assert db_model.offload_data is None + + def test_to_db_model_with_large_process_data(self): + """Test _to_db_model with large process_data that needs truncation.""" + repository = self.create_repository() + + # Create large process_data that would need truncation + large_process_data = { + "large_field": "x" * 10000, # Very large string + "metadata": {"type": "processing", "timestamp": 1234567890} + } + + # Mock truncation result + truncated_data = { + "large_field": "[TRUNCATED]", + "metadata": {"type": "processing", "timestamp": 1234567890} + } + + mock_upload_file = Mock(spec=UploadFile) + mock_upload_file.id = "mock-file-id" + + truncation_result = _InputsOutputsTruncationResult( + truncated_value=truncated_data, + file=mock_upload_file + ) + + execution = self.create_workflow_node_execution(process_data=large_process_data) + + with patch.object(repository, '_truncate_and_upload', return_value=truncation_result) as mock_truncate: + db_model = repository._to_db_model(execution) + + # Should call truncate with correct parameters + mock_truncate.assert_called_once_with( + large_process_data, + execution.id, + "_process_data" + ) + + # Truncated data should be stored in database + assert db_model.process_data is not None + stored_data = json.loads(db_model.process_data) + assert stored_data == truncated_data + + # Domain model should have truncated data set + assert execution.process_data_truncated is True + assert execution.get_truncated_process_data() == truncated_data + + # Offload data should be created + assert db_model.offload_data is not None + assert db_model.offload_data.process_data_file == mock_upload_file + assert db_model.offload_data.process_data_file_id == "mock-file-id" + + def test_to_db_model_with_none_process_data(self): + """Test _to_db_model with None process_data.""" + repository = self.create_repository() + execution = self.create_workflow_node_execution(process_data=None) + + with patch.object(repository, '_truncate_and_upload') as mock_truncate: + db_model = repository._to_db_model(execution) + + # Should not call truncate for None data + mock_truncate.assert_not_called() + + # Process data should be None + assert db_model.process_data is None + + # No offload data should be created + assert db_model.offload_data is None + + def test_to_domain_model_with_offloaded_process_data(self): + """Test _to_domain_model with offloaded process_data.""" + repository = self.create_repository() + + # Create mock database model with offload data + db_model = Mock(spec=WorkflowNodeExecutionModel) + db_model.id = "test-execution-id" + db_model.node_execution_id = "test-node-execution-id" + db_model.workflow_id = "test-workflow-id" + db_model.workflow_run_id = None + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "test-node-id" + db_model.node_type = "llm" + db_model.title = "Test Node" + db_model.status = "succeeded" + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.created_at = datetime.now() + db_model.finished_at = None + + # Mock truncated process_data from database + truncated_process_data = {"large_field": "[TRUNCATED]", "metadata": "info"} + db_model.process_data_dict = truncated_process_data + db_model.inputs_dict = None + db_model.outputs_dict = None + db_model.execution_metadata_dict = {} + + # Mock offload data with process_data file + mock_offload_data = Mock(spec=WorkflowNodeExecutionOffload) + mock_offload_data.inputs_file_id = None + mock_offload_data.inputs_file = None + mock_offload_data.outputs_file_id = None + mock_offload_data.outputs_file = None + mock_offload_data.process_data_file_id = "process-data-file-id" + + mock_process_data_file = Mock(spec=UploadFile) + mock_offload_data.process_data_file = mock_process_data_file + + db_model.offload_data = mock_offload_data + + # Mock the file loading + original_process_data = { + "large_field": "x" * 10000, + "metadata": "info" + } + + with patch.object(repository, '_load_file', return_value=original_process_data) as mock_load: + domain_model = repository._to_domain_model(db_model) + + # Should load the file + mock_load.assert_called_once_with(mock_process_data_file) + + # Domain model should have original data + assert domain_model.process_data == original_process_data + + # Domain model should have truncated data set + assert domain_model.process_data_truncated is True + assert domain_model.get_truncated_process_data() == truncated_process_data + + def test_to_domain_model_without_offload_data(self): + """Test _to_domain_model without offload data.""" + repository = self.create_repository() + + # Create mock database model without offload data + db_model = Mock(spec=WorkflowNodeExecutionModel) + db_model.id = "test-execution-id" + db_model.node_execution_id = "test-node-execution-id" + db_model.workflow_id = "test-workflow-id" + db_model.workflow_run_id = None + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "test-node-id" + db_model.node_type = "llm" + db_model.title = "Test Node" + db_model.status = "succeeded" + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.created_at = datetime.now() + db_model.finished_at = None + + process_data = {"normal": "data"} + db_model.process_data_dict = process_data + db_model.inputs_dict = None + db_model.outputs_dict = None + db_model.execution_metadata_dict = {} + db_model.offload_data = None + + domain_model = repository._to_domain_model(db_model) + + # Domain model should have the data from database + assert domain_model.process_data == process_data + + # Should not be truncated + assert domain_model.process_data_truncated is False + assert domain_model.get_truncated_process_data() is None + + +@dataclass +class TruncationScenario: + """Test scenario for truncation functionality.""" + name: str + process_data: dict[str, any] | None + should_truncate: bool + expected_truncated: bool = False + + +class TestProcessDataTruncationScenarios: + """Test various scenarios for process_data truncation.""" + + def get_truncation_scenarios(self) -> list[TruncationScenario]: + """Create test scenarios for truncation.""" + return [ + TruncationScenario( + name="none_data", + process_data=None, + should_truncate=False, + ), + TruncationScenario( + name="small_data", + process_data={"key": "value"}, + should_truncate=False, + ), + TruncationScenario( + name="large_data", + process_data={"large": "x" * 10000}, + should_truncate=True, + expected_truncated=True, + ), + TruncationScenario( + name="empty_data", + process_data={}, + should_truncate=False, + ), + ] + + @pytest.mark.parametrize("scenario", [ + scenario for scenario in get_truncation_scenarios(None) + ], ids=[scenario.name for scenario in get_truncation_scenarios(None)]) + def test_process_data_truncation_scenarios(self, scenario: TruncationScenario): + """Test various process_data truncation scenarios.""" + repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=MagicMock(spec=sessionmaker), + user=Mock(spec=Account, id="test-user", tenant_id="test-tenant"), + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = WorkflowNodeExecution( + id="test-execution-id", + workflow_id="test-workflow-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + process_data=scenario.process_data, + created_at=datetime.now(), + ) + + # Mock truncation behavior + if scenario.should_truncate: + truncated_data = {"truncated": True} + mock_file = Mock(spec=UploadFile, id="file-id") + truncation_result = _InputsOutputsTruncationResult( + truncated_value=truncated_data, + file=mock_file + ) + + with patch.object(repository, '_truncate_and_upload', return_value=truncation_result): + db_model = repository._to_db_model(execution) + + # Should create offload data + assert db_model.offload_data is not None + assert db_model.offload_data.process_data_file_id == "file-id" + assert execution.process_data_truncated == scenario.expected_truncated + else: + with patch.object(repository, '_truncate_and_upload', return_value=None): + db_model = repository._to_db_model(execution) + + # Should not create offload data or set truncation + if scenario.process_data is None: + assert db_model.offload_data is None + assert db_model.process_data is None + else: + # For small data, might have offload_data from other fields but not process_data + if db_model.offload_data: + assert db_model.offload_data.process_data_file_id is None + assert db_model.offload_data.process_data_file is None + + assert execution.process_data_truncated is False \ No newline at end of file diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py new file mode 100644 index 0000000000..86842f771b --- /dev/null +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -0,0 +1,709 @@ +""" +Comprehensive unit tests for VariableTruncator class based on current implementation. + +This test suite covers all functionality of the current VariableTruncator including: +- JSON size calculation for different data types +- String, array, and object truncation logic +- Segment-based truncation interface +- Helper methods for budget-based truncation +- Edge cases and error handling +""" + +import functools +import json +import uuid +from typing import Any +from uuid import uuid4 + +import pytest + +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File +from core.variables.segments import ( + ArrayFileSegment, + ArraySegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) +from services.variable_truncator import ( + ARRAY_CHAR_LIMIT, + LARGE_VARIABLE_THRESHOLD, + OBJECT_CHAR_LIMIT, + MaxDepthExceededError, + TruncationResult, + UnknownTypeError, + VariableTruncator, +) + + +@pytest.fixture +def file() -> File: + return File( + id=str(uuid4()), # Generate new UUID for File.id + tenant_id=str(uuid.uuid4()), + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id=str(uuid.uuid4()), + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="initial_key", + ) + + +_compact_json_dumps = functools.partial(json.dumps, separators=(",", ":")) + + +class TestCalculateJsonSize: + """Test calculate_json_size method with different data types.""" + + @pytest.fixture + def truncator(self): + return VariableTruncator() + + def test_string_size_calculation(self): + """Test JSON size calculation for strings.""" + # Simple ASCII string + assert VariableTruncator.calculate_json_size("hello") == 7 # "hello" + 2 quotes + + # Empty string + assert VariableTruncator.calculate_json_size("") == 2 # Just quotes + + # Unicode string + unicode_text = "你好" + expected_size = len(unicode_text.encode("utf-8")) + 2 + assert VariableTruncator.calculate_json_size(unicode_text) == expected_size + + def test_number_size_calculation(self, truncator): + """Test JSON size calculation for numbers.""" + assert truncator.calculate_json_size(123) == 3 + assert truncator.calculate_json_size(12.34) == 5 + assert truncator.calculate_json_size(-456) == 4 + assert truncator.calculate_json_size(0) == 1 + + def test_boolean_size_calculation(self, truncator): + """Test JSON size calculation for booleans.""" + assert truncator.calculate_json_size(True) == 4 # "true" + assert truncator.calculate_json_size(False) == 5 # "false" + + def test_null_size_calculation(self, truncator): + """Test JSON size calculation for None/null.""" + assert truncator.calculate_json_size(None) == 4 # "null" + + def test_array_size_calculation(self, truncator): + """Test JSON size calculation for arrays.""" + # Empty array + assert truncator.calculate_json_size([]) == 2 # "[]" + + # Simple array + simple_array = [1, 2, 3] + # [1,2,3] = 1 + 1 + 1 + 1 + 1 + 2 = 7 (numbers + commas + brackets) + assert truncator.calculate_json_size(simple_array) == 7 + + # Array with strings + string_array = ["a", "b"] + # ["a","b"] = 3 + 3 + 1 + 2 = 9 (quoted strings + comma + brackets) + assert truncator.calculate_json_size(string_array) == 9 + + def test_object_size_calculation(self, truncator): + """Test JSON size calculation for objects.""" + # Empty object + assert truncator.calculate_json_size({}) == 2 # "{}" + + # Simple object + simple_obj = {"a": 1} + # {"a":1} = 3 + 1 + 1 + 2 = 7 (key + colon + value + brackets) + assert truncator.calculate_json_size(simple_obj) == 7 + + # Multiple keys + multi_obj = {"a": 1, "b": 2} + # {"a":1,"b":2} = 3 + 1 + 1 + 1 + 3 + 1 + 1 + 2 = 13 + assert truncator.calculate_json_size(multi_obj) == 13 + + def test_nested_structure_size_calculation(self, truncator): + """Test JSON size calculation for nested structures.""" + nested = {"items": [1, 2, {"nested": "value"}]} + size = truncator.calculate_json_size(nested) + assert size > 0 # Should calculate without error + + # Verify it matches actual JSON length roughly + + actual_json = _compact_json_dumps(nested) + # Should be close but not exact due to UTF-8 encoding considerations + assert abs(size - len(actual_json.encode())) <= 5 + + def test_calculate_json_size_max_depth_exceeded(self, truncator): + """Test that calculate_json_size handles deep nesting gracefully.""" + # Create deeply nested structure + nested: dict[str, Any] = {"level": 0} + current = nested + for i in range(25): # Create deep nesting + current["next"] = {"level": i + 1} + current = current["next"] + + # Should either raise an error or handle gracefully + with pytest.raises(MaxDepthExceededError): + truncator.calculate_json_size(nested) + + def test_calculate_json_size_unknown_type(self, truncator): + """Test that calculate_json_size raises error for unknown types.""" + + class CustomType: + pass + + with pytest.raises(UnknownTypeError): + truncator.calculate_json_size(CustomType()) + + +class TestStringTruncation: + """Test string truncation functionality.""" + + @pytest.fixture + def small_truncator(self): + return VariableTruncator(string_length_limit=10) + + def test_short_string_no_truncation(self, small_truncator): + """Test that short strings are not truncated.""" + short_str = "hello" + result, was_truncated = small_truncator._truncate_string(short_str) + assert result == short_str + assert was_truncated is False + + def test_long_string_truncation(self, small_truncator: VariableTruncator): + """Test that long strings are truncated with ellipsis.""" + long_str = "this is a very long string that exceeds the limit" + result, was_truncated = small_truncator._truncate_string(long_str) + + assert was_truncated is True + assert result == long_str[:7] + "..." + assert len(result) == 10 # 10 chars + "..." + + def test_exact_limit_string(self, small_truncator): + """Test string exactly at limit.""" + exact_str = "1234567890" # Exactly 10 chars + result, was_truncated = small_truncator._truncate_string(exact_str) + assert result == exact_str + assert was_truncated is False + + +class TestArrayTruncation: + """Test array truncation functionality.""" + + @pytest.fixture + def small_truncator(self): + return VariableTruncator(array_element_limit=3, max_size_bytes=100) + + def test_small_array_no_truncation(self, small_truncator): + """Test that small arrays are not truncated.""" + small_array = [1, 2] + result, was_truncated = small_truncator._truncate_array(small_array, 1000) + assert result == small_array + assert was_truncated is False + + def test_array_element_limit_truncation(self, small_truncator): + """Test that arrays over element limit are truncated.""" + large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3 + result, was_truncated = small_truncator._truncate_array(large_array, 1000) + + assert was_truncated is True + assert len(result) == 3 + assert result == [1, 2, 3] + + def test_array_size_budget_truncation(self, small_truncator): + """Test array truncation due to size budget constraints.""" + # Create array with strings that will exceed size budget + large_strings = ["very long string " * 5, "another long string " * 5] + result, was_truncated = small_truncator._truncate_array(large_strings, 50) + + assert was_truncated is True + # Should have truncated the strings within the array + for item in result: + assert isinstance(item, str) + print(result) + assert len(_compact_json_dumps(result).encode()) <= 50 + + def test_array_with_nested_objects(self, small_truncator): + """Test array truncation with nested objects.""" + nested_array = [ + {"name": "item1", "data": "some data"}, + {"name": "item2", "data": "more data"}, + {"name": "item3", "data": "even more data"}, + ] + result, was_truncated = small_truncator._truncate_array(nested_array, 80) + + assert isinstance(result, list) + assert len(result) <= 3 + # Should have processed nested objects appropriately + + +class TestObjectTruncation: + """Test object truncation functionality.""" + + @pytest.fixture + def small_truncator(self): + return VariableTruncator(max_size_bytes=100) + + def test_small_object_no_truncation(self, small_truncator): + """Test that small objects are not truncated.""" + small_obj = {"a": 1, "b": 2} + result, was_truncated = small_truncator._truncate_object(small_obj, 1000) + assert result == small_obj + assert was_truncated is False + + def test_empty_object_no_truncation(self, small_truncator): + """Test that empty objects are not truncated.""" + empty_obj = {} + result, was_truncated = small_truncator._truncate_object(empty_obj, 100) + assert result == empty_obj + assert was_truncated is False + + def test_object_value_truncation(self, small_truncator): + """Test object truncation where values are truncated to fit budget.""" + obj_with_long_values = { + "key1": "very long string " * 10, + "key2": "another long string " * 10, + "key3": "third long string " * 10, + } + result, was_truncated = small_truncator._truncate_object(obj_with_long_values, 80) + + assert was_truncated is True + assert isinstance(result, dict) + + # Keys should be preserved (deterministic order due to sorting) + if result: # Only check if result is not empty + assert list(result.keys()) == sorted(result.keys()) + + # Values should be truncated if they exist + for key, value in result.items(): + if isinstance(value, str): + original_value = obj_with_long_values[key] + # Value should be same or smaller + assert len(value) <= len(original_value) + + def test_object_key_dropping(self, small_truncator): + """Test object truncation where keys are dropped due to size constraints.""" + large_obj = {f"key{i:02d}": f"value{i}" for i in range(20)} + result, was_truncated = small_truncator._truncate_object(large_obj, 50) + + assert was_truncated is True + assert len(result) < len(large_obj) + + # Should maintain sorted key order + result_keys = list(result.keys()) + assert result_keys == sorted(result_keys) + + def test_object_with_nested_structures(self, small_truncator): + """Test object truncation with nested arrays and objects.""" + nested_obj = {"simple": "value", "array": [1, 2, 3, 4, 5], "nested": {"inner": "data", "more": ["a", "b", "c"]}} + result, was_truncated = small_truncator._truncate_object(nested_obj, 60) + + assert isinstance(result, dict) + # Should handle nested structures appropriately + + +class TestSegmentBasedTruncation: + """Test the main truncate method that works with Segments.""" + + @pytest.fixture + def truncator(self): + return VariableTruncator() + + @pytest.fixture + def small_truncator(self): + return VariableTruncator(string_length_limit=20, array_element_limit=3, max_size_bytes=200) + + def test_integer_segment_no_truncation(self, truncator): + """Test that integer segments are never truncated.""" + segment = IntegerSegment(value=12345) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is False + assert result.result == segment + + def test_boolean_as_integer_segment(self, truncator): + """Test boolean values in IntegerSegment are converted to int.""" + segment = IntegerSegment(value=True) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is False + assert isinstance(result.result, IntegerSegment) + assert result.result.value == 1 # True converted to 1 + + def test_float_segment_no_truncation(self, truncator): + """Test that float segments are never truncated.""" + segment = FloatSegment(value=123.456) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is False + assert result.result == segment + + def test_none_segment_no_truncation(self, truncator): + """Test that None segments are never truncated.""" + segment = NoneSegment() + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is False + assert result.result == segment + + def test_file_segment_no_truncation(self, truncator, file): + """Test that file segments are never truncated.""" + file_segment = FileSegment(value=file) + result = truncator.truncate(file_segment) + assert result.result == file_segment + assert result.truncated is False + + def test_array_file_segment_no_truncation(self, truncator, file): + """Test that array file segments are never truncated.""" + + array_file_segment = ArrayFileSegment(value=[file] * 20) + result = truncator.truncate(array_file_segment) + assert result.result == array_file_segment + assert result.truncated is False + + def test_string_segment_small_no_truncation(self, truncator): + """Test small string segments are not truncated.""" + segment = StringSegment(value="hello world") + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is False + assert result.result == segment + + def test_string_segment_large_truncation(self, small_truncator): + """Test large string segments are truncated.""" + long_text = "this is a very long string that will definitely exceed the limit" + segment = StringSegment(value=long_text) + result = small_truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is True + assert isinstance(result.result, StringSegment) + assert len(result.result.value) < len(long_text) + assert result.result.value.endswith("...") + + def test_array_segment_small_no_truncation(self, truncator): + """Test small array segments are not truncated.""" + from factories.variable_factory import build_segment + + segment = build_segment([1, 2, 3]) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is False + assert result.result == segment + + def test_array_segment_large_truncation(self, small_truncator): + """Test large array segments are truncated.""" + from factories.variable_factory import build_segment + + large_array = list(range(10)) # Exceeds element limit of 3 + segment = build_segment(large_array) + result = small_truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is True + assert isinstance(result.result, ArraySegment) + assert len(result.result.value) <= 3 + + def test_object_segment_small_no_truncation(self, truncator): + """Test small object segments are not truncated.""" + segment = ObjectSegment(value={"key": "value"}) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is False + assert result.result == segment + + def test_object_segment_large_truncation(self, small_truncator): + """Test large object segments are truncated.""" + large_obj = {f"key{i}": f"very long value {i}" * 5 for i in range(5)} + segment = ObjectSegment(value=large_obj) + result = small_truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is True + assert isinstance(result.result, ObjectSegment) + # Object should be smaller or equal than original + original_size = small_truncator.calculate_json_size(large_obj) + result_size = small_truncator.calculate_json_size(result.result.value) + assert result_size <= original_size + + def test_final_size_fallback_to_json_string(self, small_truncator): + """Test final fallback when truncated result still exceeds size limit.""" + # Create data that will still be large after initial truncation + large_nested_data = {"data": ["very long string " * 5] * 5, "more": {"nested": "content " * 20}} + segment = ObjectSegment(value=large_nested_data) + + # Use very small limit to force JSON string fallback + tiny_truncator = VariableTruncator(max_size_bytes=50) + result = tiny_truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is True + assert isinstance(result.result, StringSegment) + # Should be JSON string with possible truncation + assert len(result.result.value) <= 53 # 50 + "..." = 53 + + def test_final_size_fallback_string_truncation(self, small_truncator): + """Test final fallback for string that still exceeds limit.""" + # Create very long string that exceeds string length limit + very_long_string = "x" * 6000 # Exceeds default string_length_limit of 5000 + segment = StringSegment(value=very_long_string) + + # Use small limit to test string fallback path + tiny_truncator = VariableTruncator(string_length_limit=100, max_size_bytes=50) + result = tiny_truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is True + assert isinstance(result.result, StringSegment) + # Should be truncated due to string limit or final size limit + assert len(result.result.value) <= 1000 # Much smaller than original + + +class TestTruncationHelperMethods: + """Test helper methods used in truncation.""" + + @pytest.fixture + def truncator(self): + return VariableTruncator() + + def test_truncate_item_to_budget_string(self, truncator): + """Test _truncate_item_to_budget with string input.""" + item = "this is a long string" + budget = 15 + result, was_truncated = truncator._truncate_item_to_budget(item, budget) + + assert isinstance(result, str) + # Should be truncated to fit budget + if was_truncated: + assert len(result) <= budget + assert result.endswith("...") + + def test_truncate_item_to_budget_dict(self, truncator): + """Test _truncate_item_to_budget with dict input.""" + item = {"key": "value", "longer": "longer value"} + budget = 30 + result, was_truncated = truncator._truncate_item_to_budget(item, budget) + + assert isinstance(result, dict) + # Should apply object truncation logic + + def test_truncate_item_to_budget_list(self, truncator): + """Test _truncate_item_to_budget with list input.""" + item = [1, 2, 3, 4, 5] + budget = 15 + result, was_truncated = truncator._truncate_item_to_budget(item, budget) + + assert isinstance(result, list) + # Should apply array truncation logic + + def test_truncate_item_to_budget_other_types(self, truncator): + """Test _truncate_item_to_budget with other types.""" + # Small number that fits + result, was_truncated = truncator._truncate_item_to_budget(123, 10) + assert result == 123 + assert was_truncated is False + + # Large number that might not fit - should convert to string if needed + large_num = 123456789012345 + result, was_truncated = truncator._truncate_item_to_budget(large_num, 5) + if was_truncated: + assert isinstance(result, str) + + def test_truncate_value_to_budget_string(self, truncator): + """Test _truncate_value_to_budget with string input.""" + value = "x" * 100 + budget = 20 + result, was_truncated = truncator._truncate_value_to_budget(value, budget) + + assert isinstance(result, str) + if was_truncated: + assert len(result) <= 20 # Should respect budget + assert result.endswith("...") + + def test_truncate_value_to_budget_respects_object_char_limit(self, truncator): + """Test that _truncate_value_to_budget respects OBJECT_CHAR_LIMIT.""" + # Even with large budget, should respect OBJECT_CHAR_LIMIT + large_string = "x" * 10000 + large_budget = 20000 + result, was_truncated = truncator._truncate_value_to_budget(large_string, large_budget) + + if was_truncated: + assert len(result) <= OBJECT_CHAR_LIMIT + 3 # +3 for "..." + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_inputs(self): + """Test truncator with empty inputs.""" + truncator = VariableTruncator() + + # Empty string + result = truncator.truncate(StringSegment(value="")) + assert not result.truncated + assert result.result.value == "" + + # Empty array + from factories.variable_factory import build_segment + + result = truncator.truncate(build_segment([])) + assert not result.truncated + assert result.result.value == [] + + # Empty object + result = truncator.truncate(ObjectSegment(value={})) + assert not result.truncated + assert result.result.value == {} + + def test_zero_and_negative_limits(self): + """Test truncator behavior with zero or very small limits.""" + # Zero string limit + with pytest.raises(ValueError): + truncator = VariableTruncator(string_length_limit=3) + + with pytest.raises(ValueError): + truncator = VariableTruncator(array_element_limit=0) + + with pytest.raises(ValueError): + truncator = VariableTruncator(max_size_bytes=0) + + def test_unicode_and_special_characters(self): + """Test truncator with unicode and special characters.""" + truncator = VariableTruncator(string_length_limit=10) + + # Unicode characters + unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character + result = truncator.truncate(StringSegment(value=unicode_text)) + if len(unicode_text) > 10: + assert result.truncated is True + + # Special JSON characters + special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}' + result = truncator.truncate(StringSegment(value=special_chars)) + assert isinstance(result.result, StringSegment) + + +class TestIntegrationScenarios: + """Test realistic integration scenarios.""" + + def test_workflow_output_scenario(self): + """Test truncation of typical workflow output data.""" + truncator = VariableTruncator() + + workflow_data = { + "result": "success", + "data": { + "users": [ + {"id": 1, "name": "Alice", "email": "alice@example.com"}, + {"id": 2, "name": "Bob", "email": "bob@example.com"}, + ] + * 3, # Multiply to make it larger + "metadata": { + "count": 6, + "processing_time": "1.23s", + "details": "x" * 200, # Long string but not too long + }, + }, + } + + segment = ObjectSegment(value=workflow_data) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert isinstance(result.result, (ObjectSegment, StringSegment)) + # Should handle complex nested structure appropriately + + def test_large_text_processing_scenario(self): + """Test truncation of large text data.""" + truncator = VariableTruncator(string_length_limit=100) + + large_text = "This is a very long text document. " * 20 # Make it larger than limit + + segment = StringSegment(value=large_text) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + assert result.truncated is True + assert isinstance(result.result, StringSegment) + assert len(result.result.value) <= 103 # 100 + "..." + assert result.result.value.endswith("...") + + def test_mixed_data_types_scenario(self): + """Test truncation with mixed data types in complex structure.""" + truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300) + + mixed_data = { + "strings": ["short", "medium length", "very long string " * 3], + "numbers": [1, 2.5, 999999], + "booleans": [True, False, True], + "nested": { + "more_strings": ["nested string " * 2], + "more_numbers": list(range(5)), + "deep": {"level": 3, "content": "deep content " * 3}, + }, + "nulls": [None, None], + } + + segment = ObjectSegment(value=mixed_data) + result = truncator.truncate(segment) + + assert isinstance(result, TruncationResult) + # Should handle all data types appropriately + if result.truncated: + # Verify the result is smaller or equal than original + original_size = truncator.calculate_json_size(mixed_data) + if isinstance(result.result, ObjectSegment): + result_size = truncator.calculate_json_size(result.result.value) + assert result_size <= original_size + + +class TestConstantsAndConfiguration: + """Test behavior with different configuration constants.""" + + def test_large_variable_threshold_constant(self): + """Test that LARGE_VARIABLE_THRESHOLD constant is properly used.""" + truncator = VariableTruncator() + assert truncator._max_size_bytes == LARGE_VARIABLE_THRESHOLD + assert LARGE_VARIABLE_THRESHOLD == 10 * 1024 # 10KB + + def test_string_truncation_limit_constant(self): + """Test that STRING_TRUNCATION_LIMIT constant is properly used.""" + truncator = VariableTruncator() + assert truncator._string_length_limit == 5000 + + def test_array_char_limit_constant(self): + """Test that ARRAY_CHAR_LIMIT is used in array item truncation.""" + truncator = VariableTruncator() + + # Test that ARRAY_CHAR_LIMIT is respected in array item truncation + long_string = "x" * 2000 + budget = 5000 # Large budget + + result, was_truncated = truncator._truncate_item_to_budget(long_string, budget) + if was_truncated: + # Should not exceed ARRAY_CHAR_LIMIT even with large budget + assert len(result) <= ARRAY_CHAR_LIMIT + 3 # +3 for "..." + + def test_object_char_limit_constant(self): + """Test that OBJECT_CHAR_LIMIT is used in object value truncation.""" + truncator = VariableTruncator() + + # Test that OBJECT_CHAR_LIMIT is respected in object value truncation + long_string = "x" * 8000 + large_budget = 20000 + + result, was_truncated = truncator._truncate_value_to_budget(long_string, large_budget) + if was_truncated: + # Should not exceed OBJECT_CHAR_LIMIT even with large budget + assert len(result) <= OBJECT_CHAR_LIMIT + 3 # +3 for "..." diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py new file mode 100644 index 0000000000..78726f7dd7 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -0,0 +1,379 @@ +"""Simplified unit tests for DraftVarLoader focusing on core functionality.""" +import json +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy import Engine + +from core.variables.segments import ObjectSegment, StringSegment +from core.variables.types import SegmentType +from models.model import UploadFile +from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile +from services.workflow_draft_variable_service import DraftVarLoader + + +class TestDraftVarLoaderSimple: + """Simplified unit tests for DraftVarLoader core methods.""" + + @pytest.fixture + def mock_engine(self) -> Engine: + return Mock(spec=Engine) + + @pytest.fixture + def draft_var_loader(self, mock_engine): + """Create DraftVarLoader instance for testing.""" + return DraftVarLoader( + engine=mock_engine, + app_id="test-app-id", + tenant_id="test-tenant-id", + fallback_variables=[] + ) + + def test_load_offloaded_variable_string_type_unit(self, draft_var_loader): + """Test _load_offloaded_variable with string type - isolated unit test.""" + # Create mock objects + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/test.txt" + + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.STRING + variable_file.upload_file = upload_file + + draft_var = Mock(spec=WorkflowDraftVariable) + draft_var.id = "draft-var-id" + draft_var.node_id = "test-node-id" + draft_var.name = "test_variable" + draft_var.description = "test description" + draft_var.get_selector.return_value = ["test-node-id", "test_variable"] + draft_var.variable_file = variable_file + + test_content = "This is the full string content" + + with patch("services.workflow_draft_variable_service.storage") as mock_storage: + mock_storage.load.return_value = test_content.encode() + + with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: + mock_variable = Mock() + mock_variable.id = "draft-var-id" + mock_variable.name = "test_variable" + mock_variable.value = StringSegment(value=test_content) + mock_segment_to_variable.return_value = mock_variable + + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + + # Verify results + assert selector_tuple == ("test-node-id", "test_variable") + assert variable.id == "draft-var-id" + assert variable.name == "test_variable" + assert variable.description == "test description" + assert variable.value == test_content + + # Verify storage was called correctly + mock_storage.load.assert_called_once_with("storage/key/test.txt") + + def test_load_offloaded_variable_object_type_unit(self, draft_var_loader): + """Test _load_offloaded_variable with object type - isolated unit test.""" + # Create mock objects + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/test.json" + + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.OBJECT + variable_file.upload_file = upload_file + + draft_var = Mock(spec=WorkflowDraftVariable) + draft_var.id = "draft-var-id" + draft_var.node_id = "test-node-id" + draft_var.name = "test_object" + draft_var.description = "test description" + draft_var.get_selector.return_value = ["test-node-id", "test_object"] + draft_var.variable_file = variable_file + + test_object = {"key1": "value1", "key2": 42} + test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":")) + + with patch("services.workflow_draft_variable_service.storage") as mock_storage: + mock_storage.load.return_value = test_json_content.encode() + + with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: + mock_segment = ObjectSegment(value=test_object) + mock_build_segment.return_value = mock_segment + + with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: + mock_variable = Mock() + mock_variable.id = "draft-var-id" + mock_variable.name = "test_object" + mock_variable.value = mock_segment + mock_segment_to_variable.return_value = mock_variable + + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + + # Verify results + assert selector_tuple == ("test-node-id", "test_object") + assert variable.id == "draft-var-id" + assert variable.name == "test_object" + assert variable.description == "test description" + assert variable.value == test_object + + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test.json") + mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object) + + def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader): + """Test that assertion error is raised when variable_file is None.""" + draft_var = Mock(spec=WorkflowDraftVariable) + draft_var.variable_file = None + + with pytest.raises(AssertionError): + draft_var_loader._load_offloaded_variable(draft_var) + + def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader): + """Test that assertion error is raised when upload_file is None.""" + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.upload_file = None + + draft_var = Mock(spec=WorkflowDraftVariable) + draft_var.variable_file = variable_file + + with pytest.raises(AssertionError): + draft_var_loader._load_offloaded_variable(draft_var) + + def test_load_variables_empty_selectors_unit(self, draft_var_loader): + """Test load_variables returns empty list for empty selectors.""" + result = draft_var_loader.load_variables([]) + assert result == [] + + def test_selector_to_tuple_unit(self, draft_var_loader): + """Test _selector_to_tuple method.""" + selector = ["node_id", "var_name", "extra_field"] + result = draft_var_loader._selector_to_tuple(selector) + assert result == ("node_id", "var_name") + + def test_load_offloaded_variable_number_type_unit(self, draft_var_loader): + """Test _load_offloaded_variable with number type - isolated unit test.""" + # Create mock objects + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/test_number.json" + + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.NUMBER + variable_file.upload_file = upload_file + + draft_var = Mock(spec=WorkflowDraftVariable) + draft_var.id = "draft-var-id" + draft_var.node_id = "test-node-id" + draft_var.name = "test_number" + draft_var.description = "test number description" + draft_var.get_selector.return_value = ["test-node-id", "test_number"] + draft_var.variable_file = variable_file + + test_number = 123.45 + test_json_content = json.dumps(test_number) + + with patch("services.workflow_draft_variable_service.storage") as mock_storage: + mock_storage.load.return_value = test_json_content.encode() + + with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: + from core.variables.segments import FloatSegment + mock_segment = FloatSegment(value=test_number) + mock_build_segment.return_value = mock_segment + + with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: + mock_variable = Mock() + mock_variable.id = "draft-var-id" + mock_variable.name = "test_number" + mock_variable.value = mock_segment + mock_segment_to_variable.return_value = mock_variable + + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + + # Verify results + assert selector_tuple == ("test-node-id", "test_number") + assert variable.id == "draft-var-id" + assert variable.name == "test_number" + assert variable.description == "test number description" + + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_number.json") + mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number) + + def test_load_offloaded_variable_array_type_unit(self, draft_var_loader): + """Test _load_offloaded_variable with array type - isolated unit test.""" + # Create mock objects + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/test_array.json" + + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.ARRAY_ANY + variable_file.upload_file = upload_file + + draft_var = Mock(spec=WorkflowDraftVariable) + draft_var.id = "draft-var-id" + draft_var.node_id = "test-node-id" + draft_var.name = "test_array" + draft_var.description = "test array description" + draft_var.get_selector.return_value = ["test-node-id", "test_array"] + draft_var.variable_file = variable_file + + test_array = ["item1", "item2", "item3"] + test_json_content = json.dumps(test_array) + + with patch("services.workflow_draft_variable_service.storage") as mock_storage: + mock_storage.load.return_value = test_json_content.encode() + + with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: + from core.variables.segments import ArrayAnySegment + mock_segment = ArrayAnySegment(value=test_array) + mock_build_segment.return_value = mock_segment + + with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: + mock_variable = Mock() + mock_variable.id = "draft-var-id" + mock_variable.name = "test_array" + mock_variable.value = mock_segment + mock_segment_to_variable.return_value = mock_variable + + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + + # Verify results + assert selector_tuple == ("test-node-id", "test_array") + assert variable.id == "draft-var-id" + assert variable.name == "test_array" + assert variable.description == "test array description" + + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_array.json") + mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array) + + def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader): + """Test load_variables method with mix of regular and offloaded variables.""" + selectors = [ + ["node1", "regular_var"], + ["node2", "offloaded_var"] + ] + + # Mock regular variable + regular_draft_var = Mock(spec=WorkflowDraftVariable) + regular_draft_var.is_truncated.return_value = False + regular_draft_var.node_id = "node1" + regular_draft_var.name = "regular_var" + regular_draft_var.get_value.return_value = StringSegment(value="regular_value") + regular_draft_var.get_selector.return_value = ["node1", "regular_var"] + regular_draft_var.id = "regular-var-id" + regular_draft_var.description = "regular description" + + # Mock offloaded variable + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/offloaded.txt" + + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.STRING + variable_file.upload_file = upload_file + + offloaded_draft_var = Mock(spec=WorkflowDraftVariable) + offloaded_draft_var.is_truncated.return_value = True + offloaded_draft_var.node_id = "node2" + offloaded_draft_var.name = "offloaded_var" + offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"] + offloaded_draft_var.variable_file = variable_file + offloaded_draft_var.id = "offloaded-var-id" + offloaded_draft_var.description = "offloaded description" + + draft_vars = [regular_draft_var, offloaded_draft_var] + + with patch("services.workflow_draft_variable_service.Session") as mock_session_cls: + mock_session = Mock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + mock_service = Mock() + mock_service.get_draft_variables_by_selectors.return_value = draft_vars + + with patch("services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service): + with patch("services.workflow_draft_variable_service.StorageKeyLoader"): + with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: + # Mock regular variable creation + regular_variable = Mock() + regular_variable.selector = ["node1", "regular_var"] + + # Mock offloaded variable creation + offloaded_variable = Mock() + offloaded_variable.selector = ["node2", "offloaded_var"] + + mock_segment_to_variable.return_value = regular_variable + + with patch("services.workflow_draft_variable_service.storage") as mock_storage: + mock_storage.load.return_value = b"offloaded_content" + + with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded: + mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable) + + with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls: + mock_executor = Mock() + mock_executor_cls.return_value.__enter__.return_value = mock_executor + mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)] + + # Execute the method + result = draft_var_loader.load_variables(selectors) + + # Verify results + assert len(result) == 2 + + # Verify service method was called + mock_service.get_draft_variables_by_selectors.assert_called_once_with( + draft_var_loader._app_id, selectors + ) + + # Verify offloaded variable loading was called + mock_load_offloaded.assert_called_once_with(offloaded_draft_var) + + def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader): + """Test load_variables method with only offloaded variables.""" + selectors = [ + ["node1", "offloaded_var1"], + ["node2", "offloaded_var2"] + ] + + # Mock first offloaded variable + offloaded_var1 = Mock(spec=WorkflowDraftVariable) + offloaded_var1.is_truncated.return_value = True + offloaded_var1.node_id = "node1" + offloaded_var1.name = "offloaded_var1" + + # Mock second offloaded variable + offloaded_var2 = Mock(spec=WorkflowDraftVariable) + offloaded_var2.is_truncated.return_value = True + offloaded_var2.node_id = "node2" + offloaded_var2.name = "offloaded_var2" + + draft_vars = [offloaded_var1, offloaded_var2] + + with patch("services.workflow_draft_variable_service.Session") as mock_session_cls: + mock_session = Mock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + mock_service = Mock() + mock_service.get_draft_variables_by_selectors.return_value = draft_vars + + with patch("services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service): + with patch("services.workflow_draft_variable_service.StorageKeyLoader"): + with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls: + mock_executor = Mock() + mock_executor_cls.return_value.__enter__.return_value = mock_executor + mock_executor.map.return_value = [ + (("node1", "offloaded_var1"), Mock()), + (("node2", "offloaded_var2"), Mock()) + ] + + # Execute the method + result = draft_var_loader.load_variables(selectors) + + # Verify results - since we have only offloaded variables, should have 2 results + assert len(result) == 2 + + # Verify ThreadPoolExecutor was used + mock_executor_cls.assert_called_once_with(max_workers=10) + mock_executor.map.assert_called_once() \ No newline at end of file diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 8b1348b75b..7b5f4e3ffc 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -1,16 +1,26 @@ import dataclasses import secrets +import uuid from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session -from core.variables import StringSegment +from core.variables.segments import StringSegment +from core.variables.types import SegmentType from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType +from libs.uuid_utils import uuidv7 +from models.account import Account from models.enums import DraftVariableType -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import ( + Workflow, + WorkflowDraftVariable, + WorkflowDraftVariableFile, + WorkflowNodeExecutionModel, + is_system_variable_editable, +) from services.workflow_draft_variable_service import ( DraftVariableSaver, VariableResetError, @@ -37,6 +47,7 @@ class TestDraftVariableSaver: def test__should_variable_be_visible(self): mock_session = MagicMock(spec=Session) + mock_user = Account(id=str(uuid.uuid4())) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, @@ -44,6 +55,7 @@ class TestDraftVariableSaver: node_id="test_node_id", node_type=NodeType.START, node_execution_id="test_execution_id", + user=mock_user, ) assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False assert saver._should_variable_be_visible("123", NodeType.START, "output") == True @@ -83,6 +95,7 @@ class TestDraftVariableSaver: ] mock_session = MagicMock(spec=Session) + mock_user = MagicMock() test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, @@ -90,6 +103,7 @@ class TestDraftVariableSaver: node_id=_NODE_ID, node_type=NodeType.START, node_execution_id="test_execution_id", + user=mock_user, ) for idx, c in enumerate(cases, 1): fail_msg = f"Test case {c.name} failed, index={idx}" @@ -97,6 +111,76 @@ class TestDraftVariableSaver: assert node_id == c.expected_node_id, fail_msg assert name == c.expected_name, fail_msg + @pytest.fixture + def mock_session(self): + """Mock SQLAlchemy session.""" + from sqlalchemy import Engine + + mock_session = MagicMock(spec=Session) + mock_engine = MagicMock(spec=Engine) + mock_session.get_bind.return_value = mock_engine + return mock_session + + @pytest.fixture + def draft_saver(self, mock_session): + """Create DraftVariableSaver instance with user context.""" + # Create a mock user + mock_user = MagicMock(spec=Account) + mock_user.id = "test-user-id" + mock_user.tenant_id = "test-tenant-id" + + return DraftVariableSaver( + session=mock_session, + app_id="test-app-id", + node_id="test-node-id", + node_type=NodeType.LLM, + node_execution_id="test-execution-id", + user=mock_user, + ) + + def test_draft_saver_with_small_variables(self, draft_saver, mock_session): + with patch( + "services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable" + ) as _mock_try_offload: + _mock_try_offload.return_value = None + mock_segment = StringSegment(value="small value") + draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True) + + # Should not have large variable metadata + assert draft_var.file_id is None + _mock_try_offload.return_value = None + + def test_draft_saver_with_large_variables(self, draft_saver, mock_session): + with patch( + "services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable" + ) as _mock_try_offload: + mock_segment = StringSegment(value="small value") + mock_draft_var_file = WorkflowDraftVariableFile( + id=str(uuidv7()), + size=1024, + length=10, + value_type=SegmentType.ARRAY_STRING, + upload_file_id=str(uuid.uuid4()), + ) + + _mock_try_offload.return_value = mock_segment, mock_draft_var_file + draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True) + + # Should not have large variable metadata + assert draft_var.file_id == mock_draft_var_file.id + + @patch("services.workflow_draft_variable_service._batch_upsert_draft_variable") + def test_save_method_integration(self, mock_batch_upsert, draft_saver): + """Test complete save workflow.""" + outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}} + + draft_saver.save(outputs=outputs) + + # Should batch upsert draft variables + mock_batch_upsert.assert_called_once() + draft_vars = mock_batch_upsert.call_args[0][1] + assert len(draft_vars) == 2 + class TestWorkflowDraftVariableService: def _get_test_app_id(self): diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index 673282a6f4..0b545da2ee 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -1,14 +1,18 @@ from unittest.mock import ANY, MagicMock, call, patch import pytest -import sqlalchemy as sa -from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch +from tasks.remove_app_and_related_data_task import ( + _delete_draft_variable_offload_data, + _delete_draft_variables, + delete_draft_variables_batch, +) class TestDeleteDraftVariablesBatch: + @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") @patch("tasks.remove_app_and_related_data_task.db") - def test_delete_draft_variables_batch_success(self, mock_db): + def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup): """Test successful deletion of draft variables in batches.""" app_id = "test-app-id" batch_size = 100 @@ -24,13 +28,19 @@ class TestDeleteDraftVariablesBatch: mock_engine.begin.return_value = mock_context_manager # Mock two batches of results, then empty - batch1_ids = [f"var-{i}" for i in range(100)] - batch2_ids = [f"var-{i}" for i in range(100, 150)] + batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)] + batch2_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(100, 150)] + + batch1_ids = [row[0] for row in batch1_data] + batch1_file_ids = [row[1] for row in batch1_data if row[1] is not None] + + batch2_ids = [row[0] for row in batch2_data] + batch2_file_ids = [row[1] for row in batch2_data if row[1] is not None] # Setup side effects for execute calls in the correct order: - # 1. SELECT (returns batch1_ids) + # 1. SELECT (returns batch1_data with id, file_id) # 2. DELETE (returns result with rowcount=100) - # 3. SELECT (returns batch2_ids) + # 3. SELECT (returns batch2_data) # 4. DELETE (returns result with rowcount=50) # 5. SELECT (returns empty, ends loop) @@ -41,14 +51,14 @@ class TestDeleteDraftVariablesBatch: # First SELECT result select_result1 = MagicMock() - select_result1.__iter__.return_value = iter([(id_,) for id_ in batch1_ids]) + select_result1.__iter__.return_value = iter(batch1_data) # First DELETE result delete_result1 = MockResult(rowcount=100) # Second SELECT result select_result2 = MagicMock() - select_result2.__iter__.return_value = iter([(id_,) for id_ in batch2_ids]) + select_result2.__iter__.return_value = iter(batch2_data) # Second DELETE result delete_result2 = MockResult(rowcount=50) @@ -66,6 +76,9 @@ class TestDeleteDraftVariablesBatch: select_result3, # Third SELECT (empty) ] + # Mock offload data cleanup + mock_offload_cleanup.side_effect = [len(batch1_file_ids), len(batch2_file_ids)] + # Execute the function result = delete_draft_variables_batch(app_id, batch_size) @@ -75,65 +88,18 @@ class TestDeleteDraftVariablesBatch: # Verify database calls assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes - # Verify the expected calls in order: - # 1. SELECT, 2. DELETE, 3. SELECT, 4. DELETE, 5. SELECT - expected_calls = [ - # First SELECT - call( - sa.text(""" - SELECT id FROM workflow_draft_variables - WHERE app_id = :app_id - LIMIT :batch_size - """), - {"app_id": app_id, "batch_size": batch_size}, - ), - # First DELETE - call( - sa.text(""" - DELETE FROM workflow_draft_variables - WHERE id IN :ids - """), - {"ids": tuple(batch1_ids)}, - ), - # Second SELECT - call( - sa.text(""" - SELECT id FROM workflow_draft_variables - WHERE app_id = :app_id - LIMIT :batch_size - """), - {"app_id": app_id, "batch_size": batch_size}, - ), - # Second DELETE - call( - sa.text(""" - DELETE FROM workflow_draft_variables - WHERE id IN :ids - """), - {"ids": tuple(batch2_ids)}, - ), - # Third SELECT (empty result) - call( - sa.text(""" - SELECT id FROM workflow_draft_variables - WHERE app_id = :app_id - LIMIT :batch_size - """), - {"app_id": app_id, "batch_size": batch_size}, - ), - ] + # Verify offload cleanup was called for both batches with file_ids + expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)] + mock_offload_cleanup.assert_has_calls(expected_offload_calls) - # Check that all calls were made correctly - actual_calls = mock_conn.execute.call_args_list - assert len(actual_calls) == len(expected_calls) - - # Simplified verification - just check that the right number of calls were made + # Simplified verification - check that the right number of calls were made # and that the SQL queries contain the expected patterns + actual_calls = mock_conn.execute.call_args_list for i, actual_call in enumerate(actual_calls): if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4) - # Verify it's a SELECT query + # Verify it's a SELECT query that now includes file_id sql_text = str(actual_call[0][0]) - assert "SELECT id FROM workflow_draft_variables" in sql_text + assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text assert "WHERE app_id = :app_id" in sql_text assert "LIMIT :batch_size" in sql_text else: # DELETE calls (odd indices: 1, 3) @@ -142,8 +108,9 @@ class TestDeleteDraftVariablesBatch: assert "DELETE FROM workflow_draft_variables" in sql_text assert "WHERE id IN :ids" in sql_text + @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") @patch("tasks.remove_app_and_related_data_task.db") - def test_delete_draft_variables_batch_empty_result(self, mock_db): + def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup): """Test deletion when no draft variables exist for the app.""" app_id = "nonexistent-app-id" batch_size = 1000 @@ -167,6 +134,7 @@ class TestDeleteDraftVariablesBatch: assert result == 0 assert mock_conn.execute.call_count == 1 # Only one select query + mock_offload_cleanup.assert_not_called() # No files to clean up def test_delete_draft_variables_batch_invalid_batch_size(self): """Test that invalid batch size raises ValueError.""" @@ -178,9 +146,10 @@ class TestDeleteDraftVariablesBatch: with pytest.raises(ValueError, match="batch_size must be positive"): delete_draft_variables_batch(app_id, 0) + @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") @patch("tasks.remove_app_and_related_data_task.db") @patch("tasks.remove_app_and_related_data_task.logger") - def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db): + def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup): """Test that batch deletion logs progress correctly.""" app_id = "test-app-id" batch_size = 50 @@ -196,10 +165,13 @@ class TestDeleteDraftVariablesBatch: mock_engine.begin.return_value = mock_context_manager # Mock one batch then empty - batch_ids = [f"var-{i}" for i in range(30)] + batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)] + batch_ids = [row[0] for row in batch_data] + batch_file_ids = [row[1] for row in batch_data if row[1] is not None] + # Create properly configured mocks select_result = MagicMock() - select_result.__iter__.return_value = iter([(id_,) for id_ in batch_ids]) + select_result.__iter__.return_value = iter(batch_data) # Create simple object with rowcount attribute class MockResult: @@ -220,10 +192,17 @@ class TestDeleteDraftVariablesBatch: empty_result, ] + # Mock offload cleanup + mock_offload_cleanup.return_value = len(batch_file_ids) + result = delete_draft_variables_batch(app_id, batch_size) assert result == 30 + # Verify offload cleanup was called with file_ids + if batch_file_ids: + mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids) + # Verify logging calls assert mock_logging.info.call_count == 2 mock_logging.info.assert_any_call( @@ -241,3 +220,118 @@ class TestDeleteDraftVariablesBatch: assert result == expected_return mock_batch_delete.assert_called_once_with(app_id, batch_size=1000) + + +class TestDeleteDraftVariableOffloadData: + """Test the Offload data cleanup functionality.""" + + @patch("extensions.ext_storage.storage") + def test_delete_draft_variable_offload_data_success(self, mock_storage): + """Test successful deletion of offload data.""" + + # Mock connection + mock_conn = MagicMock() + file_ids = ["file-1", "file-2", "file-3"] + + # Mock query results: (variable_file_id, storage_key, upload_file_id) + query_results = [ + ("file-1", "storage/key/1", "upload-1"), + ("file-2", "storage/key/2", "upload-2"), + ("file-3", "storage/key/3", "upload-3"), + ] + + mock_result = MagicMock() + mock_result.__iter__.return_value = iter(query_results) + mock_conn.execute.return_value = mock_result + + # Execute function + result = _delete_draft_variable_offload_data(mock_conn, file_ids) + + # Verify return value + assert result == 3 + + # Verify storage deletion calls + expected_storage_calls = [call("storage/key/1"), call("storage/key/2"), call("storage/key/3")] + mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True) + + # Verify database calls - should be 3 calls total + assert mock_conn.execute.call_count == 3 + + # Verify the queries were called + actual_calls = mock_conn.execute.call_args_list + + # First call should be the SELECT query + select_call_sql = str(actual_calls[0][0][0]) + assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql + assert "FROM workflow_draft_variable_files wdvf" in select_call_sql + assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql + assert "WHERE wdvf.id IN :file_ids" in select_call_sql + + # Second call should be DELETE upload_files + delete_upload_call_sql = str(actual_calls[1][0][0]) + assert "DELETE FROM upload_files" in delete_upload_call_sql + assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql + + # Third call should be DELETE workflow_draft_variable_files + delete_variable_files_call_sql = str(actual_calls[2][0][0]) + assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql + assert "WHERE id IN :file_ids" in delete_variable_files_call_sql + + def test_delete_draft_variable_offload_data_empty_file_ids(self): + """Test handling of empty file_ids list.""" + mock_conn = MagicMock() + + result = _delete_draft_variable_offload_data(mock_conn, []) + + assert result == 0 + mock_conn.execute.assert_not_called() + + @patch("extensions.ext_storage.storage") + @patch("tasks.remove_app_and_related_data_task.logging") + def test_delete_draft_variable_offload_data_storage_failure(self, mock_logging, mock_storage): + """Test handling of storage deletion failures.""" + mock_conn = MagicMock() + file_ids = ["file-1", "file-2"] + + # Mock query results + query_results = [ + ("file-1", "storage/key/1", "upload-1"), + ("file-2", "storage/key/2", "upload-2"), + ] + + mock_result = MagicMock() + mock_result.__iter__.return_value = iter(query_results) + mock_conn.execute.return_value = mock_result + + # Make storage.delete fail for the first file + mock_storage.delete.side_effect = [Exception("Storage error"), None] + + # Execute function + result = _delete_draft_variable_offload_data(mock_conn, file_ids) + + # Should still return 2 (both files processed, even if one storage delete failed) + assert result == 1 # Only one storage deletion succeeded + + # Verify warning was logged + mock_logging.warning.assert_called_once_with("Failed to delete storage object storage/key/1: Storage error") + + # Verify both database cleanup calls still happened + assert mock_conn.execute.call_count == 3 + + @patch("tasks.remove_app_and_related_data_task.logging") + def test_delete_draft_variable_offload_data_database_failure(self, mock_logging): + """Test handling of database operation failures.""" + mock_conn = MagicMock() + file_ids = ["file-1"] + + # Make execute raise an exception + mock_conn.execute.side_effect = Exception("Database error") + + # Execute function - should not raise, but log error + result = _delete_draft_variable_offload_data(mock_conn, file_ids) + + # Should return 0 when error occurs + assert result == 0 + + # Verify error was logged + mock_logging.error.assert_called_once_with("Error deleting draft variable offload data: Database error")