Merge branch 'main' into feat/memory-orchestration-fed

This commit is contained in:
zxhlyh 2025-08-14 15:04:56 +08:00
commit a5dfc09e90
158 changed files with 7513 additions and 862 deletions

View File

@ -5,7 +5,7 @@ cd web && pnpm install
pipx install uv
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc

View File

@ -8,7 +8,7 @@ inputs:
uv-version:
description: UV version to set up
required: true
default: '~=0.7.11'
default: '0.8.9'
uv-lockfile:
description: Path to the UV lockfile to restore cache from
required: true
@ -26,7 +26,7 @@ runs:
python-version: ${{ inputs.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@v6
with:
version: ${{ inputs.uv-version }}
python-version: ${{ inputs.python-version }}

View File

@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api
# Install uv
ENV UV_VERSION=0.7.11
ENV UV_VERSION=0.8.9
RUN pip install --no-cache-dir uv==${UV_VERSION}

View File

@ -74,7 +74,7 @@
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:

View File

@ -36,6 +36,7 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
@click.command("reset-password", help="Reset the account password.")
@ -1202,3 +1203,138 @@ def setup_system_tool_oauth_client(provider, client_params):
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.
Args:
batch_size: Maximum number of orphaned app IDs to return
Returns:
List of app IDs that have draft variables but don't exist in the apps table
"""
query = """
SELECT DISTINCT wdv.app_id
FROM workflow_draft_variables AS wdv
WHERE NOT EXISTS(
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
)
LIMIT :batch_size
"""
with db.engine.connect() as conn:
result = conn.execute(sa.text(query), {"batch_size": batch_size})
return [row[0] for row in result]
def _count_orphaned_draft_variables() -> dict[str, Any]:
"""
Count orphaned draft variables by app.
Returns:
Dictionary with statistics about orphaned variables
"""
query = """
SELECT
wdv.app_id,
COUNT(*) as variable_count
FROM workflow_draft_variables AS wdv
WHERE NOT EXISTS(
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
)
GROUP BY wdv.app_id
ORDER BY variable_count DESC
"""
with db.engine.connect() as conn:
result = conn.execute(sa.text(query))
orphaned_by_app = {row[0]: row[1] for row in result}
total_orphaned = sum(orphaned_by_app.values())
app_count = len(orphaned_by_app)
return {
"total_orphaned_variables": total_orphaned,
"orphaned_app_count": app_count,
"orphaned_by_app": orphaned_by_app,
}
@click.command()
@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting")
@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)")
@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)")
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
def cleanup_orphaned_draft_variables(
dry_run: bool,
batch_size: int,
max_apps: int | None,
force: bool = False,
):
"""
Clean up orphaned draft variables from the database.
This script finds and removes draft variables that belong to apps
that no longer exist in the database.
"""
logger = logging.getLogger(__name__)
# Get statistics
stats = _count_orphaned_draft_variables()
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
if stats["total_orphaned_variables"] == 0:
logger.info("No orphaned draft variables found. Exiting.")
return
if dry_run:
logger.info("DRY RUN: Would delete the following:")
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
:10
]: # Show top 10
logger.info(" App %s: %s variables", app_id, count)
if len(stats["orphaned_by_app"]) > 10:
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
return
# Confirm deletion
if not force:
click.confirm(
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
abort=True,
)
total_deleted = 0
processed_apps = 0
while True:
if max_apps and processed_apps >= max_apps:
logger.info("Reached maximum app limit (%s). Stopping.", max_apps)
break
orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10)
if not orphaned_app_ids:
logger.info("No more orphaned draft variables found.")
break
for app_id in orphaned_app_ids:
if max_apps and processed_apps >= max_apps:
break
try:
deleted_count = delete_draft_variables_batch(app_id, batch_size)
total_deleted += deleted_count
processed_apps += 1
logger.info("Deleted %s variables for app %s", deleted_count, app_id)
except Exception:
logger.exception("Error processing app %s", app_id)
continue
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)

View File

@ -552,12 +552,18 @@ class RepositoryConfig(BaseSettings):
"""
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Specify as a module path",
description="Repository implementation for WorkflowExecution. Options: "
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
)
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
description="Repository implementation for WorkflowNodeExecution. Options: "
"'core.repositories.sqlalchemy_workflow_node_execution_repository."
"SQLAlchemyWorkflowNodeExecutionRepository' (default), "
"'core.repositories.celery_workflow_node_execution_repository."
"CeleryWorkflowNodeExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
)

View File

@ -1,5 +1,7 @@
from werkzeug.exceptions import HTTPException
from libs.exception import BaseHTTPException
class FilenameNotExistsError(HTTPException):
code = 400
@ -9,3 +11,27 @@ class FilenameNotExistsError(HTTPException):
class RemoteFileUploadError(HTTPException):
code = 400
description = "Error uploading remote file."
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400

View File

@ -3,9 +3,8 @@ from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,

View File

@ -79,18 +79,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."

View File

@ -27,7 +27,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@ -124,17 +124,34 @@ class MessageFeedbackApi(Resource):
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
try:
MessageService.create_feedback(
app_model=app_model,
message_id=str(args["message_id"]),
user=current_user,
rating=args.get("rating"),
content=None,
)
except MessageNotExistsError:
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
feedback = message.admin_feedback
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
)
db.session.add(feedback)
db.session.commit()
return {"result": "success"}

View File

@ -1,30 +1,6 @@
from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."

View File

@ -76,30 +76,6 @@ class EmailSendIpLimitError(BaseHTTPException):
code = 429
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class UnauthorizedAndForceLogout(BaseHTTPException):
error_code = "unauthorized_and_force_logout"
description = "Unauthorized and force logout."

View File

@ -8,7 +8,13 @@ from werkzeug.exceptions import Forbidden
import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import FilenameNotExistsError
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -18,13 +24,6 @@ from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import FileService
from .error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
PREVIEW_WORDS_LIMIT = 3000

View File

@ -7,18 +7,17 @@ from flask_restful import Resource, marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from controllers.common.errors import (
FileTooLargeError,
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from models.account import Account
from services.file_service import FileService
from .error import (
FileTooLargeError,
UnsupportedFileTypeError,
)
class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields)

View File

@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
args = parser.parse_args()
user = current_user
if not is_valid_url(args["server_url"]):
@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource):
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
)
)
@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
)
return {"result": "success"}

View File

@ -7,15 +7,15 @@ from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.datasets.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,

View File

@ -1,7 +0,0 @@
from libs.exception import BaseHTTPException
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415

View File

@ -5,8 +5,8 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from services.account_service import TenantService
from services.file_service import FileService

View File

@ -4,8 +4,8 @@ from flask import Response
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db

View File

@ -5,11 +5,13 @@ from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import setup_required
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from controllers.inner_api.plugin.wraps import get_user
from controllers.service_api.app.error import FileTooLargeError
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import file_fields

View File

@ -85,30 +85,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class FileNotFoundError(BaseHTTPException):
error_code = "file_not_found"
description = "The requested file was not found."

View File

@ -2,14 +2,14 @@ from flask import request
from flask_restful import Resource, marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser

View File

@ -6,15 +6,15 @@ from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
ProviderNotInitializeError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
DocumentIndexingError,

View File

@ -1,30 +1,6 @@
from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."

View File

@ -97,30 +97,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class WebAppAuthRequiredError(BaseHTTPException):
error_code = "web_sso_auth_required"
description = "Web app authentication required."

View File

@ -2,8 +2,13 @@ from flask import request
from flask_restful import marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.web.wraps import WebApiResource
from fields.file_fields import file_fields
from services.file_service import FileService

View File

@ -5,15 +5,17 @@ from flask_restful import marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from controllers.common.errors import (
FileTooLargeError,
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from services.file_service import FileService
from .error import FileTooLargeError, UnsupportedFileTypeError
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)

View File

@ -74,6 +74,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models.enums import CreatorUserRole
@ -896,6 +897,7 @@ class AdvancedChatAppGenerateTaskPipeline:
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [

View File

@ -57,6 +57,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@ -389,6 +390,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if llm_result.message.content
else ""
)
message.updated_at = naive_utc_now()
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit

View File

@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
)
response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code)
except Exception as exc:
except Exception:
logger.exception("Error sending message")
raise

View File

@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
pass
@dataclass
class RequestContext:
@ -74,7 +70,7 @@ class RequestContext:
session_message: SessionMessage
metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: timedelta
sse_read_timeout: float
class StreamableHTTPTransport:
@ -84,8 +80,8 @@ class StreamableHTTPTransport:
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
) -> None:
"""Initialize the StreamableHTTP transport.
@ -97,8 +93,10 @@ class StreamableHTTPTransport:
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
@ -186,7 +184,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=client,
method="GET",
) as event_source:
@ -215,7 +213,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
client=ctx.client,
method="GET",
) as event_source:
@ -402,8 +400,8 @@ class StreamableHTTPTransport:
def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
) -> Generator[
tuple[
@ -436,7 +434,7 @@ def streamablehttp_client(
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream() -> None:

View File

@ -23,12 +23,18 @@ class MCPClient:
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Authentication info
self.authed = authed
@ -43,7 +49,7 @@ class MCPClient:
self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self.exit_stack = ExitStack()
self._exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
@ -90,21 +96,26 @@ class MCPClient:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self.exit_stack.enter_context(self._streams_context)
streams = self._exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self.exit_stack.enter_context(self._session_context)
self._session = self._exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
return
@ -120,9 +131,6 @@ class MCPClient:
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
@ -142,7 +150,7 @@ class MCPClient:
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close()
self._exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -2,7 +2,6 @@ import logging
import queue
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
@ -170,7 +169,6 @@ class BaseSession(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
@ -377,7 +375,7 @@ class BaseSession(
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception as e:
except Exception:
logging.exception("Error in message processing loop")
raise
@ -389,14 +387,12 @@ class BaseSession(
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
pass
def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
pass
def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
@ -405,11 +401,9 @@ class BaseSession(
Sends a progress notification for a request that is currently being
processed.
"""
pass
def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass

View File

@ -1,3 +1,4 @@
import queue
from datetime import timedelta
from typing import Any, Protocol
@ -85,8 +86,8 @@ class ClientSession(
):
def __init__(
self,
read_stream,
write_stream,
read_stream: queue.Queue,
write_stream: queue.Queue,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,

View File

@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",

View File

@ -0,0 +1,126 @@
"""
Celery-based implementation of the WorkflowExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from typing import Optional, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.enums import WorkflowRunTriggeredFrom
from tasks.workflow_execution_tasks import (
save_workflow_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
"""
Celery-based implementation of the WorkflowExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: Optional[str]
_triggered_from: Optional[WorkflowRunTriggeredFrom]
_creator_user_id: str
_creator_user_role: CreatorUserRole
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom],
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
logger.info(
"Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution instance asynchronously using Celery.
This method queues the save operation as a Celery task and returns immediately,
providing improved performance for high-throughput scenarios.
Args:
execution: The WorkflowExecution instance to save or update
"""
try:
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Queued async save for workflow execution: %s", execution.id_)
except Exception as e:
logger.exception("Failed to queue save operation for execution %s", execution.id_)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise

View File

@ -0,0 +1,190 @@
"""
Celery-based implementation of the WorkflowNodeExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from collections.abc import Sequence
from typing import Optional, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.repositories.workflow_node_execution_repository import (
OrderConfig,
WorkflowNodeExecutionRepository,
)
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
from tasks.workflow_node_execution_tasks import (
save_workflow_node_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow node execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- In-memory cache for immediate reads
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: Optional[str]
_triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom]
_creator_user_id: str
_creator_user_role: CreatorUserRole
_execution_cache: dict[str, WorkflowNodeExecution]
_workflow_execution_mapping: dict[str, list[str]]
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# In-memory cache for workflow node executions
self._execution_cache: dict[str, WorkflowNodeExecution] = {}
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
self._workflow_execution_mapping: dict[str, list[str]] = {}
logger.info(
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a WorkflowNodeExecution instance to cache and asynchronously to database.
This method stores the execution in cache immediately for fast reads and queues
the save operation as a Celery task without tracking the task status.
Args:
execution: The WorkflowNodeExecution instance to save or update
"""
try:
# Store in cache immediately for fast reads
self._execution_cache[execution.id] = execution
# Update workflow execution mapping for efficient retrieval
if execution.workflow_execution_id:
if execution.workflow_execution_id not in self._workflow_execution_mapping:
self._workflow_execution_mapping[execution.workflow_execution_id] = []
if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]:
self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id)
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_node_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
except Exception as e:
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
Returns:
A sequence of WorkflowNodeExecution instances
"""
try:
# Get execution IDs for this workflow run from cache
execution_ids = self._workflow_execution_mapping.get(workflow_run_id, [])
# Retrieve executions from cache
result = []
for execution_id in execution_ids:
if execution_id in self._execution_cache:
result.append(self._execution_cache[execution_id])
# Apply ordering if specified
if order_config and result:
# Sort based on the order configuration
reverse = order_config.order_direction == "desc"
# Sort by multiple fields if specified
for field_name in reversed(order_config.order_by):
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id)
return result
except Exception as e:
logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id)
return []

View File

@ -94,11 +94,9 @@ class DifyCoreRepositoryFactory:
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
"""
Validate that a repository class constructor accepts required parameters.
Args:
repository_class: The class to validate
required_params: List of required parameter names
Raises:
RepositoryImportError: If the constructor doesn't accept required parameters
"""
@ -158,10 +156,8 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
# All repository types now use the same constructor parameters
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
@ -204,10 +200,8 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
# All repository types now use the same constructor parameters
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,

View File

@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC):
entity: ToolProviderEntity
def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity

View File

@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Optional
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService
class MCPToolProviderController(ToolProviderController):
provider_id: str
entity: ToolProviderEntityWithPlugin
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
def __init__(
self,
entity: ToolProviderEntityWithPlugin,
provider_id: str,
tenant_id: str,
server_url: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
super().__init__(entity)
self.entity = entity
self.entity: ToolProviderEntityWithPlugin = entity
self.tenant_id = tenant_id
self.provider_id = provider_id
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
@property
def provider_type(self) -> ToolProviderType:
@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController):
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
for tool_entity in self.entity.tools
]

View File

@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class MCPTool(Tool):
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
self,
entity: ToolEntity,
runtime: ToolRuntime,
tenant_id: str,
icon: str,
server_url: str,
provider_id: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.server_url = server_url
self.provider_id = provider_id
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@ -35,7 +47,15 @@ class MCPTool(Tool):
from core.tools.errors import ToolInvokeError
try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
@ -72,6 +92,9 @@ class MCPTool(Tool):
icon=self.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:

View File

@ -789,9 +789,6 @@ class ToolManager:
"""
get api provider
"""
"""
get tool provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)

View File

@ -276,17 +276,26 @@ class Executor:
encoded_credentials = credentials
headers[authorization.config.header] = f"Basic {encoded_credentials}"
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""
if authorization.config.header and authorization.config.api_key:
headers[authorization.config.header] = authorization.config.api_key
# Handle Content-Type for multipart/form-data requests
# Fix for issue #22880: Missing boundary when using multipart/form-data
# Fix for issue #23829: Missing boundary when using multipart/form-data
body = self.node_data.body
if body and body.type == "form-data":
# For multipart/form-data with files, let httpx handle the boundary automatically
# by not setting Content-Type header when files are present
if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files):
# Only set Content-Type when there are no actual files
# This ensures httpx generates the correct boundary
# For multipart/form-data with files (including placeholder files),
# remove any manually set Content-Type header to let httpx handle
# For multipart/form-data, if any files are present (including placeholder files),
# we must remove any manually set Content-Type header. This is because httpx needs to
# automatically set the Content-Type and boundary for multipart encoding whenever files
# are included, even if they are placeholders, to avoid boundary issues and ensure correct
# file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the
# boundary, resulting in invalid requests.
if self.files:
# Remove Content-Type if it was manually set to avoid boundary issues
headers = {k: v for k, v in headers.items() if k.lower() != "content-type"}
else:
# No files at all, set Content-Type manually
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = "multipart/form-data"
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:

View File

@ -5,7 +5,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
@ -194,6 +194,17 @@ class LLMNode(BaseNode):
else []
)
# single step run fetch file from sys files
if not files and self.invoke_from == InvokeFrom.DEBUGGER and not self.previous_node_id:
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=["sys", "files"],
)
if self._node_data.vision.enabled
else []
)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]

View File

@ -1,4 +1,5 @@
from collections.abc import Mapping
from decimal import Decimal
from typing import Any
from pydantic import BaseModel
@ -17,6 +18,9 @@ class WorkflowRuntimeTypeConverter:
return value
if isinstance(value, (bool, int, str, float)):
return value
if isinstance(value, Decimal):
# Convert Decimal to float for JSON serialization
return float(value)
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
if isinstance(value, File):

View File

@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin}
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

View File

@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
convert_to_agent_apps,
@ -42,6 +43,7 @@ def init_app(app: DifyApp):
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
setup_system_tool_oauth_client,
cleanup_orphaned_draft_variables,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -1,18 +1,23 @@
import functools
import logging
from collections.abc import Callable
from typing import Any, Union
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
import redis
from redis import RedisError
from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.lock import Lock
from redis.sentinel import Sentinel
from configs import dify_config
from dify_app import DifyApp
if TYPE_CHECKING:
from redis.lock import Lock
logger = logging.getLogger(__name__)
@ -28,8 +33,8 @@ class RedisClientWrapper:
a failover in a Sentinel-managed Redis setup.
Attributes:
_client (redis.Redis): The actual Redis client instance. It remains None until
initialized with the `initialize` method.
_client: The actual Redis client instance. It remains None until
initialized with the `initialize` method.
Methods:
initialize(client): Initializes the Redis client if it hasn't been initialized already.
@ -37,20 +42,78 @@ class RedisClientWrapper:
if the client is not initialized.
"""
def __init__(self):
_client: Union[redis.Redis, RedisCluster, None]
def __init__(self) -> None:
self._client = None
def initialize(self, client):
def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None:
if self._client is None:
self._client = client
def __getattr__(self, item):
if TYPE_CHECKING:
# Type hints for IDE support and static analysis
# These are not executed at runtime but provide type information
def get(self, name: str | bytes) -> Any: ...
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any: ...
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
def setnx(self, name: str | bytes, value: Any) -> Any: ...
def delete(self, *names: str | bytes) -> Any: ...
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Lock: ...
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def __getattr__(self, item: str) -> Any:
if self._client is None:
raise RuntimeError("Redis client is not initialized. Call init_app first.")
return getattr(self._client, item)
redis_client = RedisClientWrapper()
redis_client: RedisClientWrapper = RedisClientWrapper()
def init_app(app: DifyApp):
@ -80,6 +143,9 @@ def init_app(app: DifyApp):
if dify_config.REDIS_USE_SENTINEL:
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, (
"REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True"
)
sentinel_hosts = [
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
]

View File

@ -248,6 +248,8 @@ def _get_remote_file_info(url: str):
# Initialize mime_type from filename as fallback
mime_type, _ = mimetypes.guess_type(filename)
if mime_type is None:
mime_type = ""
resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
@ -256,7 +258,12 @@ def _get_remote_file_info(url: str):
filename = str(content_disposition.split("filename=")[-1].strip('"'))
# Re-guess mime_type from updated filename
mime_type, _ = mimetypes.guess_type(filename)
if mime_type is None:
mime_type = ""
file_size = int(resp.headers.get("Content-Length", file_size))
# Fallback to Content-Type header if mime_type is still empty
if not mime_type:
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
return mime_type, filename, file_size

View File

@ -0,0 +1,33 @@
"""add timeout for tool_mcp_providers
Revision ID: fa8b0fa6f407
Revises: 532b3f888abf
Create Date: 2025-08-07 11:15:31.517985
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fa8b0fa6f407'
down_revision = '532b3f888abf'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.drop_column('sse_read_timeout')
batch_op.drop_column('timeout')
# ### end Alembic commands ###

View File

@ -278,6 +278,8 @@ class MCPToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()

View File

@ -162,6 +162,7 @@ dev = [
"pandas-stubs~=2.2.3",
"scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
"types-redis>=4.6.0.20241004",
]
############################################################

View File

@ -1,5 +1,6 @@
import datetime
import time
from typing import Optional, TypedDict
import click
from sqlalchemy import func, select
@ -14,168 +15,140 @@ from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Documen
from services.feature_service import FeatureService
class CleanupConfig(TypedDict):
clean_day: datetime.datetime
plan_filter: Optional[str]
add_logs: bool
@app.celery.task(queue="dataset")
def clean_unused_datasets_task():
click.echo(click.style("Start clean unused datasets indexes.", fg="green"))
plan_sandbox_clean_day_setting = dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING
plan_pro_clean_day_setting = dify_config.PLAN_PRO_CLEAN_DAY_SETTING
start_at = time.perf_counter()
plan_sandbox_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_sandbox_clean_day_setting)
plan_pro_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_pro_clean_day_setting)
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > plan_sandbox_clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < plan_sandbox_clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
# Define cleanup configurations
cleanup_configs: list[CleanupConfig] = [
{
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING),
"plan_filter": None,
"add_logs": True,
},
{
"clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_PRO_CLEAN_DAY_SETTING),
"plan_filter": "sandbox",
"add_logs": False,
},
]
# Main query with join and filter
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.where(
Dataset.created_at < plan_sandbox_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
)
for config in cleanup_configs:
clean_day = config["clean_day"]
plan_filter = config["plan_filter"]
add_logs = config["add_logs"]
datasets = db.paginate(stmt, page=1, per_page=50)
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
# add auto disable log
documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > clean_day,
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
# update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
except Exception as e:
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
while True:
try:
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at > plan_pro_clean_day,
.group_by(Document.dataset_id)
.subquery()
)
.group_by(Document.dataset_id)
.subquery()
)
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < plan_pro_clean_day,
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.updated_at < clean_day,
)
.group_by(Document.dataset_id)
.subquery()
)
.group_by(Document.dataset_id)
.subquery()
)
# Main query with join and filter
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.where(
Dataset.created_at < plan_pro_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
# Main query with join and filter
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.where(
Dataset.created_at < clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
)
.order_by(Dataset.created_at.desc())
)
datasets = db.paginate(stmt, page=1, per_page=50)
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
features_cache_key = f"features:{dataset.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(dataset.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == "sandbox":
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
datasets = db.paginate(stmt, page=1, per_page=50)
except SQLAlchemyError:
raise
if datasets.items is None or len(datasets.items) == 0:
break
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
try:
should_clean = True
# Check plan filter if specified
if plan_filter:
features_cache_key = f"features:{dataset.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(dataset.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
should_clean = plan == plan_filter
if should_clean:
# Add auto disable log if required
if add_logs:
documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# Remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
# Update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update(
{Document.enabled: False}
)
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
except Exception as e:
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
# update document
db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
db.session.commit()
click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green"))
except Exception as e:
click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red"))
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned unused dataset from db success latency: {end_at - start_at}", fg="green"))

View File

@ -24,9 +24,20 @@ def queue_monitor_task():
queue_name = "dataset"
threshold = dify_config.QUEUE_MONITOR_THRESHOLD
if threshold is None:
logging.warning(click.style("QUEUE_MONITOR_THRESHOLD is not configured, skipping monitoring", fg="yellow"))
return
try:
queue_length = celery_redis.llen(f"{queue_name}")
logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
if queue_length is None:
logging.error(
click.style(f"Failed to get queue length for {queue_name} - Redis may be unavailable", fg="red")
)
return
logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
if queue_length >= threshold:

View File

@ -59,6 +59,8 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float,
sse_read_timeout: float,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
@ -91,6 +93,8 @@ class MCPToolManageService:
tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
db.session.add(mcp_tool)
db.session.commit()
@ -166,6 +170,8 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
@ -197,6 +203,10 @@ class MCPToolManageService:
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
db.session.commit()
except IntegrityError as e:
db.session.rollback()

View File

@ -33,7 +33,11 @@ from models import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog
from models.workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
)
from repositories.factory import DifyAPIRepositoryFactory
@ -62,6 +66,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_end_users(tenant_id, app_id)
_delete_trace_app_configs(tenant_id, app_id)
_delete_conversation_variables(app_id=app_id)
_delete_draft_variables(app_id)
end_at = time.perf_counter()
logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
@ -91,7 +96,12 @@ def _delete_app_site(tenant_id: str, app_id: str):
def del_site(site_id: str):
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_site,
"site",
)
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
@ -111,7 +121,10 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token"
"""select id from api_tokens where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_api_token,
"api token",
)
@ -273,7 +286,10 @@ def _delete_app_messages(tenant_id: str, app_id: str):
db.session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message"
"""select id from messages where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_message,
"message",
)
@ -329,6 +345,56 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
)
def _delete_draft_variables(app_id: str):
"""Delete all workflow draft variables for an app in batches."""
return delete_draft_variables_batch(app_id, batch_size=1000)
def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
"""
Delete draft variables for an app in batches.
Args:
app_id: The ID of the app whose draft variables should be deleted
batch_size: Number of records to delete per batch
Returns:
Total number of records deleted
"""
if batch_size <= 0:
raise ValueError("batch_size must be positive")
total_deleted = 0
while True:
with db.engine.begin() as conn:
# Get a batch of draft variable IDs
query_sql = """
SELECT id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
draft_var_ids = [row[0] for row in result]
if not draft_var_ids:
break
# Delete the batch
delete_sql = """
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
batch_deleted = deleted_result.rowcount
total_deleted += batch_deleted
logging.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
logging.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green"))
return total_deleted
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:

View File

@ -0,0 +1,136 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow execution to the database.
Args:
execution_data: Serialized WorkflowExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
# Check if workflow run already exists
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
if existing_run:
# Update existing workflow run
_update_workflow_run_from_execution(existing_run, execution)
logger.debug("Updated existing workflow run: %s", execution.id_)
else:
# Create new workflow run
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(workflow_run)
logger.debug("Created new workflow run: %s", execution.id_)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_workflow_run_from_execution(
execution: WorkflowExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowRun:
"""
Create a WorkflowRun database model from a WorkflowExecution domain entity.
"""
workflow_run = WorkflowRun()
workflow_run.id = execution.id_
workflow_run.tenant_id = tenant_id
workflow_run.app_id = app_id
workflow_run.workflow_id = execution.workflow_id
workflow_run.type = execution.workflow_type.value
workflow_run.triggered_from = triggered_from.value
workflow_run.version = execution.workflow_version
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.created_by_role = creator_user_role.value
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
workflow_run.finished_at = execution.finished_at
return workflow_run
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.finished_at = execution.finished_at

View File

@ -0,0 +1,171 @@
"""
Celery tasks for asynchronous workflow node execution storage operations.
These tasks provide asynchronous storage capabilities for workflow node execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow node execution to the database.
Args:
execution_data: Serialized WorkflowNodeExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug("Updated existing workflow node execution: %s", execution.id)
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(node_execution)
logger.debug("Created new workflow node execution: %s", execution.id)
session.commit()
return True
except Exception as e:
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_node_execution_from_domain(
execution: WorkflowNodeExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowNodeExecutionModel:
"""
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
node_execution = WorkflowNodeExecutionModel()
node_execution.id = execution.id
node_execution.tenant_id = tenant_id
node_execution.app_id = app_id
node_execution.workflow_id = execution.workflow_id
node_execution.triggered_from = triggered_from.value
node_execution.workflow_run_id = execution.workflow_execution_id
node_execution.index = execution.index
node_execution.predecessor_node_id = execution.predecessor_node_id
node_execution.node_id = execution.node_id
node_execution.node_type = execution.node_type.value
node_execution.title = execution.title
node_execution.node_execution_id = execution.node_execution_id
# Serialize complex data as JSON
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role.value
node_execution.created_by = creator_user_id
node_execution.created_at = execution.created_at
node_execution.finished_at = execution.finished_at
return node_execution
def _update_node_execution_from_domain(
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
) -> None:
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at

View File

@ -0,0 +1,214 @@
import uuid
import pytest
from sqlalchemy import delete
from core.variables.segments import StringSegment
from models import Tenant, db
from models.model import App
from models.workflow import WorkflowDraftVariable
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
@pytest.fixture
def app_and_tenant(flask_req_ctx):
tenant_id = uuid.uuid4()
tenant = Tenant(
id=tenant_id,
name="test_tenant",
)
db.session.add(tenant)
app = App(
tenant_id=tenant_id, # Now tenant.id will have a value
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.flush()
yield (tenant, app)
# Cleanup with proper error handling
db.session.delete(app)
db.session.delete(tenant)
class TestDeleteDraftVariablesIntegration:
@pytest.fixture
def setup_test_data(self, app_and_tenant):
"""Create test data with apps and draft variables."""
tenant, app = app_and_tenant
# Create a second app for testing
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.commit()
# Create draft variables for both apps
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var1)
variables_app1.append(var1)
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var2)
variables_app2.append(var2)
# Commit all the variables to the database
db.session.commit()
yield {
"app1": app,
"app2": app2,
"tenant": tenant,
"variables_app1": variables_app1,
"variables_app2": variables_app2,
}
# Cleanup - refresh session and check if objects still exist
db.session.rollback() # Clear any pending changes
# Clean up remaining variables
cleanup_query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_query)
# Clean up app2
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
"""Test that batch deletion only removes variables for the specified app."""
data = setup_test_data
app1_id = data["app1"].id
app2_id = data["app2"].id
# Verify initial state
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_before == 5
assert app2_vars_before == 5
# Delete app1 variables
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
# Verify results
assert deleted_count == 5
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0 # All app1 variables deleted
assert app2_vars_after == 5 # App2 variables unchanged
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
"""Test batch deletion with small batch size processes all records."""
data = setup_test_data
app1_id = data["app1"].id
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
assert deleted_count == 5
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert remaining_vars == 0
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
"""Test that deleting variables for nonexistent app returns 0."""
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
assert deleted_count == 0
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
"""Test that _delete_draft_variables wrapper function works correctly."""
data = setup_test_data
app1_id = data["app1"].id
# Verify initial state
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_before == 5
# Call wrapper function
deleted_count = _delete_draft_variables(app1_id)
# Verify results
assert deleted_count == 5
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_after == 0
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
"""Test batch deletion with larger dataset to verify batching logic."""
tenant, app = app_and_tenant
# Create many draft variables
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var)
variables.append(var)
variable_ids = [i.id for i in variables]
# Commit the variables to the database
db.session.commit()
try:
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
assert deleted_count == 25
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining_vars == 0
finally:
query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.id.in_(variable_ids),
)
.execution_options(synchronize_session=False)
)
db.session.execute(query)

View File

@ -160,6 +160,177 @@ def test_custom_authorization_header(setup_http_mock):
assert "?A=b" in data
assert "X-Header: 123" in data
# Custom authorization header should be set (may be masked)
assert "X-Auth:" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.system_variable import SystemVariable
# Create variable pool
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="test", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
# Create node data with custom auth and empty api_key
node_data = HttpRequestNodeData(
title="http",
desc="",
url="http://example.com",
method="get",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={
"type": "custom",
"api_key": "", # Empty api_key
"header": "X-Custom-Auth",
},
),
headers="",
params="",
body=None,
ssl_verify=True,
)
# Create executor
executor = Executor(
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
)
# Get assembled headers
headers = executor._assembling_headers()
# When api_key is empty, the custom header should NOT be set
assert "X-Custom-Auth" not in headers
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
"""
Test that when switching from custom to bearer authorization,
the custom header settings don't interfere with bearer token.
This test verifies the fix for issue #23554.
"""
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": {
"type": "api-key",
"config": {
"type": "bearer",
"api_key": "test-token",
"header": "", # Empty header - should default to Authorization
},
},
"headers": "",
"params": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# In bearer mode, should use Authorization header (value is masked with *)
assert "Authorization: " in data
# Should contain masked Bearer token
assert "*" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
"""
Test that when switching from custom to basic authorization,
the custom header settings don't interfere with basic auth.
This test verifies the fix for issue #23554.
"""
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": {
"type": "api-key",
"config": {
"type": "basic",
"api_key": "user:pass",
"header": "", # Empty header - should default to Authorization
},
},
"headers": "",
"params": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# In basic mode, should use Authorization header (value is masked with *)
assert "Authorization: " in data
# Should contain masked Basic credentials
assert "*" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_authorization_with_empty_api_key(setup_http_mock):
"""
Test that custom authorization doesn't set header when api_key is empty.
This test verifies the fix for issue #23554.
"""
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": {
"type": "api-key",
"config": {
"type": "custom",
"api_key": "", # Empty api_key
"header": "X-Custom-Auth",
},
},
"headers": "",
"params": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# Custom header should NOT be set when api_key is empty
assert "X-Custom-Auth:" not in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
@ -239,6 +410,7 @@ def test_json(setup_http_mock):
assert "X-Header: 123" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_x_www_form_urlencoded(setup_http_mock):
node = init_http_node(
config={
@ -285,6 +457,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
assert "X-Header: 123" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_form_data(setup_http_mock):
node = init_http_node(
config={
@ -334,6 +507,7 @@ def test_form_data(setup_http_mock):
assert "X-Header: 123" in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_none_data(setup_http_mock):
node = init_http_node(
config={
@ -366,6 +540,7 @@ def test_none_data(setup_http_mock):
assert "123123123" not in data
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_mock_404(setup_http_mock):
node = init_http_node(
config={
@ -394,6 +569,7 @@ def test_mock_404(setup_http_mock):
assert "Not Found" in resp.get("body", "")
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_multi_colons_parse(setup_http_mock):
node = init_http_node(
config={

View File

@ -0,0 +1,885 @@
import copy
import pytest
from faker import Faker
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class TestAdvancedPromptTemplateService:
"""Integration tests for AdvancedPromptTemplateService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
# This service doesn't have external dependencies, but we keep the pattern
# for consistency with other test files
return {}
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful prompt generation for Baichuan model.
This test verifies:
- Proper prompt generation for Baichuan models
- Correct model detection logic
- Appropriate prompt template selection
"""
fake = Faker()
# Test data for Baichuan model
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is included for Baichuan model
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful prompt generation for common models.
This test verifies:
- Proper prompt generation for non-Baichuan models
- Correct model detection logic
- Appropriate prompt template selection
"""
fake = Faker()
# Test data for common model
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is included for common model
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_case_insensitive_baichuan_detection(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan model detection is case insensitive.
This test verifies:
- Model name detection works regardless of case
- Proper prompt template selection for different case variations
"""
fake = Faker()
# Test different case variations
test_cases = ["Baichuan-13B-Chat", "BAICHUAN-13B-CHAT", "baichuan-13b-chat", "BaiChuan-13B-Chat"]
for model_name in test_cases:
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": model_name,
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify Baichuan template is used
assert result is not None
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
def test_get_common_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for chat app with completion mode.
This test verifies:
- Correct prompt template selection for chat app + completion mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "conversation_histories_role" in result["completion_prompt_config"]
assert "stop" in result
# Verify context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test common prompt generation for chat app with chat mode.
This test verifies:
- Correct prompt template selection for chat app + chat mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with completion mode.
This test verifies:
- Correct prompt template selection for completion app + completion mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "stop" in result
# Verify context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with chat mode.
This test verifies:
- Correct prompt template selection for completion app + chat mode
- Proper context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test common prompt generation without context.
This test verifies:
- Correct handling when has_context is "false"
- Context is not included in prompt
- Template structure remains intact
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is NOT included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert CONTEXT not in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported app mode.
This test verifies:
- Proper handling of unsupported app modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt("unsupported_mode", "completion", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_common_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported model mode.
This test verifies:
- Proper handling of unsupported model modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test completion prompt generation with context.
This test verifies:
- Proper context integration in completion prompts
- Template structure preservation
- Context placement at the beginning
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "true", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify context is prepended to original text
result_text = result["completion_prompt_config"]["prompt"]["text"]
assert result_text.startswith(CONTEXT)
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_completion_prompt_without_context(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test completion prompt generation without context.
This test verifies:
- Original template is preserved when no context
- No modification to prompt text
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify original text is unchanged
result_text = result["completion_prompt_config"]["prompt"]["text"]
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test chat prompt generation with context.
This test verifies:
- Proper context integration in chat prompts
- Template structure preservation
- Context placement at the beginning of first message
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "true", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify context is prepended to original text
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert result_text.startswith(CONTEXT)
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test chat prompt generation without context.
This test verifies:
- Original template is preserved when no context
- No modification to prompt text
"""
fake = Faker()
# Create test prompt template
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify original text is unchanged
result_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_baichuan_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with completion mode.
This test verifies:
- Correct Baichuan prompt template selection for chat app + completion mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "conversation_histories_role" in result["completion_prompt_config"]
assert "stop" in result
# Verify Baichuan context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_chat_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with chat mode.
This test verifies:
- Correct Baichuan prompt template selection for chat app + chat mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify Baichuan context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with completion mode.
This test verifies:
- Correct Baichuan prompt template selection for completion app + completion mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
assert "stop" in result
# Verify Baichuan context is included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with chat mode.
This test verifies:
- Correct Baichuan prompt template selection for completion app + chat mode
- Proper Baichuan context integration
- Template structure validation
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true")
# Assert: Verify the expected outcomes
assert result is not None
assert "chat_prompt_config" in result
assert "prompt" in result["chat_prompt_config"]
assert len(result["chat_prompt_config"]["prompt"]) > 0
assert "role" in result["chat_prompt_config"]["prompt"][0]
assert "text" in result["chat_prompt_config"]["prompt"][0]
# Verify Baichuan context is included
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test Baichuan prompt generation without context.
This test verifies:
- Correct handling when has_context is "false"
- Baichuan context is not included in prompt
- Template structure remains intact
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false")
# Assert: Verify the expected outcomes
assert result is not None
assert "completion_prompt_config" in result
assert "prompt" in result["completion_prompt_config"]
assert "text" in result["completion_prompt_config"]["prompt"]
# Verify Baichuan context is NOT included
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert BAICHUAN_CONTEXT not in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported app mode.
This test verifies:
- Proper handling of unsupported app modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt("unsupported_mode", "completion", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_baichuan_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported model mode.
This test verifies:
- Proper handling of unsupported model modes
- Default empty dict return
"""
fake = Faker()
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true")
# Assert: Verify empty dict is returned
assert result == {}
def test_get_prompt_all_app_modes_common_model(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with common model.
This test verifies:
- All app modes work correctly with common models
- Proper template selection for each combination
"""
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
for model_mode in model_modes:
args = {
"app_mode": app_mode,
"model_mode": model_mode,
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify result is not empty
assert result is not None
assert result != {}
def test_get_prompt_all_app_modes_baichuan_model(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with Baichuan model.
This test verifies:
- All app modes work correctly with Baichuan models
- Proper template selection for each combination
"""
fake = Faker()
# Test all app modes
app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value]
model_modes = ["completion", "chat"]
for app_mode in app_modes:
for model_mode in model_modes:
args = {
"app_mode": app_mode,
"model_mode": model_mode,
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify result is not empty
assert result is not None
assert result != {}
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test prompt generation with edge cases.
This test verifies:
- Handling of edge case inputs
- Proper error handling
- Consistent behavior with unusual inputs
"""
fake = Faker()
# Test edge cases
edge_cases = [
{"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"},
{"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"},
{
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "",
},
]
for args in edge_cases:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify method handles edge cases gracefully
# Should either return a valid result or empty dict, but not crash
assert result is not None
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that original templates are not modified.
This test verifies:
- Original template constants are not modified
- Deep copy is used properly
- Template immutability is maintained
"""
fake = Faker()
# Store original templates
original_chat_completion = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_chat_chat = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_completion_completion = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
original_completion_chat = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify original templates are unchanged
assert original_chat_completion == CHAT_APP_COMPLETION_PROMPT_CONFIG
assert original_chat_chat == CHAT_APP_CHAT_PROMPT_CONFIG
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that original Baichuan templates are not modified.
This test verifies:
- Original Baichuan template constants are not modified
- Deep copy is used properly
- Template immutability is maintained
"""
fake = Faker()
# Store original templates
original_baichuan_chat_completion = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_baichuan_chat_chat = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
original_baichuan_completion_completion = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
original_baichuan_completion_chat = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Test with context
args = {
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
}
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify original templates are unchanged
assert original_baichuan_chat_completion == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_chat_chat == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test consistency of context integration across different scenarios.
This test verifies:
- Context is always prepended correctly
- Context integration is consistent across different templates
- No context duplication or corruption
"""
fake = Faker()
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "completion",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "chat",
"model_name": "gpt-3.5-turbo",
"has_context": "true",
},
]
for args in test_scenarios:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify context integration is consistent
assert result is not None
assert result != {}
# Check that context is properly integrated
if "completion_prompt_config" in result:
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert prompt_text.startswith(CONTEXT)
elif "chat_prompt_config" in result:
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert prompt_text.startswith(CONTEXT)
def test_baichuan_context_integration_consistency(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test consistency of Baichuan context integration across different scenarios.
This test verifies:
- Baichuan context is always prepended correctly
- Context integration is consistent across different templates
- No context duplication or corruption
"""
fake = Faker()
# Test different scenarios
test_scenarios = [
{
"app_mode": AppMode.CHAT.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.CHAT.value,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "completion",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
{
"app_mode": AppMode.COMPLETION.value,
"model_mode": "chat",
"model_name": "baichuan-13b-chat",
"has_context": "true",
},
]
for args in test_scenarios:
# Act: Execute the method under test
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert: Verify context integration is consistent
assert result is not None
assert result != {}
# Check that Baichuan context is properly integrated
if "completion_prompt_config" in result:
prompt_text = result["completion_prompt_config"]["prompt"]["text"]
assert prompt_text.startswith(BAICHUAN_CONTEXT)
elif "chat_prompt_config" in result:
prompt_text = result["chat_prompt_config"]["prompt"][0]["text"]
assert prompt_text.startswith(BAICHUAN_CONTEXT)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,913 @@
import hashlib
from io import BytesIO
from unittest.mock import patch
import pytest
from faker import Faker
from werkzeug.exceptions import NotFound
from configs import dify_config
from models.account import Account, Tenant
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
class TestFileService:
"""Integration tests for FileService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.file_service.storage") as mock_storage,
patch("services.file_service.file_helpers") as mock_file_helpers,
patch("services.file_service.ExtractProcessor") as mock_extract_processor,
):
# Setup default mock returns
mock_storage.save.return_value = None
mock_storage.load.return_value = BytesIO(b"mock file content")
mock_file_helpers.get_signed_file_url.return_value = "https://example.com/signed-url"
mock_file_helpers.verify_image_signature.return_value = True
mock_file_helpers.verify_file_signature.return_value = True
mock_extract_processor.load_from_upload_file.return_value = "extracted text content"
yield {
"storage": mock_storage,
"file_helpers": mock_file_helpers,
"extract_processor": mock_extract_processor,
}
def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test account for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
Account: Created account instance
"""
fake = Faker()
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Set current tenant for account
account.current_tenant = tenant
return account
def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Helper method to create a test end user for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
EndUser: Created end user instance
"""
fake = Faker()
end_user = EndUser(
tenant_id=str(fake.uuid4()),
type="web",
name=fake.name(),
is_anonymous=False,
session_id=fake.uuid4(),
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
return end_user
def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account):
"""
Helper method to create a test upload file for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
account: Account instance
Returns:
UploadFile: Created upload file instance
"""
fake = Faker()
upload_file = UploadFile(
tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()),
storage_type="local",
key=f"upload_files/test/{fake.uuid4()}.txt",
name="test_file.txt",
size=1024,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=fake.date_time(),
used=False,
hash=hashlib.sha3_256(b"test content").hexdigest(),
source_url="",
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
return upload_file
# Test upload_file method
def test_upload_file_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful file upload with valid parameters.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test_document.pdf"
content = b"test file content"
mimetype = "application/pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.name == filename
assert upload_file.size == len(content)
assert upload_file.extension == "pdf"
assert upload_file.mime_type == mimetype
assert upload_file.created_by == account.id
assert upload_file.created_by_role == CreatorUserRole.ACCOUNT.value
assert upload_file.used is False
assert upload_file.hash == hashlib.sha3_256(content).hexdigest()
# Verify storage was called
mock_external_service_dependencies["storage"].save.assert_called_once()
# Verify database state
from extensions.ext_database import db
db.session.refresh(upload_file)
assert upload_file.id is not None
def test_upload_file_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with end user instead of account.
"""
fake = Faker()
end_user = self._create_test_end_user(db_session_with_containers, mock_external_service_dependencies)
filename = "test_image.jpg"
content = b"test image content"
mimetype = "image/jpeg"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=end_user,
)
assert upload_file is not None
assert upload_file.created_by == end_user.id
assert upload_file.created_by_role == CreatorUserRole.END_USER.value
def test_upload_file_with_datasets_source(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with datasets source parameter.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test_document.pdf"
content = b"test file content"
mimetype = "application/pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
source="datasets",
source_url="https://example.com/source",
)
assert upload_file is not None
assert upload_file.source_url == "https://example.com/source"
def test_upload_file_invalid_filename_characters(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload with invalid filename characters.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test/file<name>.txt"
content = b"test content"
mimetype = "text/plain"
with pytest.raises(ValueError, match="Filename contains invalid characters"):
FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
def test_upload_file_filename_too_long(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with filename that exceeds length limit.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Create a filename longer than 200 characters
long_name = "a" * 250
filename = f"{long_name}.txt"
content = b"test content"
mimetype = "text/plain"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
# Verify filename was truncated (the logic truncates the base name to 200 chars + extension)
# So the total length should be <= 200 + len(extension) + 1 (for the dot)
assert len(upload_file.name) <= 200 + len(upload_file.extension) + 1
assert upload_file.name.endswith(".txt")
# Verify the base name was truncated
base_name = upload_file.name[:-4] # Remove .txt
assert len(base_name) <= 200
def test_upload_file_datasets_unsupported_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload for datasets with unsupported file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test_image.jpg"
content = b"test content"
mimetype = "image/jpeg"
with pytest.raises(UnsupportedFileTypeError):
FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
source="datasets",
)
def test_upload_file_too_large(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with file size exceeding limit.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "large_image.jpg"
# Create content larger than the limit
content = b"x" * (dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1)
mimetype = "image/jpeg"
with pytest.raises(FileTooLargeError):
FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
# Test is_file_size_within_limit method
def test_is_file_size_within_limit_image_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for image files within limit.
"""
extension = "jpg"
file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_video_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for video files within limit.
"""
extension = "mp4"
file_size = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_audio_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for audio files within limit.
"""
extension = "mp3"
file_size = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_document_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for document files within limit.
"""
extension = "pdf"
file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
def test_is_file_size_within_limit_image_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for image files exceeding limit.
"""
extension = "jpg"
file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1 # Exceeds limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is False
def test_is_file_size_within_limit_unknown_extension(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file size check for unknown file extension.
"""
extension = "xyz"
file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Uses default limit
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
# Test upload_text method
def test_upload_text_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful text upload.
"""
fake = Faker()
text = "This is a test text content"
text_name = "test_text.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None
assert upload_file.name == text_name
assert upload_file.size == len(text)
assert upload_file.extension == "txt"
assert upload_file.mime_type == "text/plain"
assert upload_file.used is True
assert upload_file.used_by == mock_current_user.id
# Verify storage was called
mock_external_service_dependencies["storage"].save.assert_called_once()
def test_upload_text_name_too_long(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test text upload with name that exceeds length limit.
"""
fake = Faker()
text = "test content"
long_name = "a" * 250 # Longer than 200 characters
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
upload_file = FileService.upload_text(text=text, text_name=long_name)
# Verify name was truncated
assert len(upload_file.name) <= 200
assert upload_file.name == "a" * 200
# Test get_file_preview method
def test_get_file_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful file preview generation.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have document extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
result = FileService.get_file_preview(file_id=upload_file.id)
assert result == "extracted text content"
mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once()
def test_get_file_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file preview with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
with pytest.raises(NotFound, match="File not found"):
FileService.get_file_preview(file_id=non_existent_id)
def test_get_file_preview_unsupported_file_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file preview with unsupported file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have non-document extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
with pytest.raises(UnsupportedFileTypeError):
FileService.get_file_preview(file_id=upload_file.id)
def test_get_file_preview_text_truncation(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file preview with text that exceeds preview limit.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have document extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
# Mock long text content
long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT
mock_external_service_dependencies["extract_processor"].load_from_upload_file.return_value = long_text
result = FileService.get_file_preview(file_id=upload_file.id)
assert len(result) == 3000 # PREVIEW_WORDS_LIMIT
assert result == "x" * 3000
# Test get_image_preview method
def test_get_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful image preview generation.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have image extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
generator, mime_type = FileService.get_image_preview(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
assert generator is not None
assert mime_type == upload_file.mime_type
mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once()
def test_get_image_preview_invalid_signature(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test image preview with invalid signature.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Mock invalid signature
mock_external_service_dependencies["file_helpers"].verify_image_signature.return_value = False
timestamp = "1234567890"
nonce = "test_nonce"
sign = "invalid_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_image_preview(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
def test_get_image_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test image preview with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_image_preview(
file_id=non_existent_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
def test_get_image_preview_unsupported_file_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test image preview with non-image file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have non-image extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
with pytest.raises(UnsupportedFileTypeError):
FileService.get_image_preview(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
# Test get_file_generator_by_file_id method
def test_get_file_generator_by_file_id_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful file generator retrieval.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
generator, file_obj = FileService.get_file_generator_by_file_id(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
assert generator is not None
assert file_obj == upload_file
mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once()
def test_get_file_generator_by_file_id_invalid_signature(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file generator retrieval with invalid signature.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Mock invalid signature
mock_external_service_dependencies["file_helpers"].verify_file_signature.return_value = False
timestamp = "1234567890"
nonce = "test_nonce"
sign = "invalid_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_file_generator_by_file_id(
file_id=upload_file.id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
def test_get_file_generator_by_file_id_file_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file generator retrieval with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
timestamp = "1234567890"
nonce = "test_nonce"
sign = "test_signature"
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_file_generator_by_file_id(
file_id=non_existent_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
# Test get_public_image_preview method
def test_get_public_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful public image preview generation.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have image extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
generator, mime_type = FileService.get_public_image_preview(file_id=upload_file.id)
assert generator is not None
assert mime_type == upload_file.mime_type
mock_external_service_dependencies["storage"].load.assert_called_once()
def test_get_public_image_preview_file_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test public image preview with non-existent file.
"""
fake = Faker()
non_existent_id = str(fake.uuid4())
with pytest.raises(NotFound, match="File not found or signature is invalid"):
FileService.get_public_image_preview(file_id=non_existent_id)
def test_get_public_image_preview_unsupported_file_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test public image preview with non-image file type.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
upload_file = self._create_test_upload_file(
db_session_with_containers, mock_external_service_dependencies, account
)
# Update file to have non-image extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
with pytest.raises(UnsupportedFileTypeError):
FileService.get_public_image_preview(file_id=upload_file.id)
# Test edge cases and boundary conditions
def test_upload_file_empty_content(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with empty content.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "empty.txt"
content = b""
mimetype = "text/plain"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.size == 0
def test_upload_file_special_characters_in_name(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload with special characters in filename (but valid ones).
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test-file_with_underscores_and.dots.txt"
content = b"test content"
mimetype = "text/plain"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.name == filename
def test_upload_file_different_case_extensions(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test file upload with different case extensions.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test.PDF"
content = b"test content"
mimetype = "application/pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.extension == "pdf" # Should be converted to lowercase
def test_upload_text_empty_text(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test text upload with empty text.
"""
fake = Faker()
text = ""
text_name = "empty.txt"
# Mock current_user
with patch("services.file_service.current_user") as mock_current_user:
mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4())
upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None
assert upload_file.size == 0
def test_file_size_limits_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file size limits with edge case values.
"""
# Test exactly at limit
for extension, limit_config in [
("jpg", dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT),
("mp4", dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT),
("mp3", dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT),
("pdf", dify_config.UPLOAD_FILE_SIZE_LIMIT),
]:
file_size = limit_config * 1024 * 1024
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is True
# Test one byte over limit
file_size = limit_config * 1024 * 1024 + 1
result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is False
def test_upload_file_with_source_url(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test file upload with source URL that gets overridden by signed URL.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
filename = "test.pdf"
content = b"test content"
mimetype = "application/pdf"
source_url = "https://original-source.com/file.pdf"
upload_file = FileService.upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
source_url=source_url,
)
# When source_url is provided, it should be preserved
assert upload_file.source_url == source_url
# The signed URL should only be set when source_url is empty
# Let's test that scenario
upload_file2 = FileService.upload_file(
filename="test2.pdf",
content=b"test content 2",
mimetype="application/pdf",
user=account,
source_url="", # Empty source_url
)
# Should have the signed URL when source_url is empty
assert upload_file2.source_url == "https://example.com/signed-url"

View File

@ -4,8 +4,8 @@ from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.common.errors import FilenameNotExistsError
from controllers.console.error import (
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,

View File

@ -0,0 +1,247 @@
"""
Unit tests for CeleryWorkflowExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowExecutionRepository:
"""Test cases for CeleryWorkflowExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_basic_functionality(self, mock_session_factory, mock_account):
"""Test repository initialization basic functionality."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Verify basic initialization
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == "test-app"
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
"""Test that save operation queues a Celery task without tracking."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify no task tracking occurs (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_execution)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_operation_fire_and_forget(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation works in fire-and-forget mode."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Test that save doesn't block or maintain state
repo.save(sample_workflow_execution)
# Verify no pending saves are tracked (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
"""Test multiple save operations work correctly."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Create multiple executions
exec1 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input2": "value2"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Should work without issues and not maintain state (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
"""Test save operation with different user types."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
repo.save(execution)
# Verify task was called with EndUser context
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["tenant_id"] == mock_end_user.tenant_id
assert call_args["creator_user_id"] == mock_end_user.id

View File

@ -0,0 +1,349 @@
"""
Unit tests for CeleryWorkflowNodeExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_node_execution():
"""Sample WorkflowNodeExecution for testing."""
return WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
node_id="test_node",
node_type=NodeType.START,
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowNodeExecutionRepository:
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
"""Test repository initialization with cache properly initialized."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._execution_cache == {}
assert repo._workflow_execution_mapping == {}
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_caches_and_queues_celery_task(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation caches execution and queues a Celery task."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.save(sample_workflow_node_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify execution is cached
assert sample_workflow_node_execution.id in repo._execution_cache
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
# Verify workflow execution mapping is updated
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
assert (
sample_workflow_node_execution.id
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_node_execution)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_get_by_workflow_run_from_cache(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that get_by_workflow_run retrieves executions from cache."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Save execution to cache first
repo.save(sample_workflow_node_execution)
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
# Verify results were retrieved from cache
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
assert result[0] is sample_workflow_node_execution
def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account):
"""Test get_by_workflow_run without order configuration."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
result = repo.get_by_workflow_run("workflow-run-id")
# Should return empty list since nothing in cache
assert len(result) == 0
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
"""Test cache operations work correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Test saving to cache
repo.save(sample_workflow_node_execution)
# Verify cache contains the execution
assert sample_workflow_node_execution.id in repo._execution_cache
# Test retrieving from cache
result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id)
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
"""Test multiple executions for the same workflow."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create multiple executions for the same workflow
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.START,
title="Node 1",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.LLM,
title="Node 2",
inputs={"input2": "value2"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Verify both are cached and mapped
assert len(repo._execution_cache) == 2
assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2
# Test retrieval
result = repo.get_by_workflow_run(workflow_run_id)
assert len(result) == 2
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
"""Test ordering functionality works correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create executions with different indices
workflow_run_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=2,
node_id="node2",
node_type=NodeType.START,
title="Node 2",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_run_id,
index=1,
node_id="node1",
node_type=NodeType.LLM,
title="Node 1",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Save in random order
repo.save(exec1)
repo.save(exec2)
# Test ascending order
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 1
assert result[1].index == 2
# Test descending order
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
assert len(result) == 2
assert result[0].index == 2
assert result[1].index == 1

View File

@ -59,7 +59,7 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
# Create a mock interface with the same methods
# Create a mock interface class
class MockInterface:
def save(self):
pass
@ -67,20 +67,20 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
# Should not raise an exception
# Should not raise an exception when all methods are present
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods."""
# Create a mock class that doesn't implement all required methods
# Create a mock class that's missing required methods
class IncompleteRepository:
def save(self):
pass
# Missing get_by_id method
# Create a mock interface with required methods
# Create a mock interface that requires both methods
class MockInterface:
def save(self):
pass
@ -88,57 +88,39 @@ class TestRepositoryFactory:
def get_by_id(self):
pass
def missing_method(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value)
assert "get_by_id" in str(exc_info.value)
def test_validate_constructor_signature_success(self):
"""Test successful constructor signature validation."""
def test_validate_repository_interface_with_private_methods(self):
"""Test that private methods are ignored during interface validation."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from):
def save(self):
pass
# Should not raise an exception
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_missing_params(self):
"""Test constructor validation with missing parameters."""
class IncompleteRepository:
def __init__(self, session_factory, user):
# Missing app_id and triggered_from parameters
def _private_method(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
"""Test constructor validation when inspection fails."""
# Mock inspect.signature to raise an exception
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
class MockRepository:
def __init__(self, session_factory):
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
assert "Failed to validate constructor signature" in str(exc_info.value)
def _private_method(self):
pass
# Should not raise exception - private methods should be ignored
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowExecutionRepository."""
def test_create_workflow_execution_repository_success(self, mock_config):
"""Test successful WorkflowExecutionRepository creation."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
@ -146,7 +128,7 @@ class TestRepositoryFactory:
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@ -155,7 +137,6 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
@ -177,7 +158,7 @@ class TestRepositoryFactory:
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
@ -195,45 +176,46 @@ class TestRepositoryFactory:
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import to succeed but validation to fail
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=Account)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
@ -245,18 +227,18 @@ class TestRepositoryFactory:
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
"""Test successful creation of WorkflowNodeExecutionRepository."""
def test_create_workflow_node_execution_repository_success(self, mock_config):
"""Test successful WorkflowNodeExecutionRepository creation."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@ -265,7 +247,6 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
@ -287,7 +268,7 @@ class TestRepositoryFactory:
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
@ -297,28 +278,83 @@ class TestRepositoryFactory:
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Cannot import repository class" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception."""
error_message = "Test error message"
exception = RepositoryImportError(error_message)
assert str(exception) == error_message
assert isinstance(exception, Exception)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock the import to succeed but validation to fail
mock_repository_class = MagicMock()
mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class)
mocker.patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
)
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Create a mock repository class that raises exception on instantiation
mock_repository_class = MagicMock()
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock the validation methods to succeed
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_repository_import_error_exception(self):
"""Test RepositoryImportError exception handling."""
error_message = "Custom error message"
error = RepositoryImportError(error_message)
assert str(error) == error_message
@patch("core.repositories.factory.dify_config")
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
# Create mock dependencies with Engine instead of sessionmaker
# Create mock dependencies using Engine instead of sessionmaker
mock_engine = MagicMock(spec=Engine)
mock_user = MagicMock(spec=Account)
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
# Mock the imported class to be a valid repository
# Create mock repository class and instance
mock_repository_class = MagicMock()
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
mock_repository_class.return_value = mock_repository_instance
@ -327,129 +363,19 @@ class TestRepositoryFactory:
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
app_id=app_id,
triggered_from=triggered_from,
)
# Verify the repository was created with the Engine
# Verify the repository was created with correct parameters
mock_repository_class.assert_called_once_with(
session_factory=mock_engine,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
app_id=app_id,
triggered_from=triggered_from,
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with validation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Interface validation failed" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
mock_session_factory = MagicMock(spec=sessionmaker)
mock_user = MagicMock(spec=EndUser)
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
def test_validate_repository_interface_with_private_methods(self):
"""Test interface validation ignores private methods."""
# Create a mock class with private methods
class MockRepository:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Create a mock interface with private methods
class MockInterface:
def save(self):
pass
def get_by_id(self):
pass
def _private_method(self):
pass
# Should not raise an exception (private methods are ignored)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_constructor_signature_with_extra_params(self):
"""Test constructor validation with extra parameters (should pass)."""
class MockRepository:
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
pass
# Should not raise an exception (extra parameters are allowed)
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
def test_validate_constructor_signature_with_kwargs(self):
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
class MockRepository:
def __init__(self, session_factory, user, **kwargs):
pass
# Current implementation doesn't handle **kwargs, so this should raise an exception
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
assert "app_id" in str(exc_info.value)
assert "triggered_from" in str(exc_info.value)

View File

@ -243,8 +243,6 @@ def test_executor_with_form_data():
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/upload"
assert "Content-Type" in executor.headers
assert "multipart/form-data" in executor.headers["Content-Type"]
assert executor.params is None
assert executor.json is None
# '__multipart_placeholder__' is expected when no file inputs exist,
@ -252,6 +250,11 @@ def test_executor_with_form_data():
assert executor.files == [("__multipart_placeholder__", ("", b"", "application/octet-stream"))]
assert executor.content is None
# After fix for #23829: When placeholder files exist, Content-Type is removed
# to let httpx handle Content-Type and boundary automatically
headers = executor._assembling_headers()
assert "Content-Type" not in headers or "multipart/form-data" not in headers.get("Content-Type", "")
# Check that the form data is correctly loaded in executor.data
assert isinstance(executor.data, dict)
assert "text_field" in executor.data

View File

View File

@ -0,0 +1,243 @@
from unittest.mock import ANY, MagicMock, call, patch
import pytest
import sqlalchemy as sa
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
class TestDeleteDraftVariablesBatch:
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_success(self, mock_db):
"""Test successful deletion of draft variables in batches."""
app_id = "test-app-id"
batch_size = 100
# Mock database connection and engine
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
# Mock two batches of results, then empty
batch1_ids = [f"var-{i}" for i in range(100)]
batch2_ids = [f"var-{i}" for i in range(100, 150)]
# Setup side effects for execute calls in the correct order:
# 1. SELECT (returns batch1_ids)
# 2. DELETE (returns result with rowcount=100)
# 3. SELECT (returns batch2_ids)
# 4. DELETE (returns result with rowcount=50)
# 5. SELECT (returns empty, ends loop)
# Create mock results with actual integer rowcount attributes
class MockResult:
def __init__(self, rowcount):
self.rowcount = rowcount
# First SELECT result
select_result1 = MagicMock()
select_result1.__iter__.return_value = iter([(id_,) for id_ in batch1_ids])
# First DELETE result
delete_result1 = MockResult(rowcount=100)
# Second SELECT result
select_result2 = MagicMock()
select_result2.__iter__.return_value = iter([(id_,) for id_ in batch2_ids])
# Second DELETE result
delete_result2 = MockResult(rowcount=50)
# Third SELECT result (empty, ends loop)
select_result3 = MagicMock()
select_result3.__iter__.return_value = iter([])
# Configure side effects in the correct order
mock_conn.execute.side_effect = [
select_result1, # First SELECT
delete_result1, # First DELETE
select_result2, # Second SELECT
delete_result2, # Second DELETE
select_result3, # Third SELECT (empty)
]
# Execute the function
result = delete_draft_variables_batch(app_id, batch_size)
# Verify the result
assert result == 150
# Verify database calls
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
# Verify the expected calls in order:
# 1. SELECT, 2. DELETE, 3. SELECT, 4. DELETE, 5. SELECT
expected_calls = [
# First SELECT
call(
sa.text("""
SELECT id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""),
{"app_id": app_id, "batch_size": batch_size},
),
# First DELETE
call(
sa.text("""
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""),
{"ids": tuple(batch1_ids)},
),
# Second SELECT
call(
sa.text("""
SELECT id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""),
{"app_id": app_id, "batch_size": batch_size},
),
# Second DELETE
call(
sa.text("""
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""),
{"ids": tuple(batch2_ids)},
),
# Third SELECT (empty result)
call(
sa.text("""
SELECT id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""),
{"app_id": app_id, "batch_size": batch_size},
),
]
# Check that all calls were made correctly
actual_calls = mock_conn.execute.call_args_list
assert len(actual_calls) == len(expected_calls)
# Simplified verification - just check that the right number of calls were made
# and that the SQL queries contain the expected patterns
for i, actual_call in enumerate(actual_calls):
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
# Verify it's a SELECT query
sql_text = str(actual_call[0][0])
assert "SELECT id FROM workflow_draft_variables" in sql_text
assert "WHERE app_id = :app_id" in sql_text
assert "LIMIT :batch_size" in sql_text
else: # DELETE calls (odd indices: 1, 3)
# Verify it's a DELETE query
sql_text = str(actual_call[0][0])
assert "DELETE FROM workflow_draft_variables" in sql_text
assert "WHERE id IN :ids" in sql_text
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_empty_result(self, mock_db):
"""Test deletion when no draft variables exist for the app."""
app_id = "nonexistent-app-id"
batch_size = 1000
# Mock database connection
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
# Mock empty result
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.return_value = empty_result
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 0
assert mock_conn.execute.call_count == 1 # Only one select query
def test_delete_draft_variables_batch_invalid_batch_size(self):
"""Test that invalid batch size raises ValueError."""
app_id = "test-app-id"
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, -1)
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, 0)
@patch("tasks.remove_app_and_related_data_task.db")
@patch("tasks.remove_app_and_related_data_task.logging")
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db):
"""Test that batch deletion logs progress correctly."""
app_id = "test-app-id"
batch_size = 50
# Mock database
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
# Mock one batch then empty
batch_ids = [f"var-{i}" for i in range(30)]
# Create properly configured mocks
select_result = MagicMock()
select_result.__iter__.return_value = iter([(id_,) for id_ in batch_ids])
# Create simple object with rowcount attribute
class MockResult:
def __init__(self, rowcount):
self.rowcount = rowcount
delete_result = MockResult(rowcount=30)
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.side_effect = [
# Select query result
select_result,
# Delete query result
delete_result,
# Empty select result (end condition)
empty_result,
]
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 30
# Verify logging calls
assert mock_logging.info.call_count == 2
mock_logging.info.assert_any_call(
ANY # click.style call
)
@patch("tasks.remove_app_and_related_data_task.delete_draft_variables_batch")
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
"""Test that _delete_draft_variables calls the batch function correctly."""
app_id = "test-app-id"
expected_return = 42
mock_batch_delete.return_value = expected_return
result = _delete_draft_variables(app_id)
assert result == expected_return
mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)

View File

@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.11, <3.13"
resolution-markers = [
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",
@ -1371,6 +1371,7 @@ dev = [
{ name = "types-python-http-client" },
{ name = "types-pywin32" },
{ name = "types-pyyaml" },
{ name = "types-redis" },
{ name = "types-regex" },
{ name = "types-requests" },
{ name = "types-requests-oauthlib" },
@ -1557,6 +1558,7 @@ dev = [
{ name = "types-python-http-client", specifier = ">=3.3.7.20240910" },
{ name = "types-pywin32", specifier = "~=310.0.0" },
{ name = "types-pyyaml", specifier = "~=6.0.12" },
{ name = "types-redis", specifier = ">=4.6.0.20241004" },
{ name = "types-regex", specifier = "~=2024.11.6" },
{ name = "types-requests", specifier = "~=2.32.0" },
{ name = "types-requests-oauthlib", specifier = "~=2.0.0" },
@ -6064,6 +6066,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl", hash = "sha256:8478208feaeb53a34cb5d970c56a7cd76b72659442e733e268a94dc72b2d0530", size = 20312, upload-time = "2025-05-16T03:08:04.019Z" },
]
[[package]]
name = "types-redis"
version = "4.6.0.20241004"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
{ name = "types-pyopenssl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3a/95/c054d3ac940e8bac4ca216470c80c26688a0e79e09f520a942bb27da3386/types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e", size = 49679, upload-time = "2024-10-04T02:43:59.224Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/55/82/7d25dce10aad92d2226b269bce2f85cfd843b4477cd50245d7d40ecf8f89/types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed", size = 58737, upload-time = "2024-10-04T02:43:57.968Z" },
]
[[package]]
name = "types-regex"
version = "2024.11.6.20250403"

View File

@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.."
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage

View File

@ -861,17 +861,23 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Repository configuration
# Core workflow execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576

View File

@ -96,6 +96,7 @@ services:
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}

View File

@ -390,8 +390,8 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@ -662,6 +662,7 @@ services:
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}

View File

@ -37,7 +37,7 @@ describe('CommandSelector', () => {
},
knowledge: {
key: '@knowledge',
shortcut: '@knowledge',
shortcut: '@kb',
title: 'Search Knowledge',
description: 'Search knowledge bases',
search: jest.fn(),
@ -75,7 +75,7 @@ describe('CommandSelector', () => {
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
@ -90,7 +90,7 @@ describe('CommandSelector', () => {
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
@ -107,7 +107,7 @@ describe('CommandSelector', () => {
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
})
@ -122,7 +122,7 @@ describe('CommandSelector', () => {
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
})
@ -137,7 +137,7 @@ describe('CommandSelector', () => {
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
})
it('should match partial strings', () => {
@ -145,14 +145,14 @@ describe('CommandSelector', () => {
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="nowl"
searchFilter="od"
/>,
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
})
@ -167,7 +167,7 @@ describe('CommandSelector', () => {
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@knowledge')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
@ -210,7 +210,7 @@ describe('CommandSelector', () => {
/>,
)
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@knowledge')
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@kb')
})
it('should not call onCommandValueChange if current value still exists', () => {
@ -246,10 +246,10 @@ describe('CommandSelector', () => {
/>,
)
const knowledgeItem = screen.getByTestId('command-item-@knowledge')
const knowledgeItem = screen.getByTestId('command-item-@kb')
fireEvent.click(knowledgeItem)
expect(mockOnCommandSelect).toHaveBeenCalledWith('@knowledge')
expect(mockOnCommandSelect).toHaveBeenCalledWith('@kb')
})
})
@ -276,7 +276,7 @@ describe('CommandSelector', () => {
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
@ -312,7 +312,7 @@ describe('CommandSelector', () => {
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
@ -326,7 +326,7 @@ describe('CommandSelector', () => {
/>,
)
expect(screen.getByTestId('command-item-@knowledge')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
})
})

View File

@ -3,7 +3,7 @@ import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import AppCard from '@/app/components/app/overview/appCard'
import AppCard from '@/app/components/app/overview/app-card'
import Loading from '@/app/components/base/loading'
import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card'
import { ToastContext } from '@/app/components/base/toast'
@ -17,7 +17,7 @@ import type { App } from '@/types/app'
import type { UpdateAppSiteCodeResponse } from '@/models/app'
import { asyncRunSafe } from '@/utils'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import type { IAppCardProps } from '@/app/components/app/overview/appCard'
import type { IAppCardProps } from '@/app/components/app/overview/app-card'
import { useStore as useAppStore } from '@/app/components/app/store'
export type ICardViewProps = {

View File

@ -3,8 +3,8 @@ import React, { useState } from 'react'
import dayjs from 'dayjs'
import quarterOfYear from 'dayjs/plugin/quarterOfYear'
import { useTranslation } from 'react-i18next'
import type { PeriodParams } from '@/app/components/app/overview/appChart'
import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/appChart'
import type { PeriodParams } from '@/app/components/app/overview/app-chart'
import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/app-chart'
import type { Item } from '@/app/components/base/select'
import { SimpleSelect } from '@/app/components/base/select'
import { TIME_PERIOD_MAPPING } from '@/app/components/app/log/filter'
@ -54,6 +54,7 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) {
<SimpleSelect
items={Object.entries(TIME_PERIOD_MAPPING).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
className='mt-0 !w-40'
notClearable={true}
onSelect={(item) => {
const id = item.value
const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1'

View File

@ -1,5 +1,5 @@
import React from 'react'
import ChartView from './chartView'
import ChartView from './chart-view'
import TracingPanel from './tracing/panel'
import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel'

View File

@ -13,14 +13,14 @@ const Header = () => {
const router = useRouter()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const back = useCallback(() => {
router.back()
const goToStudio = useCallback(() => {
router.push('/apps')
}, [router])
return (
<div className='flex flex-1 items-center justify-between px-4'>
<div className='flex items-center gap-3'>
<div className='flex cursor-pointer items-center' onClick={back}>
<div className='flex cursor-pointer items-center' onClick={goToStudio}>
{systemFeatures.branding.enabled && systemFeatures.branding.login_page_logo
? <img
src={systemFeatures.branding.login_page_logo}
@ -33,7 +33,7 @@ const Header = () => {
<p className='title-3xl-semi-bold relative mt-[-2px] text-text-primary'>{t('common.account.account')}</p>
</div>
<div className='flex shrink-0 items-center gap-3'>
<Button className='system-sm-medium gap-2 px-3 py-2' onClick={back}>
<Button className='system-sm-medium gap-2 px-3 py-2' onClick={goToStudio}>
<RiRobot2Line className='h-4 w-4' />
<p>{t('common.account.studio')}</p>
<RiArrowRightUpLine className='h-4 w-4' />

View File

@ -25,7 +25,7 @@ import type { EnvironmentVariable } from '@/app/components/workflow/types'
import { fetchWorkflowDraft } from '@/service/workflow'
import ContentDialog from '@/app/components/base/content-dialog'
import Button from '@/app/components/base/button'
import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView'
import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view'
import Divider from '../base/divider'
import type { Operation } from './app-operations'
import AppOperations from './app-operations'

View File

@ -35,7 +35,7 @@ import type { AppDetailResponse } from '@/models/app'
import { useAppContext } from '@/context/app-context'
import type { AppSSO } from '@/types/app'
import Indicator from '@/app/components/header/indicator'
import { fetchAppDetail } from '@/service/apps'
import { fetchAppDetailDirect } from '@/service/apps'
import { AccessMode } from '@/models/access-control'
import AccessControl from '../app-access-control'
import { useAppWhiteListSubjects } from '@/service/access-control'
@ -161,11 +161,15 @@ function AppCard({
return
setShowAccessControl(true)
}, [appDetail])
const handleAccessControlUpdate = useCallback(() => {
fetchAppDetail({ url: '/apps', id: appDetail!.id }).then((res) => {
const handleAccessControlUpdate = useCallback(async () => {
try {
const res = await fetchAppDetailDirect({ url: '/apps', id: appDetail!.id })
setAppDetail(res)
setShowAccessControl(false)
})
}
catch (error) {
console.error('Failed to fetch app detail:', error)
}
}, [appDetail, setAppDetail])
return (

View File

@ -123,7 +123,7 @@ const Chart: React.FC<IChartProps> = ({
dimensions: ['date', yField],
source: statistics,
},
grid: { top: 8, right: 36, bottom: 0, left: 0, containLabel: true },
grid: { top: 8, right: 36, bottom: 10, left: 25, containLabel: true },
tooltip: {
trigger: 'item',
position: 'top',
@ -165,7 +165,7 @@ const Chart: React.FC<IChartProps> = ({
lineStyle: {
color: COMMON_COLOR_MAP.splitLineDark,
},
interval(index, value) {
interval(_index, value) {
return !!value
},
},
@ -242,7 +242,7 @@ const Chart: React.FC<IChartProps> = ({
? ''
: <span>{t('appOverview.analysis.tokenUsage.consumed')} Tokens<span className='text-sm'>
<span className='ml-1 text-text-tertiary'>(</span>
<span className='text-orange-400'>~{sum(statistics.map(item => Number.parseFloat(get(item, 'total_price', '0')))).toLocaleString('en-US', { style: 'currency', currency: 'USD', minimumFractionDigits: 4 })}</span>
<span className='text-orange-400'>~{sum(statistics.map(item => Number.parseFloat(String(get(item, 'total_price', '0'))))).toLocaleString('en-US', { style: 'currency', currency: 'USD', minimumFractionDigits: 4 })}</span>
<span className='text-text-tertiary'>)</span>
</span></span>}
textStyle={{ main: `!text-3xl !font-normal ${sumData === 0 ? '!text-text-quaternary' : ''}` }} />
@ -268,7 +268,7 @@ export const MessagesChart: FC<IBizChartProps> = ({ id, period }) => {
const noDataFlag = !response.data || response.data.length === 0
return <Chart
basicInfo={{ title: t('appOverview.analysis.totalMessages.title'), explanation: t('appOverview.analysis.totalMessages.explanation'), timePeriod: period.name }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) } as any}
chartType='messages'
{...(noDataFlag && { yMax: 500 })}
/>
@ -282,7 +282,7 @@ export const ConversationsChart: FC<IBizChartProps> = ({ id, period }) => {
const noDataFlag = !response.data || response.data.length === 0
return <Chart
basicInfo={{ title: t('appOverview.analysis.totalConversations.title'), explanation: t('appOverview.analysis.totalConversations.explanation'), timePeriod: period.name }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) } as any}
chartType='conversations'
{...(noDataFlag && { yMax: 500 })}
/>
@ -297,7 +297,7 @@ export const EndUsersChart: FC<IBizChartProps> = ({ id, period }) => {
const noDataFlag = !response.data || response.data.length === 0
return <Chart
basicInfo={{ title: t('appOverview.analysis.activeUsers.title'), explanation: t('appOverview.analysis.activeUsers.explanation'), timePeriod: period.name }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) } as any}
chartType='endUsers'
{...(noDataFlag && { yMax: 500 })}
/>
@ -380,7 +380,7 @@ export const CostChart: FC<IBizChartProps> = ({ id, period }) => {
const noDataFlag = !response.data || response.data.length === 0
return <Chart
basicInfo={{ title: t('appOverview.analysis.tokenUsage.title'), explanation: t('appOverview.analysis.tokenUsage.explanation'), timePeriod: period.name }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) } as any}
chartType='costs'
{...(noDataFlag && { yMax: 100 })}
/>
@ -394,7 +394,7 @@ export const WorkflowMessagesChart: FC<IBizChartProps> = ({ id, period }) => {
const noDataFlag = !response.data || response.data.length === 0
return <Chart
basicInfo={{ title: t('appOverview.analysis.totalMessages.title'), explanation: t('appOverview.analysis.totalMessages.explanation'), timePeriod: period.name }}
chartData={!noDataFlag ? response : { data: getDefaultChartData({ ...(period.query ?? defaultPeriod), key: 'runs' }) }}
chartData={!noDataFlag ? response : { data: getDefaultChartData({ ...(period.query ?? defaultPeriod), key: 'runs' }) } as any}
chartType='conversations'
valueKey='runs'
{...(noDataFlag && { yMax: 500 })}
@ -410,7 +410,7 @@ export const WorkflowDailyTerminalsChart: FC<IBizChartProps> = ({ id, period })
const noDataFlag = !response.data || response.data.length === 0
return <Chart
basicInfo={{ title: t('appOverview.analysis.activeUsers.title'), explanation: t('appOverview.analysis.activeUsers.explanation'), timePeriod: period.name }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) } as any}
chartType='endUsers'
{...(noDataFlag && { yMax: 500 })}
/>
@ -425,7 +425,7 @@ export const WorkflowCostChart: FC<IBizChartProps> = ({ id, period }) => {
const noDataFlag = !response.data || response.data.length === 0
return <Chart
basicInfo={{ title: t('appOverview.analysis.tokenUsage.title'), explanation: t('appOverview.analysis.tokenUsage.explanation'), timePeriod: period.name }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) }}
chartData={!noDataFlag ? response : { data: getDefaultChartData(period.query ?? defaultPeriod) } as any}
chartType='workflowCosts'
{...(noDataFlag && { yMax: 100 })}
/>

View File

@ -18,3 +18,13 @@
.pluginInstallIcon {
background-image: url(../assets/chromeplugin-install.svg);
}
:global(html[data-theme="dark"]) .iframeIcon,
:global(html[data-theme="dark"]) .scriptsIcon,
:global(html[data-theme="dark"]) .chromePluginIcon {
filter: invert(0.86) hue-rotate(180deg) saturate(0.5) brightness(0.95);
}
:global(html[data-theme="dark"]) .pluginInstallIcon {
filter: invert(0.9);
}

View File

@ -117,7 +117,7 @@ const Flowchart = React.forwardRef((props: {
const [isInitialized, setIsInitialized] = useState(false)
const [currentTheme, setCurrentTheme] = useState<'light' | 'dark'>(props.theme || 'light')
const containerRef = useRef<HTMLDivElement>(null)
const chartId = useRef(`mermaid-chart-${Math.random().toString(36).substr(2, 9)}`).current
const chartId = useRef(`mermaid-chart-${Math.random().toString(36).slice(2, 11)}`).current
const [isLoading, setIsLoading] = useState(true)
const renderTimeoutRef = useRef<NodeJS.Timeout>()
const [errMsg, setErrMsg] = useState('')

View File

@ -52,7 +52,7 @@ export default function Modal({
<div className="flex min-h-full items-center justify-center p-4 text-center">
<TransitionChild>
<DialogPanel className={classNames(
'w-full max-w-[480px] rounded-2xl bg-components-panel-bg p-6 text-left align-middle shadow-xl transition-all',
'relative w-full max-w-[480px] rounded-2xl bg-components-panel-bg p-6 text-left align-middle shadow-xl transition-all',
overflowVisible ? 'overflow-visible' : 'overflow-hidden',
'duration-100 ease-in data-[closed]:scale-95 data-[closed]:opacity-0',
'data-[enter]:scale-100 data-[enter]:opacity-100',

View File

@ -259,7 +259,7 @@ function getFullMatchOffset(
): number {
let triggerOffset = offset
for (let i = triggerOffset; i <= entryText.length; i++) {
if (documentText.substr(-i) === entryText.substr(0, i))
if (documentText.slice(-i) === entryText.slice(0, i))
triggerOffset = i
}
return triggerOffset

View File

@ -214,78 +214,83 @@ const SimpleSelect: FC<ISelectProps> = ({
}
}}
>
<div className={classNames('group/simple-select relative h-9', wrapperClassName)}>
{renderTrigger && <ListboxButton className='w-full'>{renderTrigger(selectedItem)}</ListboxButton>}
{!renderTrigger && (
<ListboxButton onClick={() => {
// get data-open, use setTimeout to ensure the attribute is set
setTimeout(() => {
if (listboxRef.current)
onOpenChange?.(listboxRef.current.getAttribute('data-open') !== null)
})
}} className={classNames(`flex h-full w-full items-center rounded-lg border-0 bg-components-input-bg-normal pl-3 pr-10 focus-visible:bg-state-base-hover-alt focus-visible:outline-none group-hover/simple-select:bg-state-base-hover-alt sm:text-sm sm:leading-6 ${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`, className)}>
<span className={classNames('system-sm-regular block truncate text-left text-components-input-text-filled', !selectedItem?.name && 'text-components-input-text-placeholder')}>{selectedItem?.name ?? localPlaceholder}</span>
<span className="absolute inset-y-0 right-0 flex items-center pr-2">
{isLoading ? <RiLoader4Line className='h-3.5 w-3.5 animate-spin text-text-secondary' />
: (selectedItem && !notClearable)
? (
<XMarkIcon
onClick={(e) => {
e.stopPropagation()
setSelectedItem(null)
onSelect({ name: '', value: '' })
}}
className="h-4 w-4 cursor-pointer text-text-quaternary"
aria-hidden="false"
/>
)
: (
<ChevronDownIcon
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
aria-hidden="true"
/>
)}
</span>
</ListboxButton>
)}
{(!disabled) && (
<ListboxOptions className={classNames('absolute z-10 mt-1 max-h-60 w-full overflow-auto rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur px-1 py-1 text-base shadow-lg backdrop-blur-sm focus:outline-none sm:text-sm', optionWrapClassName)}>
{items.map((item: Item) => (
<ListboxOption
key={item.value}
className={
classNames(
'relative cursor-pointer select-none rounded-lg py-2 pl-3 pr-9 text-text-secondary hover:bg-state-base-hover',
optionClassName,
{({ open }) => (
<div className={classNames('group/simple-select relative h-9', wrapperClassName)}>
{renderTrigger && <ListboxButton className='w-full'>{renderTrigger(selectedItem)}</ListboxButton>}
{!renderTrigger && (
<ListboxButton onClick={() => {
onOpenChange?.(open)
}} className={classNames(`flex h-full w-full items-center rounded-lg border-0 bg-components-input-bg-normal pl-3 pr-10 focus-visible:bg-state-base-hover-alt focus-visible:outline-none group-hover/simple-select:bg-state-base-hover-alt sm:text-sm sm:leading-6 ${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}`, className)}>
<span className={classNames('system-sm-regular block truncate text-left text-components-input-text-filled', !selectedItem?.name && 'text-components-input-text-placeholder')}>{selectedItem?.name ?? localPlaceholder}</span>
<span className="absolute inset-y-0 right-0 flex items-center pr-2">
{isLoading ? <RiLoader4Line className='h-3.5 w-3.5 animate-spin text-text-secondary' />
: (selectedItem && !notClearable)
? (
<XMarkIcon
onClick={(e) => {
e.stopPropagation()
setSelectedItem(null)
onSelect({ name: '', value: '' })
}}
className="h-4 w-4 cursor-pointer text-text-quaternary"
aria-hidden="false"
/>
)
}
value={item}
disabled={disabled}
>
{({ /* active, */ selected }) => (
<>
{renderOption
? renderOption({ item, selected })
: (<>
<span className={classNames('block', selected && 'font-normal')}>{item.name}</span>
{selected && !hideChecked && (
<span
className={classNames(
'absolute inset-y-0 right-0 flex items-center pr-4 text-text-accent',
)}
>
<RiCheckLine className="h-4 w-4" aria-hidden="true" />
</span>
)}
</>)}
</>
)}
</ListboxOption>
))}
</ListboxOptions>
)}
</div>
: (
open ? (
<ChevronUpIcon
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
aria-hidden="true"
/>
) : (
<ChevronDownIcon
className="h-4 w-4 text-text-quaternary group-hover/simple-select:text-text-secondary"
aria-hidden="true"
/>
)
)}
</span>
</ListboxButton>
)}
{(!disabled) && (
<ListboxOptions className={classNames('absolute z-10 mt-1 max-h-60 w-full overflow-auto rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur px-1 py-1 text-base shadow-lg backdrop-blur-sm focus:outline-none sm:text-sm', optionWrapClassName)}>
{items.map((item: Item) => (
<ListboxOption
key={item.value}
className={
classNames(
'relative cursor-pointer select-none rounded-lg py-2 pl-3 pr-9 text-text-secondary hover:bg-state-base-hover',
optionClassName,
)
}
value={item}
disabled={disabled}
>
{({ /* active, */ selected }) => (
<>
{renderOption
? renderOption({ item, selected })
: (<>
<span className={classNames('block', selected && 'font-normal')}>{item.name}</span>
{selected && !hideChecked && (
<span
className={classNames(
'absolute inset-y-0 right-0 flex items-center pr-4 text-text-accent',
)}
>
<RiCheckLine className="h-4 w-4" aria-hidden="true" />
</span>
)}
</>)}
</>
)}
</ListboxOption>
))}
</ListboxOptions>
)}
</div>
)}
</Listbox>
)
}

View File

@ -69,9 +69,11 @@ const RenameDatasetModal = ({ show, dataset, onSuccess, onClose }: RenameDataset
isShow={show}
onClose={noop}
>
<div className='relative pb-2 text-xl font-medium leading-[30px] text-text-primary'>{t('datasetSettings.title')}</div>
<div className='absolute right-4 top-4 cursor-pointer p-2' onClick={onClose}>
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
<div className='flex items-center justify-between pb-2'>
<div className='text-xl font-medium leading-[30px] text-text-primary'>{t('datasetSettings.title')}</div>
<div className='cursor-pointer p-2' onClick={onClose}>
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
</div>
</div>
<div>
<div className={cn('flex flex-wrap items-center justify-between py-4')}>

View File

@ -0,0 +1,26 @@
export type CommandHandler = (args?: Record<string, any>) => void | Promise<void>
const handlers = new Map<string, CommandHandler>()
export const registerCommand = (name: string, handler: CommandHandler) => {
handlers.set(name, handler)
}
export const unregisterCommand = (name: string) => {
handlers.delete(name)
}
export const executeCommand = async (name: string, args?: Record<string, any>) => {
const handler = handlers.get(name)
if (!handler)
return
await handler(args)
}
export const registerCommands = (map: Record<string, CommandHandler>) => {
Object.entries(map).forEach(([name, handler]) => registerCommand(name, handler))
}
export const unregisterCommands = (names: string[]) => {
names.forEach(unregisterCommand)
}

View File

@ -3,11 +3,13 @@ import { knowledgeAction } from './knowledge'
import { pluginAction } from './plugin'
import { workflowNodesAction } from './workflow-nodes'
import type { ActionItem, SearchResult } from './types'
import { commandAction } from './run'
export const Actions = {
app: appAction,
knowledge: knowledgeAction,
plugin: pluginAction,
run: commandAction,
node: workflowNodesAction,
}

View File

@ -0,0 +1,33 @@
import type { CommandSearchResult } from './types'
import { languages } from '@/i18n-config/language'
import { RiTranslate } from '@remixicon/react'
import i18n from '@/i18n-config/i18next-config'
export const buildLanguageCommands = (query: string): CommandSearchResult[] => {
const q = query.toLowerCase()
const list = languages.filter(item => item.supported && (
!q || item.name.toLowerCase().includes(q) || String(item.value).toLowerCase().includes(q)
))
return list.map(item => ({
id: `lang-${item.value}`,
title: item.name,
description: i18n.t('app.gotoAnything.actions.languageChangeDesc'),
type: 'command' as const,
data: { command: 'i18n.set', args: { locale: item.value } },
}))
}
export const buildLanguageRootItem = (): CommandSearchResult => {
return {
id: 'category-language',
title: i18n.t('app.gotoAnything.actions.languageCategoryTitle'),
description: i18n.t('app.gotoAnything.actions.languageCategoryDesc'),
type: 'command',
icon: (
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
<RiTranslate className='h-4 w-4 text-text-tertiary' />
</div>
),
data: { command: 'nav.search', args: { query: '@run language ' } },
}
}

View File

@ -0,0 +1,61 @@
import type { CommandSearchResult } from './types'
import type { ReactNode } from 'react'
import { RiComputerLine, RiMoonLine, RiPaletteLine, RiSunLine } from '@remixicon/react'
import i18n from '@/i18n-config/i18next-config'
const THEME_ITEMS: { id: 'light' | 'dark' | 'system'; titleKey: string; descKey: string; icon: ReactNode }[] = [
{
id: 'system',
titleKey: 'app.gotoAnything.actions.themeSystem',
descKey: 'app.gotoAnything.actions.themeSystemDesc',
icon: <RiComputerLine className='h-4 w-4 text-text-tertiary' />,
},
{
id: 'light',
titleKey: 'app.gotoAnything.actions.themeLight',
descKey: 'app.gotoAnything.actions.themeLightDesc',
icon: <RiSunLine className='h-4 w-4 text-text-tertiary' />,
},
{
id: 'dark',
titleKey: 'app.gotoAnything.actions.themeDark',
descKey: 'app.gotoAnything.actions.themeDarkDesc',
icon: <RiMoonLine className='h-4 w-4 text-text-tertiary' />,
},
]
export const buildThemeCommands = (query: string, locale?: string): CommandSearchResult[] => {
const q = query.toLowerCase()
const list = THEME_ITEMS.filter(item =>
!q
|| i18n.t(item.titleKey, { lng: locale }).toLowerCase().includes(q)
|| item.id.includes(q),
)
return list.map(item => ({
id: item.id,
title: i18n.t(item.titleKey, { lng: locale }),
description: i18n.t(item.descKey, { lng: locale }),
type: 'command' as const,
icon: (
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
{item.icon}
</div>
),
data: { command: 'theme.set', args: { value: item.id } },
}))
}
export const buildThemeRootItem = (): CommandSearchResult => {
return {
id: 'category-theme',
title: i18n.t('app.gotoAnything.actions.themeCategoryTitle'),
description: i18n.t('app.gotoAnything.actions.themeCategoryDesc'),
type: 'command',
icon: (
<div className='flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg'>
<RiPaletteLine className='h-4 w-4 text-text-tertiary' />
</div>
),
data: { command: 'nav.search', args: { query: '@run theme ' } },
}
}

View File

@ -0,0 +1,97 @@
'use client'
import { useEffect } from 'react'
import type { ActionItem, CommandSearchResult } from './types'
import { buildLanguageCommands, buildLanguageRootItem } from './run-language'
import { buildThemeCommands, buildThemeRootItem } from './run-theme'
import i18n from '@/i18n-config/i18next-config'
import { executeCommand, registerCommands, unregisterCommands } from './command-bus'
import { useTheme } from 'next-themes'
import { setLocaleOnClient } from '@/i18n-config'
const rootParser = (query: string): CommandSearchResult[] => {
const q = query.toLowerCase()
const items: CommandSearchResult[] = []
if (!q || 'theme'.includes(q))
items.push(buildThemeRootItem())
if (!q || 'language'.includes(q) || 'lang'.includes(q))
items.push(buildLanguageRootItem())
return items
}
type RunContext = {
setTheme?: (value: 'light' | 'dark' | 'system') => void
setLocale?: (locale: string) => Promise<void>
search?: (query: string) => void
}
export const commandAction: ActionItem = {
key: '@run',
shortcut: '@run',
title: i18n.t('app.gotoAnything.actions.runTitle'),
description: i18n.t('app.gotoAnything.actions.runDesc'),
action: (result) => {
if (result.type !== 'command') return
const { command, args } = result.data
if (command === 'theme.set') {
executeCommand('theme.set', args)
return
}
if (command === 'i18n.set') {
executeCommand('i18n.set', args)
return
}
if (command === 'nav.search')
executeCommand('nav.search', args)
},
search: async (_, searchTerm = '') => {
const q = searchTerm.trim()
if (q.startsWith('theme'))
return buildThemeCommands(q.replace(/^theme\s*/, ''), i18n.language)
if (q.startsWith('language') || q.startsWith('lang'))
return buildLanguageCommands(q.replace(/^(language|lang)\s*/, ''))
// root categories
return rootParser(q)
},
}
// Register/unregister default handlers for @run commands with external dependencies.
export const registerRunCommands = (deps: {
setTheme?: (value: 'light' | 'dark' | 'system') => void
setLocale?: (locale: string) => Promise<void>
search?: (query: string) => void
}) => {
registerCommands({
'theme.set': async (args) => {
deps.setTheme?.(args?.value)
},
'i18n.set': async (args) => {
const locale = args?.locale
if (locale)
await deps.setLocale?.(locale)
},
'nav.search': (args) => {
const q = args?.query
if (q)
deps.search?.(q)
},
})
}
export const unregisterRunCommands = () => {
unregisterCommands(['theme.set', 'i18n.set', 'nav.search'])
}
export const RunCommandProvider = ({ onNavSearch }: { onNavSearch?: (q: string) => void }) => {
const theme = useTheme()
useEffect(() => {
registerRunCommands({
setTheme: theme.setTheme,
setLocale: setLocaleOnClient,
search: onNavSearch,
})
return () => unregisterRunCommands()
}, [theme.setTheme, onNavSearch])
return null
}

View File

@ -5,7 +5,7 @@ import type { Plugin } from '../../plugins/types'
import type { DataSet } from '@/models/datasets'
import type { CommonNodeType } from '../../workflow/types'
export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node'
export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node' | 'command'
export type BaseSearchResult<T = any> = {
id: string
@ -37,10 +37,14 @@ export type WorkflowNodeSearchResult = {
}
} & BaseSearchResult<CommonNodeType>
export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult
export type CommandSearchResult = {
type: 'command'
} & BaseSearchResult<{ command: string; args?: Record<string, any> }>
export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult | CommandSearchResult
export type ActionItem = {
key: '@app' | '@knowledge' | '@plugin' | '@node'
key: '@app' | '@knowledge' | '@plugin' | '@node' | '@run'
shortcut: string
title: string | TypeWithI18N
description: string

View File

@ -20,7 +20,6 @@ const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, co
return true
const filterLower = searchFilter.toLowerCase()
return action.shortcut.toLowerCase().includes(filterLower)
|| action.key.toLowerCase().includes(filterLower)
})
useEffect(() => {
@ -61,7 +60,7 @@ const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, co
className="flex cursor-pointer items-center rounded-md
p-2.5
transition-all
duration-150 hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover"
duration-150 hover:bg-state-base-hover aria-[selected=true]:bg-state-base-hover-alt"
onSelect={() => onCommandSelect(action.shortcut)}
>
<span className="min-w-[4.5rem] text-left font-mono text-xs text-text-tertiary">
@ -73,6 +72,7 @@ const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, co
'@app': 'app.gotoAnything.actions.searchApplicationsDesc',
'@plugin': 'app.gotoAnything.actions.searchPluginsDesc',
'@knowledge': 'app.gotoAnything.actions.searchKnowledgeBasesDesc',
'@run': 'app.gotoAnything.actions.runDesc',
'@node': 'app.gotoAnything.actions.searchWorkflowNodesDesc',
}
return t(keyMap[action.key])

View File

@ -18,6 +18,7 @@ import InstallFromMarketplace from '../plugins/install-plugin/install-from-marke
import type { Plugin } from '../plugins/types'
import { Command } from 'cmdk'
import CommandSelector from './command-selector'
import { RunCommandProvider } from './actions/run'
type Props = {
onHide?: () => void
@ -31,9 +32,14 @@ const GotoAnything: FC<Props> = ({
const { t } = useTranslation()
const [show, setShow] = useState<boolean>(false)
const [searchQuery, setSearchQuery] = useState<string>('')
const [cmdVal, setCmdVal] = useState<string>('')
const [cmdVal, setCmdVal] = useState<string>('_')
const inputRef = useRef<HTMLInputElement>(null)
const handleNavSearch = useCallback((q: string) => {
setShow(true)
setSearchQuery(q)
setCmdVal('')
requestAnimationFrame(() => inputRef.current?.focus())
}, [])
// Filter actions based on context
const Actions = useMemo(() => {
// Create a filtered copy of actions based on current page context
@ -43,8 +49,8 @@ const GotoAnything: FC<Props> = ({
}
else {
// Exclude node action on non-workflow pages
const { app, knowledge, plugin } = AllActions
return { app, knowledge, plugin }
const { app, knowledge, plugin, run } = AllActions
return { app, knowledge, plugin, run }
}
}, [isWorkflowPage])
@ -114,9 +120,14 @@ const GotoAnything: FC<Props> = ({
},
)
// Prevent automatic selection of the first option when cmdVal is not set
const clearSelection = () => {
setCmdVal('_')
}
const handleCommandSelect = useCallback((commandKey: string) => {
setSearchQuery(`${commandKey} `)
setCmdVal('')
clearSelection()
setTimeout(() => {
inputRef.current?.focus()
}, 0)
@ -128,6 +139,11 @@ const GotoAnything: FC<Props> = ({
setSearchQuery('')
switch (result.type) {
case 'command': {
const action = Object.values(Actions).find(a => a.key === '@run')
action?.action?.(result)
break
}
case 'plugin':
setActivePlugin(result.data)
break
@ -222,9 +238,6 @@ const GotoAnything: FC<Props> = ({
inputRef.current?.focus()
})
}
return () => {
setCmdVal('')
}
}, [show])
return (
@ -234,6 +247,7 @@ const GotoAnything: FC<Props> = ({
onClose={() => {
setShow(false)
setSearchQuery('')
clearSelection()
onHide?.()
}}
closable={false}
@ -245,6 +259,7 @@ const GotoAnything: FC<Props> = ({
className='outline-none'
value={cmdVal}
onValueChange={setCmdVal}
disablePointerSelection
>
<div className='flex items-center gap-3 border-b border-divider-subtle bg-components-panel-bg-blur px-4 py-3'>
<RiSearchLine className='h-4 w-4 text-text-quaternary' />
@ -256,7 +271,7 @@ const GotoAnything: FC<Props> = ({
onChange={(e) => {
setSearchQuery(e.target.value)
if (!e.target.value.startsWith('@'))
setCmdVal('')
clearSelection()
}}
className='flex-1 !border-0 !bg-transparent !shadow-none'
wrapperClassName='flex-1 !border-0 !bg-transparent'
@ -309,40 +324,40 @@ const GotoAnything: FC<Props> = ({
/>
) : (
Object.entries(groupedResults).map(([type, results], groupIndex) => (
<Command.Group key={groupIndex} heading={(() => {
const typeMap: Record<string, string> = {
'app': 'app.gotoAnything.groups.apps',
'plugin': 'app.gotoAnything.groups.plugins',
'knowledge': 'app.gotoAnything.groups.knowledgeBases',
'workflow-node': 'app.gotoAnything.groups.workflowNodes',
}
return t(typeMap[type] || `${type}s`)
})()} className='p-2 capitalize text-text-secondary'>
{results.map(result => (
<Command.Item
key={`${result.type}-${result.id}`}
value={result.title}
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] aria-[selected=true]:bg-state-base-hover data-[selected=true]:bg-state-base-hover'
onSelect={() => handleNavigate(result)}
>
{result.icon}
<div className='min-w-0 flex-1'>
<div className='truncate font-medium text-text-secondary'>
{result.title}
</div>
{result.description && (
<div className='mt-0.5 truncate text-xs text-text-quaternary'>
{result.description}
<Command.Group key={groupIndex} heading={(() => {
const typeMap: Record<string, string> = {
'app': 'app.gotoAnything.groups.apps',
'plugin': 'app.gotoAnything.groups.plugins',
'knowledge': 'app.gotoAnything.groups.knowledgeBases',
'workflow-node': 'app.gotoAnything.groups.workflowNodes',
}
return t(typeMap[type] || `${type}s`)
})()} className='p-2 capitalize text-text-secondary'>
{results.map(result => (
<Command.Item
key={`${result.type}-${result.id}`}
value={result.title}
className='flex cursor-pointer items-center gap-3 rounded-md p-3 will-change-[background-color] aria-[selected=true]:bg-state-base-hover data-[selected=true]:bg-state-base-hover'
onSelect={() => handleNavigate(result)}
>
{result.icon}
<div className='min-w-0 flex-1'>
<div className='truncate font-medium text-text-secondary'>
{result.title}
</div>
)}
</div>
<div className='text-xs capitalize text-text-quaternary'>
{result.type}
</div>
</Command.Item>
))}
</Command.Group>
))
{result.description && (
<div className='mt-0.5 truncate text-xs text-text-quaternary'>
{result.description}
</div>
)}
</div>
<div className='text-xs capitalize text-text-quaternary'>
{result.type}
</div>
</Command.Item>
))}
</Command.Group>
))
)}
{!isCommandsMode && emptyResult}
{!isCommandsMode && defaultUI}
@ -361,7 +376,7 @@ const GotoAnything: FC<Props> = ({
{t('app.gotoAnything.resultCount', { count: searchResults.length })}
{searchMode !== 'general' && (
<span className='ml-2 opacity-60'>
{t('app.gotoAnything.inScope', { scope: searchMode.replace('@', '') })}
{t('app.gotoAnything.inScope', { scope: searchMode.replace('@', '') })}
</span>
)}
</>
@ -380,6 +395,7 @@ const GotoAnything: FC<Props> = ({
</div>
</Modal>
<RunCommandProvider onNavSearch={handleNavSearch} />
{
activePlugin && (
<InstallFromMarketplace

Some files were not shown because too many files have changed in this diff Show More