diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 286ba65a7f..311aa81279 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -8,15 +8,81 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_run_fields import ( advanced_chat_workflow_run_pagination_fields, + workflow_run_count_fields, workflow_run_detail_fields, workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required -from models import Account, App, AppMode, EndUser +from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom from services.workflow_run_service import WorkflowRunService +# Workflow run status choices for filtering +WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] + + +def _parse_workflow_run_list_args(): + """ + Parse common arguments for workflow run list endpoints. + + Returns: + Parsed arguments containing last_id, limit, status, and triggered_from filters + """ + parser = reqparse.RequestParser() + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "status", + type=str, + choices=WORKFLOW_RUN_STATUS_CHOICES, + location="args", + required=False, + ) + parser.add_argument( + "triggered_from", + type=str, + choices=["debugging", "app-run"], + location="args", + required=False, + help="Filter by trigger source: debugging or app-run", + ) + return parser.parse_args() + + +def _parse_workflow_run_count_args(): + """ + Parse common arguments for workflow run count endpoints. + + Returns: + Parsed arguments containing status, time_range, and triggered_from filters + """ + parser = reqparse.RequestParser() + parser.add_argument( + "status", + type=str, + choices=WORKFLOW_RUN_STATUS_CHOICES, + location="args", + required=False, + ) + parser.add_argument( + "time_range", + type=time_duration, + location="args", + required=False, + help="Time range filter (e.g., 7d, 4h, 30m, 30s)", + ) + parser.add_argument( + "triggered_from", + type=str, + choices=["debugging", "app-run"], + location="args", + required=False, + help="Filter by trigger source: debugging or app-run", + ) + return parser.parse_args() + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @@ -24,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource): @api.doc(description="Get advanced chat workflow run list") @api.doc(params={"app_id": "Application ID"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) @setup_required @login_required @@ -34,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource): """ Get advanced chat app workflow run list """ - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = _parse_workflow_run_list_args() + + # Default to DEBUGGING if not specified + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) + result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( + app_model=app_model, args=args, triggered_from=triggered_from + ) + + return result + + +@console_ns.route("/apps//advanced-chat/workflow-runs/count") +class AdvancedChatAppWorkflowRunCountApi(Resource): + @api.doc("get_advanced_chat_workflow_runs_count") + @api.doc(description="Get advanced chat workflow runs count statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc( + params={ + "time_range": ( + "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " + "30m (30 minutes), 30s (30 seconds). Filters by created_at field." + ) + } + ) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) + @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @marshal_with(workflow_run_count_fields) + def get(self, app_model: App): + """ + Get advanced chat workflow runs count statistics + """ + args = _parse_workflow_run_count_args() + + # Default to DEBUGGING if not specified + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) + + workflow_run_service = WorkflowRunService() + result = workflow_run_service.get_workflow_runs_count( + app_model=app_model, + status=args.get("status"), + time_range=args.get("time_range"), + triggered_from=triggered_from, + ) return result @@ -51,6 +170,8 @@ class WorkflowRunListApi(Resource): @api.doc(description="Get workflow run list") @api.doc(params={"app_id": "Application ID"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) @setup_required @login_required @@ -61,13 +182,64 @@ class WorkflowRunListApi(Resource): """ Get workflow run list """ - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = _parse_workflow_run_list_args() + + # Default to DEBUGGING for workflow if not specified (backward compatibility) + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) + result = workflow_run_service.get_paginate_workflow_runs( + app_model=app_model, args=args, triggered_from=triggered_from + ) + + return result + + +@console_ns.route("/apps//workflow-runs/count") +class WorkflowRunCountApi(Resource): + @api.doc("get_workflow_runs_count") + @api.doc(description="Get workflow runs count statistics") + @api.doc(params={"app_id": "Application ID"}) + @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) + @api.doc( + params={ + "time_range": ( + "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " + "30m (30 minutes), 30s (30 seconds). Filters by created_at field." + ) + } + ) + @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) + @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_run_count_fields) + def get(self, app_model: App): + """ + Get workflow runs count statistics + """ + args = _parse_workflow_run_count_args() + + # Default to DEBUGGING for workflow if not specified (backward compatibility) + triggered_from = ( + WorkflowRunTriggeredFrom(args.get("triggered_from")) + if args.get("triggered_from") + else WorkflowRunTriggeredFrom.DEBUGGING + ) + + workflow_run_service = WorkflowRunService() + result = workflow_run_service.get_workflow_runs_count( + app_model=app_model, + status=args.get("status"), + time_range=args.get("time_range"), + triggered_from=triggered_from, + ) return result diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 649e881848..79594beeed 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -64,6 +64,15 @@ workflow_run_pagination_fields = { "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), } +workflow_run_count_fields = { + "total": fields.Integer, + "running": fields.Integer, + "succeeded": fields.Integer, + "failed": fields.Integer, + "stopped": fields.Integer, + "partial_succeeded": fields.Integer(attribute="partial-succeeded"), +} + workflow_run_detail_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/custom_inputs.py b/api/libs/custom_inputs.py new file mode 100644 index 0000000000..10d550ed65 --- /dev/null +++ b/api/libs/custom_inputs.py @@ -0,0 +1,32 @@ +"""Custom input types for Flask-RESTX request parsing.""" + +import re + + +def time_duration(value: str) -> str: + """ + Validate and return time duration string. + + Accepts formats: d (days), h (hours), m (minutes), s (seconds) + Examples: 7d, 4h, 30m, 30s + + Args: + value: The time duration string + + Returns: + The validated time duration string + + Raises: + ValueError: If the format is invalid + """ + if not value: + raise ValueError("Time duration cannot be empty") + + pattern = r"^(\d+)([dhms])$" + if not re.match(pattern, value.lower()): + raise ValueError( + "Invalid time duration format. Use: d (days), h (hours), " + "m (minutes), or s (seconds). Examples: 7d, 4h, 30m, 30s" + ) + + return value.lower() diff --git a/api/libs/time_parser.py b/api/libs/time_parser.py new file mode 100644 index 0000000000..1d9dd92a08 --- /dev/null +++ b/api/libs/time_parser.py @@ -0,0 +1,67 @@ +"""Time duration parser utility.""" + +import re +from datetime import UTC, datetime, timedelta + + +def parse_time_duration(duration_str: str) -> timedelta | None: + """ + Parse time duration string to timedelta. + + Supported formats: + - 7d: 7 days + - 4h: 4 hours + - 30m: 30 minutes + - 30s: 30 seconds + + Args: + duration_str: Duration string (e.g., "7d", "4h", "30m", "30s") + + Returns: + timedelta object or None if invalid format + """ + if not duration_str: + return None + + # Pattern: number followed by unit (d, h, m, s) + pattern = r"^(\d+)([dhms])$" + match = re.match(pattern, duration_str.lower()) + + if not match: + return None + + value = int(match.group(1)) + unit = match.group(2) + + if unit == "d": + return timedelta(days=value) + elif unit == "h": + return timedelta(hours=value) + elif unit == "m": + return timedelta(minutes=value) + elif unit == "s": + return timedelta(seconds=value) + + return None + + +def get_time_threshold(duration_str: str | None) -> datetime | None: + """ + Get datetime threshold from duration string. + + Calculates the datetime that is duration_str ago from now. + + Args: + duration_str: Duration string (e.g., "7d", "4h", "30m", "30s") + + Returns: + datetime object representing the threshold time, or None if no duration + """ + if not duration_str: + return None + + duration = parse_time_duration(duration_str) + if duration is None: + return None + + return datetime.now(UTC) - duration diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 3ac28fad75..72de9fed31 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -59,6 +59,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): triggered_from: str, limit: int = 20, last_id: str | None = None, + status: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -73,6 +74,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): triggered_from: Filter by trigger source (e.g., "debugging", "app-run") limit: Maximum number of records to return (default: 20) last_id: Cursor for pagination - ID of the last record from previous page + status: Optional filter by status (e.g., "running", "succeeded", "failed") Returns: InfiniteScrollPagination object containing: @@ -107,6 +109,43 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_workflow_runs_count( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + status: str | None = None, + time_range: str | None = None, + ) -> dict[str, int]: + """ + Get workflow runs count statistics. + + Retrieves total count and count by status for workflow runs + matching the specified filters. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source (e.g., "debugging", "app-run") + status: Optional filter by specific status + time_range: Optional time range filter (e.g., "7d", "4h", "30m", "30s") + Filters records based on created_at field + + Returns: + Dictionary containing: + - total: Total count of all workflow runs (or filtered by status) + - running: Count of workflow runs with status "running" + - succeeded: Count of workflow runs with status "succeeded" + - failed: Count of workflow runs with status "failed" + - stopped: Count of workflow runs with status "stopped" + - partial_succeeded: Count of workflow runs with status "partial-succeeded" + + Note: If a status is provided, 'total' will be the count for that status, + and the specific status count will also be set to this value, with all + other status counts being 0. + """ + ... + def get_expired_runs_batch( self, tenant_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 6154273f33..68affb59f3 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -24,11 +24,12 @@ from collections.abc import Sequence from datetime import datetime from typing import cast -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker from libs.infinite_scroll_pagination import InfiniteScrollPagination +from libs.time_parser import get_time_threshold from models.workflow import WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository @@ -63,6 +64,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): triggered_from: str, limit: int = 20, last_id: str | None = None, + status: str | None = None, ) -> InfiniteScrollPagination: """ Get paginated workflow runs with filtering. @@ -79,6 +81,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): WorkflowRun.triggered_from == triggered_from, ) + # Add optional status filter + if status: + base_stmt = base_stmt.where(WorkflowRun.status == status) + if last_id: # Get the last workflow run for cursor-based pagination last_run_stmt = base_stmt.where(WorkflowRun.id == last_id) @@ -120,6 +126,73 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) return session.scalar(stmt) + def get_workflow_runs_count( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + status: str | None = None, + time_range: str | None = None, + ) -> dict[str, int]: + """ + Get workflow runs count statistics grouped by status. + """ + _initial_status_counts = { + "running": 0, + "succeeded": 0, + "failed": 0, + "stopped": 0, + "partial-succeeded": 0, + } + + with self._session_maker() as session: + # Build base where conditions + base_conditions = [ + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.triggered_from == triggered_from, + ] + + # Add time range filter if provided + if time_range: + time_threshold = get_time_threshold(time_range) + if time_threshold: + base_conditions.append(WorkflowRun.created_at >= time_threshold) + + # If status filter is provided, return simple count + if status: + count_stmt = select(func.count(WorkflowRun.id)).where(*base_conditions, WorkflowRun.status == status) + total = session.scalar(count_stmt) or 0 + + result = {"total": total} | _initial_status_counts + + # Set the count for the filtered status + if status in result: + result[status] = total + + return result + + # No status filter - get counts grouped by status + base_stmt = ( + select(WorkflowRun.status, func.count(WorkflowRun.id).label("count")) + .where(*base_conditions) + .group_by(WorkflowRun.status) + ) + + # Execute query + results = session.execute(base_stmt).all() + + # Build response dictionary + status_counts = _initial_status_counts.copy() + + total = 0 + for status_val, count in results: + total += count + if status_val in status_counts: + status_counts[status_val] = count + + return {"total": total} | status_counts + def get_expired_runs_batch( self, tenant_id: str, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 6a2edd912a..5c8719b499 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -26,13 +26,15 @@ class WorkflowRunService: ) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) - def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + def get_paginate_advanced_chat_workflow_runs( + self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + ) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list - Only return triggered_from == advanced_chat :param app_model: app model :param args: request args + :param triggered_from: workflow run triggered from (default: DEBUGGING for preview runs) """ class WorkflowWithMessage: @@ -45,7 +47,7 @@ class WorkflowRunService: def __getattr__(self, item): return getattr(self._workflow_run, item) - pagination = self.get_paginate_workflow_runs(app_model, args) + pagination = self.get_paginate_workflow_runs(app_model, args, triggered_from) with_message_workflow_runs = [] for workflow_run in pagination.data: @@ -60,23 +62,27 @@ class WorkflowRunService: pagination.data = with_message_workflow_runs return pagination - def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: + def get_paginate_workflow_runs( + self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + ) -> InfiniteScrollPagination: """ - Get debug workflow run list - Only return triggered_from == debugging + Get workflow run list :param app_model: app model :param args: request args + :param triggered_from: workflow run triggered from (default: DEBUGGING) """ limit = int(args.get("limit", 20)) last_id = args.get("last_id") + status = args.get("status") return self._workflow_run_repo.get_paginated_workflow_runs( tenant_id=app_model.tenant_id, app_id=app_model.id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + triggered_from=triggered_from, limit=limit, last_id=last_id, + status=status, ) def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None: @@ -92,6 +98,30 @@ class WorkflowRunService: run_id=run_id, ) + def get_workflow_runs_count( + self, + app_model: App, + status: str | None = None, + time_range: str | None = None, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, + ) -> dict[str, int]: + """ + Get workflow runs count statistics + + :param app_model: app model + :param status: optional status filter + :param time_range: optional time range filter (e.g., "7d", "4h", "30m", "30s") + :param triggered_from: workflow run triggered from (default: DEBUGGING) + :return: dict with total and status counts + """ + return self._workflow_run_repo.get_workflow_runs_count( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + triggered_from=triggered_from, + status=status, + time_range=time_range, + ) + def get_workflow_run_node_executions( self, app_model: App, diff --git a/api/tests/unit_tests/libs/test_custom_inputs.py b/api/tests/unit_tests/libs/test_custom_inputs.py new file mode 100644 index 0000000000..7e4c3b4ff0 --- /dev/null +++ b/api/tests/unit_tests/libs/test_custom_inputs.py @@ -0,0 +1,68 @@ +"""Unit tests for custom input types.""" + +import pytest + +from libs.custom_inputs import time_duration + + +class TestTimeDuration: + """Test time_duration input validator.""" + + def test_valid_days(self): + """Test valid days format.""" + result = time_duration("7d") + assert result == "7d" + + def test_valid_hours(self): + """Test valid hours format.""" + result = time_duration("4h") + assert result == "4h" + + def test_valid_minutes(self): + """Test valid minutes format.""" + result = time_duration("30m") + assert result == "30m" + + def test_valid_seconds(self): + """Test valid seconds format.""" + result = time_duration("30s") + assert result == "30s" + + def test_uppercase_conversion(self): + """Test uppercase units are converted to lowercase.""" + result = time_duration("7D") + assert result == "7d" + + result = time_duration("4H") + assert result == "4h" + + def test_invalid_format_no_unit(self): + """Test invalid format without unit.""" + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("7") + + def test_invalid_format_wrong_unit(self): + """Test invalid format with wrong unit.""" + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("7days") + + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("7x") + + def test_invalid_format_no_number(self): + """Test invalid format without number.""" + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("d") + + with pytest.raises(ValueError, match="Invalid time duration format"): + time_duration("abc") + + def test_empty_string(self): + """Test empty string.""" + with pytest.raises(ValueError, match="Time duration cannot be empty"): + time_duration("") + + def test_none(self): + """Test None value.""" + with pytest.raises(ValueError, match="Time duration cannot be empty"): + time_duration(None) diff --git a/api/tests/unit_tests/libs/test_time_parser.py b/api/tests/unit_tests/libs/test_time_parser.py new file mode 100644 index 0000000000..83ff251272 --- /dev/null +++ b/api/tests/unit_tests/libs/test_time_parser.py @@ -0,0 +1,91 @@ +"""Unit tests for time parser utility.""" + +from datetime import UTC, datetime, timedelta + +from libs.time_parser import get_time_threshold, parse_time_duration + + +class TestParseTimeDuration: + """Test parse_time_duration function.""" + + def test_parse_days(self): + """Test parsing days.""" + result = parse_time_duration("7d") + assert result == timedelta(days=7) + + def test_parse_hours(self): + """Test parsing hours.""" + result = parse_time_duration("4h") + assert result == timedelta(hours=4) + + def test_parse_minutes(self): + """Test parsing minutes.""" + result = parse_time_duration("30m") + assert result == timedelta(minutes=30) + + def test_parse_seconds(self): + """Test parsing seconds.""" + result = parse_time_duration("30s") + assert result == timedelta(seconds=30) + + def test_parse_uppercase(self): + """Test parsing uppercase units.""" + result = parse_time_duration("7D") + assert result == timedelta(days=7) + + def test_parse_invalid_format(self): + """Test parsing invalid format.""" + result = parse_time_duration("7days") + assert result is None + + result = parse_time_duration("abc") + assert result is None + + result = parse_time_duration("7") + assert result is None + + def test_parse_empty_string(self): + """Test parsing empty string.""" + result = parse_time_duration("") + assert result is None + + def test_parse_none(self): + """Test parsing None.""" + result = parse_time_duration(None) + assert result is None + + +class TestGetTimeThreshold: + """Test get_time_threshold function.""" + + def test_get_threshold_days(self): + """Test getting threshold for days.""" + before = datetime.now(UTC) + result = get_time_threshold("7d") + after = datetime.now(UTC) + + assert result is not None + # Result should be approximately 7 days ago + expected = before - timedelta(days=7) + # Allow 1 second tolerance for test execution time + assert abs((result - expected).total_seconds()) < 1 + + def test_get_threshold_hours(self): + """Test getting threshold for hours.""" + before = datetime.now(UTC) + result = get_time_threshold("4h") + after = datetime.now(UTC) + + assert result is not None + expected = before - timedelta(hours=4) + assert abs((result - expected).total_seconds()) < 1 + + def test_get_threshold_invalid(self): + """Test getting threshold with invalid duration.""" + result = get_time_threshold("invalid") + assert result is None + + def test_get_threshold_none(self): + """Test getting threshold with None.""" + result = get_time_threshold(None) + assert result is None diff --git a/api/tests/unit_tests/repositories/test_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_workflow_run_repository.py new file mode 100644 index 0000000000..8f47f0df48 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_workflow_run_repository.py @@ -0,0 +1,251 @@ +"""Unit tests for workflow run repository with status filter.""" + +import uuid +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.orm import sessionmaker + +from models import WorkflowRun, WorkflowRunTriggeredFrom +from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository + + +class TestDifyAPISQLAlchemyWorkflowRunRepository: + """Test workflow run repository with status filtering.""" + + @pytest.fixture + def mock_session_maker(self): + """Create a mock session maker.""" + return MagicMock(spec=sessionmaker) + + @pytest.fixture + def repository(self, mock_session_maker): + """Create repository instance with mock session.""" + return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker) + + def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker): + """Test getting paginated workflow runs without status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_runs + + # Act + result = repository.get_paginated_workflow_runs( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + # Assert + assert len(result.data) == 3 + assert result.limit == 20 + assert result.has_more is False + + def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker): + """Test getting paginated workflow runs with status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)] + mock_session.scalars.return_value.all.return_value = mock_runs + + # Act + result = repository.get_paginated_workflow_runs( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status="succeeded", + ) + + # Assert + assert len(result.data) == 2 + assert all(run.status == "succeeded" for run in result.data) + + def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker): + """Test getting workflow runs count without status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the GROUP BY query results + mock_results = [ + ("succeeded", 5), + ("failed", 2), + ("running", 1), + ] + mock_session.execute.return_value.all.return_value = mock_results + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + ) + + # Assert + assert result["total"] == 8 + assert result["succeeded"] == 5 + assert result["failed"] == 2 + assert result["running"] == 1 + assert result["stopped"] == 0 + assert result["partial-succeeded"] == 0 + + def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker): + """Test getting workflow runs count with status filter.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the count query for succeeded status + mock_session.scalar.return_value = 5 + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + ) + + # Assert + assert result["total"] == 5 + assert result["succeeded"] == 5 + assert result["running"] == 0 + assert result["failed"] == 0 + assert result["stopped"] == 0 + assert result["partial-succeeded"] == 0 + + def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker): + """Test that invalid status is still counted in total but not in any specific status.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock count query returning 0 for invalid status + mock_session.scalar.return_value = 0 + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="invalid_status", + ) + + # Assert + assert result["total"] == 0 + assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"]) + + def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker): + """Test getting workflow runs count with time range filter verifies SQL query construction.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the GROUP BY query results + mock_results = [ + ("succeeded", 3), + ("running", 2), + ] + mock_session.execute.return_value.all.return_value = mock_results + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + time_range="1d", + ) + + # Assert results + assert result["total"] == 5 + assert result["succeeded"] == 3 + assert result["running"] == 2 + assert result["failed"] == 0 + + # Verify that execute was called (which means GROUP BY query was used) + assert mock_session.execute.called, "execute should have been called for GROUP BY query" + + # Verify SQL query includes time filter by checking the statement + call_args = mock_session.execute.call_args + assert call_args is not None, "execute should have been called with a statement" + + # The first argument should be the SQL statement + stmt = call_args[0][0] + # Convert to string to inspect the query + query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) + + # Verify the query includes created_at filter + # The query should have a WHERE clause with created_at comparison + assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( + "Query should include created_at filter for time range" + ) + + def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker): + """Test getting workflow runs count with both status and time range filters verifies SQL query.""" + # Arrange + tenant_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + mock_session = MagicMock() + mock_session_maker.return_value.__enter__.return_value = mock_session + + # Mock the count query for running status within time range + mock_session.scalar.return_value = 2 + + # Act + result = repository.get_workflow_runs_count( + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="running", + time_range="1d", + ) + + # Assert results + assert result["total"] == 2 + assert result["running"] == 2 + assert result["succeeded"] == 0 + assert result["failed"] == 0 + + # Verify that scalar was called (which means COUNT query was used) + assert mock_session.scalar.called, "scalar should have been called for count query" + + # Verify SQL query includes both status and time filter + call_args = mock_session.scalar.call_args + assert call_args is not None, "scalar should have been called with a statement" + + # The first argument should be the SQL statement + stmt = call_args[0][0] + # Convert to string to inspect the query + query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) + + # Verify the query includes both filters + assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( + "Query should include created_at filter for time range" + ) + assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), ( + "Query should include status filter" + )