Merge branch 'main' into feat/hitl-frontend

This commit is contained in:
twwu 2026-01-05 15:31:08 +08:00
commit cf9b72d574
133 changed files with 29329 additions and 1338 deletions

View File

@ -3,6 +3,7 @@
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
"pyright-lsp@claude-plugins-official": true,
"ralph-wiggum@claude-plugins-official": true
}
}

View File

@ -0,0 +1,73 @@
---
name: frontend-code-review
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
---
# Frontend Code Review
## Intent
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
1. **Pending-change review** inspect staged/working-tree files slated for commit and flag checklist violations before submission.
2. **File-targeted review** review the specific file(s) the user names and report the relevant checklist findings.
Stick to the checklist below for every applicable file and mode.
## Checklist
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
## Review Process
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
## Required output
When invoked, the response must exactly follow one of the two templates:
### Template A (any findings)
```
# Code review
Found <N> urgent issues need to be fixed:
## 1 <brief description of bug>
FilePath: <path> line <line>
<relevant code snippet or pointer>
### Suggested fix
<brief description of suggested fix>
---
... (repeat for each urgent issue) ...
Found <M> suggestions for improvement:
## 1 <brief description of suggestion>
FilePath: <path> line <line>
<relevant code snippet or pointer>
### Suggested fix
<brief description of suggested fix>
---
... (repeat for each suggestion) ...
```
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
Don't compress the blank lines between sections; keep them as-is for readability.
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
### Template B (no issues)
```
## Code review
No issues found.
```

View File

@ -0,0 +1,15 @@
# Rule Catalog — Business Logic
## Can't use workflowStore in Node components
IsUrgent: True
### Description
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
### Suggested Fix
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.

View File

@ -0,0 +1,44 @@
# Rule Catalog — Code Quality
## Conditional class names use utility function
IsUrgent: True
Category: Code Quality
### Description
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
### Suggested Fix
```ts
import { cn } from '@/utils/classnames'
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
```
## Tailwind-first styling
IsUrgent: True
Category: Code Quality
### Description
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
## Classname ordering for easy overrides
### Description
When writing components, always place the incoming `className` prop after the components own class values so that downstream consumers can override or extend the styling. This keeps your components defaults but still lets external callers change or remove specific styles.
Example:
```tsx
import { cn } from '@/utils/classnames'
const Button = ({ className }) => {
return <div className={cn('bg-primary-600', className)}></div>
}
```

View File

@ -0,0 +1,45 @@
# Rule Catalog — Performance
## React Flow data usage
IsUrgent: True
Category: Performance
### Description
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
## Complex prop memoization
IsUrgent: True
Category: Performance
### Description
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
Wrong:
```tsx
<HeavyComp
config={{
provider: ...,
detail: ...
}}
/>
```
Right:
```tsx
const config = useMemo(() => ({
provider: ...,
detail: ...
}), [provider, detail]);
<HeavyComp
config={config}
/>
```

View File

@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods

View File

@ -5,6 +5,7 @@ on:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
permissions:
contents: write
@ -18,7 +19,8 @@ jobs:
run:
working-directory: web
steps:
- uses: actions/checkout@v6
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@ -26,21 +28,28 @@ jobs:
- name: Check for file changes in i18n/en-US
id: check_files
run: |
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
# Skip check for manual trigger, translate all files
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
echo "FILE_ARGS=" >> $GITHUB_ENV
echo "Manual trigger: translating all files"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
fi
- name: Install pnpm

1
.gitignore vendored
View File

@ -235,3 +235,4 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md

View File

@ -60,9 +60,10 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, and import linter..."
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
type-check:
@ -122,7 +123,7 @@ help:
@echo "Backend Code Quality:"
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format and fix code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo ""

View File

@ -502,6 +502,8 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log output format: text or json
LOG_OUTPUT_FORMAT=text
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s

View File

@ -3,9 +3,11 @@ root_packages =
core
configs
controllers
extensions
models
tasks
services
include_external_packages = True
[importlinter:contract:workflow]
name = Workflow
@ -33,6 +35,29 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
type = forbidden
source_modules =
core.workflow
forbidden_modules =
extensions.ext_database
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.variable_assigner.common.impl -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
[importlinter:contract:rsc]
name = RSC
type = layers

View File

@ -2,9 +2,11 @@ import logging
import time
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
# Initialize logging context for this request
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
def add_trace_id_header(response):
def add_trace_headers(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
if not ctx or not ctx.is_valid:
return response
# Inject trace headers from OTEL context
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
_ = add_trace_headers
return dify_app

View File

@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
@click.command("file-usage", help="Query file usages and show where files are referenced.")
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
@click.option("--key", type=str, default=None, help="Filter by storage key.")
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
def file_usage(
file_id: str | None,
key: str | None,
src: str | None,
limit: int,
offset: int,
output_json: bool,
):
"""
Query file usages and show where files are referenced in the database.
This command reuses the same reference checking logic as clear-orphaned-file-records
and displays detailed information about where each file is referenced.
"""
# define tables and columns to process
files_tables = [
{"table": "upload_files", "id_column": "id", "key_column": "key"},
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
]
ids_tables = [
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
]
# Stream file usages with pagination to avoid holding all results in memory
paginated_usages = []
total_count = 0
# First, build a mapping of file_id -> storage_key from the base tables
file_key_map = {}
for files_table in files_tables:
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
# If filtering by key or file_id, verify it exists
if file_id and file_id not in file_key_map:
if output_json:
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
else:
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
return
if key:
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
if not matching_file_ids:
if output_json:
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
else:
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
return
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
# For each reference table/column, find matching file IDs and record the references
for ids_table in ids_tables:
src_filter = f"{ids_table['table']}.{ids_table['column']}"
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
if src:
if "%" in src or "_" in src:
import fnmatch
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
pattern = src.replace("%", "*").replace("_", "?")
if not fnmatch.fnmatch(src_filter, pattern):
continue
else:
if src_filter != src:
continue
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
# Output results
if output_json:
result = {
"total": total_count,
"offset": offset,
"limit": limit,
"usages": paginated_usages,
}
click.echo(json.dumps(result, indent=2))
else:
click.echo(
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
)
click.echo("")
if not paginated_usages:
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
return
# Print table header
click.echo(
click.style(
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
fg="cyan",
)
)
click.echo(click.style("-" * 190, fg="white"))
# Print each usage
for usage in paginated_usages:
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
# Show pagination info
if offset + limit < total_count:
click.echo("")
click.echo(
click.style(
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
)
)
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")

View File

@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
default="INFO",
)
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
default="text",
)
LOG_FILE: str | None = Field(
description="File path for log output.",
default=None,

View File

@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
description="Authentication token for Milvus, if token-based authentication is enabled",
default=None,
)
MILVUS_USER: str | None = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
default=None,

View File

@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
@ -177,6 +176,12 @@ annotation_hit_history_model = console_ns.model(
},
)
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]["text"] if value else ""
# Simple message detail model
simple_message_detail_model = console_ns.model(
"SimpleMessageDetail",

View File

@ -751,12 +751,12 @@ class DocumentApi(DocumentResource):
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@ -784,12 +784,12 @@ class DocumentApi(DocumentResource):
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,

View File

@ -1,8 +1,7 @@
from typing import Any
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
ResultResponse,
SimpleConversation,
)
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource):
invoke_from=InvokeFrom.EXPLORE,
pinned=args.pinned,
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204
@console_ns.route(
@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource):
endpoint="installed_app_conversation_rename",
)
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename(
conversation = ConversationService.rename(
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")

View File

@ -2,8 +2,7 @@ import logging
from typing import Literal
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -23,7 +22,8 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource):
args = MessageListQuery.model_validate(request.args.to_dict())
try:
return MessageService.pagination_by_first_id(
pagination = MessageService.pagination_by_first_id(
app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
)
adapter = TypeAdapter(MessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return MessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")

View File

@ -1,14 +1,14 @@
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, UUIDStrOrEmpty
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -26,28 +26,8 @@ class SavedMessageCreatePayload(BaseModel):
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
feedback_fields = {"rating": fields.String}
message_fields = {
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@ -57,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource):
args = SavedMessageListQuery.model_validate(request.args.to_dict())
return SavedMessageService.pagination_by_last_id(
pagination = SavedMessageService.pagination_by_last_id(
app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
)
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
@ -78,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@ -96,4 +83,4 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204

View File

@ -4,12 +4,11 @@ from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel):
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
@model_validator(mode="after")
def check_at_least_one_field(self):
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
return self
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource):
provider_id = TriggerProviderID(subscription.provider_id)
try:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
# For rename only, just update the name
rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
# When credential type is UNAUTHORIZED, it indicates the subscription was manually created
# For Manually created subscription, they dont have credentials, parameters
# They only have name and properties(which is input by user)
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
if rename or manually_created:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
name=request.name,
properties=request.properties,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
# For the rest cases(API_KEY, OAUTH2)
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=request.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=request.credentials or subscription.credentials,
parameters=request.parameters or subscription.parameters,
)
return 200
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:

View File

@ -3,8 +3,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
build_conversation_delete_model,
build_conversation_infinite_scroll_pagination_model,
build_simple_conversation_model,
ConversationDelete,
ConversationInfiniteScrollPagination,
SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
@ -105,7 +104,6 @@ class ConversationApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List all conversations for the current user.
@ -120,7 +118,7 @@ class ConversationApi(Resource):
try:
with Session(db.engine) as session:
return ConversationService.pagination_by_last_id(
pagination = ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@ -129,6 +127,13 @@ class ConversationApi(Resource):
invoke_from=InvokeFrom.SERVICE_API,
sort_by=query_args.sort_by,
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -146,7 +151,6 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
@ -159,7 +163,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
return ConversationDelete(result="success").model_dump(mode="json"), 204
@service_api_ns.route("/conversations/<uuid:c_id>/name")
@ -176,7 +180,6 @@ class ConversationRenameApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
@ -188,7 +191,14 @@ class ConversationRenameApi(Resource):
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
conversation = ConversationService.rename(
app_model, conversation_id, end_user, payload.name, payload.auto_generate
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,11 +1,10 @@
import json
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField
from libs.helper import TimestampField
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel):
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
def build_message_model(api_or_ns: Namespace):
"""Build the message model for the API or Namespace."""
# First build the nested models
feedback_model = build_feedback_model(api_or_ns)
agent_thought_model = build_agent_thought_model(api_or_ns)
message_file_model = build_message_file_model(api_or_ns)
# Then build the message fields with nested models
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_model)),
"feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.Raw(
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
if obj.message_metadata
else []
),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"status": fields.String,
"error": fields.String,
}
return api_or_ns.model("Message", message_fields)
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the message infinite scroll pagination model for the API or Namespace."""
# Build the nested message model first
message_model = build_message_model(api_or_ns)
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_model)),
}
return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@ -104,7 +58,6 @@ class MessageListApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List messages in a conversation.
@ -119,9 +72,16 @@ class MessageListApi(Resource):
first_id = str(query_args.first_id) if query_args.first_id else None
try:
return MessageService.pagination_by_first_id(
pagination = MessageService.pagination_by_first_id(
app_model, end_user, conversation_id, first_id, query_args.limit
)
adapter = TypeAdapter(MessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return MessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@service_api_ns.route("/app/feedbacks")

View File

@ -1,5 +1,6 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -8,7 +9,11 @@ from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
ResultResponse,
SimpleConversation,
)
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
@ -54,7 +59,6 @@ class ConversationListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -82,7 +86,7 @@ class ConversationListApi(WebApiResource):
try:
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@ -92,16 +96,19 @@ class ConversationListApi(WebApiResource):
pinned=pinned,
sort_by=args["sort_by"],
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>")
class ConversationApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Conversation")
@web_ns.doc(description="Delete a specific conversation.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@ -115,7 +122,6 @@ class ConversationApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -126,7 +132,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204
@web_ns.route("/conversations/<uuid:c_id>/name")
@ -155,7 +161,6 @@ class ConversationRenameApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -171,17 +176,20 @@ class ConversationRenameApi(WebApiResource):
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
conversation = ConversationService.rename(
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>/pin")
class ConversationPinApi(WebApiResource):
pin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Pin Conversation")
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@ -195,7 +203,6 @@ class ConversationPinApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(pin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -208,15 +215,11 @@ class ConversationPinApi(WebApiResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/conversations/<uuid:c_id>/unpin")
class ConversationUnPinApi(WebApiResource):
unpin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Unpin Conversation")
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@ -230,7 +233,6 @@ class ConversationUnPinApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(unpin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -239,4 +241,4 @@ class ConversationUnPinApi(WebApiResource):
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")

View File

@ -2,8 +2,7 @@ import logging
from typing import Literal
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
from fields.raws import FilesContainedField
from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@ -70,29 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
}
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@web_ns.doc("Get Message List")
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
@web_ns.doc(
@ -121,7 +96,6 @@ class MessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -131,9 +105,16 @@ class MessageListApi(WebApiResource):
query = MessageListQuery.model_validate(raw_args)
try:
return MessageService.pagination_by_first_id(
pagination = MessageService.pagination_by_first_id(
app_model, end_user, query.conversation_id, query.first_id, query.limit
)
adapter = TypeAdapter(WebMessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return WebMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@ -142,10 +123,6 @@ class MessageListApi(WebApiResource):
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(WebApiResource):
feedback_response_fields = {
"result": fields.String,
}
@web_ns.doc("Create Message Feedback")
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@ -170,7 +147,6 @@ class MessageFeedbackApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(feedback_response_fields)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
@ -187,7 +163,7 @@ class MessageFeedbackApi(WebApiResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
@ -247,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource):
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(WebApiResource):
suggested_questions_response_fields = {
"data": fields.List(fields.String),
}
@web_ns.doc("Get Suggested Questions")
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@ -264,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(suggested_questions_response_fields)
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -277,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
)
# questions is a list of strings, not a list of Message objects
# so we can directly return it
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
@ -296,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")

View File

@ -1,40 +1,20 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import uuid_value
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
feedback_fields = {"rating": fields.String}
message_fields = {
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
post_response_fields = {
"result": fields.String,
}
@web_ns.doc("Get Saved Messages")
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
@web_ns.doc(
@ -58,7 +38,6 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -70,7 +49,14 @@ class SavedMessageListApi(WebApiResource):
)
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
@web_ns.doc("Save Message")
@web_ns.doc(description="Save a specific message for later reference.")
@ -89,7 +75,6 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(post_response_fields)
def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -102,15 +87,11 @@ class SavedMessageListApi(WebApiResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/saved-messages/<uuid:message_id>")
class SavedMessageApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Saved Message")
@web_ns.doc(description="Remove a message from saved messages.")
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
@ -124,7 +105,6 @@ class SavedMessageApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
@ -133,4 +113,4 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204

View File

@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
"""
Inject W3C traceparent header for distributed tracing.
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
When OTEL is disabled, we manually inject the traceparent header.
"""
if headers is None:
headers = {}
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return headers
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
if dify_config.ENABLE_OTEL:
return headers
# Generate and inject traceparent for non-OTEL scenarios
try:
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
# Silently ignore errors to avoid breaking requests
logger.debug("Failed to generate traceparent header", exc_info=True)
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
# Preserve the user-provided Host header
# httpx may override the Host header when using a proxy
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host

View File

@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
def get_span_id_from_otel_context() -> str | None:
"""
Retrieve the current span ID from the active OpenTelemetry trace context.
Returns:
A 16-character hex string representing the span ID, or None if not available.
"""
try:
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
span = get_current_span()
if not span:
return None
span_context = span.get_span_context()
if not span_context or span_context.span_id == INVALID_SPAN_ID:
return None
return f"{span_context.span_id:016x}"
except Exception:
return None
def generate_traceparent_header() -> str | None:
"""
Generate a W3C traceparent header from the current context.
Uses OpenTelemetry context if available, otherwise uses the
ContextVar-based trace_id from the logging context.
Format: {version}-{trace_id}-{span_id}-{flags}
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
Returns:
A valid traceparent header string, or None if generation fails.
"""
import uuid
# Try OTEL context first
trace_id = get_trace_id_from_otel_context()
span_id = get_span_id_from_otel_context()
if trace_id and span_id:
return f"00-{trace_id}-{span_id}-01"
# Fallback: use ContextVar-based trace_id or generate new one
from core.logging.context import get_trace_id as get_logging_trace_id
trace_id = get_logging_trace_id() or uuid.uuid4().hex
# Generate a new span_id (16 hex chars)
span_id = uuid.uuid4().hex[:16]
return f"00-{trace_id}-{span_id}-01"

View File

@ -0,0 +1,20 @@
"""Structured logging components for Dify."""
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
from core.logging.filters import IdentityContextFilter, TraceContextFilter
from core.logging.structured_formatter import StructuredJSONFormatter
__all__ = [
"IdentityContextFilter",
"StructuredJSONFormatter",
"TraceContextFilter",
"clear_request_context",
"get_request_id",
"get_trace_id",
"init_request_context",
]

View File

@ -0,0 +1,35 @@
"""Request context for logging - framework agnostic.
This module provides request-scoped context variables for logging,
using Python's contextvars for thread-safe and async-safe storage.
"""
import uuid
from contextvars import ContextVar
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
def get_request_id() -> str:
"""Get current request ID (10 hex chars)."""
return _request_id.get()
def get_trace_id() -> str:
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
return _trace_id.get()
def init_request_context() -> None:
"""Initialize request context. Call at start of each request."""
req_id = uuid.uuid4().hex[:10]
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
_request_id.set(req_id)
_trace_id.set(trace_id)
def clear_request_context() -> None:
"""Clear request context. Call at end of request (optional)."""
_request_id.set("")
_trace_id.set("")

View File

@ -0,0 +1,94 @@
"""Logging filters for structured logging."""
import contextlib
import logging
import flask
from core.logging.context import get_request_id, get_trace_id
class TraceContextFilter(logging.Filter):
"""
Filter that adds trace_id and span_id to log records.
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# For backward compatibility, also set req_id
record.req_id = get_request_id()
return True
def _get_otel_context(self) -> tuple[str, str]:
"""Extract trace_id and span_id from OpenTelemetry context."""
with contextlib.suppress(Exception):
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
if span and span.get_span_context():
ctx = span.get_span_context()
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
trace_id = f"{ctx.trace_id:032x}"
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
return trace_id, span_id
return "", ""
class IdentityContextFilter(logging.Filter):
"""
Filter that adds user identity context to log records.
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
return {}
from flask_login import current_user
# Check if user is authenticated using the proxy
if not current_user.is_authenticated:
return {}
# Access the underlying user object
user = current_user
from models import Account
from models.model import EndUser
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:
identity["tenant_id"] = user.current_tenant_id
identity["user_id"] = user.id
identity["user_type"] = "account"
elif isinstance(user, EndUser):
identity["tenant_id"] = user.tenant_id
identity["user_id"] = user.id
identity["user_type"] = user.type or "end_user"
return identity
except Exception:
return {}

View File

@ -0,0 +1,107 @@
"""Structured JSON log formatter for Dify."""
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
import orjson
from configs import dify_config
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
{
"ts": "ISO 8601 UTC",
"severity": "INFO|ERROR|WARN|DEBUG",
"service": "service name",
"caller": "file:line",
"trace_id": "hex 32",
"span_id": "hex 16",
"identity": { "tenant_id", "user_id", "user_type" },
"message": "log message",
"attributes": { ... },
"stack_trace": "..."
}
"""
SEVERITY_MAP: dict[int, str] = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR",
logging.CRITICAL: "ERROR",
}
def __init__(self, service_name: str | None = None):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:
return orjson.dumps(log_dict).decode("utf-8")
except TypeError:
# Fallback: convert non-serializable objects to string
import json
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
# Core fields
log_dict: dict[str, Any] = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,
"caller": f"{record.filename}:{record.lineno}",
"message": record.getMessage(),
}
# Trace context (from TraceContextFilter)
trace_id = getattr(record, "trace_id", "")
span_id = getattr(record, "span_id", "")
if trace_id:
log_dict["trace_id"] = trace_id
if span_id:
log_dict["span_id"] = span_id
# Identity context (from IdentityContextFilter)
identity = self._extract_identity(record)
if identity:
log_dict["identity"] = identity
# Dynamic attributes
attributes = getattr(record, "attributes", None)
if attributes:
log_dict["attributes"] = attributes
# Stack trace for errors with exceptions
if record.exc_info and record.levelno >= logging.ERROR:
log_dict["stack_trace"] = self._format_exception(record.exc_info)
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:
identity["user_id"] = user_id
if user_type:
identity["user_type"] = user_type
return identity
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
if exc_info and exc_info[0] is not None:
return "".join(traceback.format_exception(*exc_info))
return ""

View File

@ -103,6 +103,9 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
# Inject traceparent header for distributed tracing
self._inject_trace_headers(prepared_headers)
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@ -114,6 +117,31 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
"""
Inject W3C traceparent header for distributed tracing.
This ensures trace context is propagated to plugin daemon even if
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
"""
if not dify_config.ENABLE_OTEL:
return
import contextlib
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
with contextlib.suppress(Exception):
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
def _stream_request(
self,
method: str,

View File

@ -515,6 +515,7 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
dataset_count = len(available_datasets)
with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []
@ -537,6 +538,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
@ -562,6 +564,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
@ -1422,6 +1425,7 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
dataset_count: int,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
):
@ -1470,37 +1474,38 @@ class DatasetRetrieval:
if cancel_event and cancel_event.is_set():
break
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()

View File

@ -1,8 +1,7 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@ -20,9 +20,41 @@ from .exc import (
OutputValidationError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
Python3CodeProvider,
JavascriptCodeProvider,
)
_limits: CodeNodeLimits
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
)
self._limits = code_limits
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
code_provider: type[CodeNodeProvider] = next(
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
)
return code_provider.get_default_config()
@classmethod
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
return cls._DEFAULT_CODE_PROVIDERS
@classmethod
def version(cls) -> str:
return "1"
@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
_ = self._select_code_provider(code_language)
result = self._code_executor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
for provider in self._code_providers:
if provider.is_accept_language(code_language):
return provider
raise CodeNodeError(f"Unsupported code language: {code_language}")
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
if len(value) > self._limits.max_string_length:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
f" less than {self._limits.max_string_length} characters"
)
return value.replace("\x00", "")
@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
if value > self._limits.max_number or value < self._limits.min_number:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
if precision > dify_config.CODE_MAX_PRECISION:
if precision > self._limits.max_precision:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
f" it must be less than {self._limits.max_precision} digits."
)
return value
@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
if depth > self._limits.max_depth:
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
if len(value) > self._limits.max_number_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
f" less than {self._limits.max_number_array_length} elements."
)
for i, inner_value in enumerate(value):
@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_string_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
f" less than {self._limits.max_string_array_length} elements."
)
transformed_result[output_name] = [
@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_object_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
f" less than {self._limits.max_object_array_length} elements."
)
for i, value in enumerate(result[output_name]):

View File

@ -0,0 +1,13 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class CodeNodeLimits:
max_string_length: int
max_number: int | float
min_number: int | float
max_precision: int
max_depth: int
max_number_array_length: int
max_string_array_length: int
max_object_array_length: int

View File

@ -1,10 +1,21 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory):
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
)
self._code_limits = code_limits or CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -72,6 +103,26 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod

View File

@ -46,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
from core.logging.context import init_request_context
with app.app_context():
# Initialize logging context for this task (similar to before_request in Flask)
init_request_context()
return self.run(*args, **kwargs)
broker_transport_options = {}

View File

@ -11,6 +11,7 @@ def init_app(app: DifyApp):
create_tenant,
extract_plugins,
extract_unique_plugins,
file_usage,
fix_app_site_missing,
install_plugins,
install_rag_pipeline_plugins,
@ -47,6 +48,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
file_usage,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
cleanup_orphaned_draft_variables,

View File

@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
def init_app(app: DifyApp):
db.init_app(app)
_setup_gevent_compatibility()
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
try:
with app.app_context():
_ = db.engine # triggers engine creation with the configured options
except Exception:
logger.exception("Failed to initialize SQLAlchemy engine during app startup")

View File

@ -1,18 +1,19 @@
"""Logging extension for Dify Flask application."""
import logging
import os
import sys
import uuid
from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
"""Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
# File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@ -25,27 +26,53 @@ def init_app(app: DifyApp):
)
)
# Always add StreamHandler to log to console
# Console handler
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(RequestIdFilter())
# Apply filters to all handlers
from core.logging.filters import IdentityContextFilter, TraceContextFilter
for handler in log_handlers:
handler.addFilter(TraceContextFilter())
handler.addFilter(IdentityContextFilter())
# Configure formatter based on format type
formatter = _create_formatter()
for handler in log_handlers:
handler.setFormatter(formatter)
# Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
# Apply timezone if specified (only for text format)
if dify_config.LOG_OUTPUT_FORMAT == "text":
_apply_timezone(log_handlers)
def _create_formatter() -> logging.Formatter:
"""Create appropriate formatter based on configuration."""
if dify_config.LOG_OUTPUT_FORMAT == "json":
from core.logging.structured_formatter import StructuredJSONFormatter
return StructuredJSONFormatter()
else:
# Text format - use existing pattern with backward compatible formatter
return _TextFormatter(
fmt=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
)
def _apply_timezone(handlers: list[logging.Handler]):
"""Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@ -57,34 +84,51 @@ def init_app(app: DifyApp):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers:
for handler in handlers:
if handler.formatter:
handler.formatter.converter = time_converter
handler.formatter.converter = time_converter # type: ignore[attr-defined]
def get_request_id():
if getattr(flask.g, "request_id", None):
return flask.g.request_id
class _TextFormatter(logging.Formatter):
"""Text formatter that ensures trace_id and req_id are always present."""
new_uuid = uuid.uuid4().hex[:10]
flask.g.request_id = new_uuid
return new_uuid
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
if not hasattr(record, "span_id"):
record.span_id = ""
return super().format(record)
def get_request_id() -> str:
"""Get request ID for current request context.
Deprecated: Use core.logging.context.get_request_id() directly.
"""
from core.logging.context import get_request_id as _get_request_id
return _get_request_id()
# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
# This is a logging filter that makes the request ID available for use in
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
def filter(self, record: logging.LogRecord) -> bool:
from core.logging.context import get_request_id as _get_request_id
from core.logging.context import get_trace_id as _get_trace_id
record.req_id = _get_request_id()
record.trace_id = _get_trace_id()
return True
class RequestIdFormatter(logging.Formatter):
def format(self, record):
"""Deprecated: Use _TextFormatter instead."""
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
"""Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

View File

@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@ -108,9 +120,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
(
"graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
if domain_model.graph
else "{}",
),
(
"inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
if domain_model.outputs
else "{}",
),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),

View File

@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
Unlike creating a new span, this records exceptions on the existing span
to maintain trace context consistency throughout the request lifecycle.
"""
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
) as span:
span.set_status(StatusCode.ERROR)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
span.set_attribute("exception.message", str(record.exc_info[1]))
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
if not record.exc_info:
return
from opentelemetry.trace import get_current_span
span = get_current_span()
if not span or not span.is_recording():
return
# Record exception on the current span instead of creating a new one
span.set_status(StatusCode.ERROR, record.getMessage())
# Add log context as span events/attributes
span.add_event(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:

View File

@ -1,236 +1,338 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
from datetime import datetime
from typing import Any, TypeAlias
from .raws import FilesContainedField
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.file import File
JSONValue: TypeAlias = Any
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]["text"] if value else ""
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
feedback_fields = {
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_fields, allow_null=True),
}
class MessageFile(ResponseModel):
id: str
filename: str
type: str
url: str | None = None
mime_type: str | None = None
size: int | None = None
transfer_method: str
belongs_to: str | None = None
upload_file_id: str | None = None
annotation_fields = {
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
annotation_hit_history_fields = {
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
message_file_fields = {
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
}
@field_validator("transfer_method", mode="before")
@classmethod
def _normalize_transfer_method(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
def build_message_file_model(api_or_ns: Namespace):
"""Build the message file fields for the API or Namespace."""
return api_or_ns.model("MessageFile", message_file_fields)
class SimpleConversation(ResponseModel):
id: str
name: str
inputs: dict[str, JSONValue]
status: str
introduction: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
message_detail_fields = {
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_fields)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
}
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
model_config_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"model": fields.Raw,
"user_input_form": fields.Raw,
"pre_prompt": fields.String,
"agent_mode": fields.Raw,
}
simple_model_config_fields = {
"model": fields.Raw(attribute="model_dict"),
"pre_prompt": fields.String,
}
simple_message_detail_fields = {
"inputs": FilesContainedField,
"query": fields.String,
"message": MessageTextField,
"answer": fields.String,
}
conversation_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String(),
"from_account_id": fields.String,
"from_account_name": fields.String,
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"model_config": fields.Nested(simple_model_config_fields),
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
}
conversation_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_fields), attribute="items"),
}
conversation_message_detail_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"model_config": fields.Nested(model_config_fields),
"message": fields.Nested(message_detail_fields, attribute="first_message"),
}
conversation_with_summary_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String,
"from_account_id": fields.String,
"from_account_name": fields.String,
"name": fields.String,
"summary": fields.String(attribute="summary_or_query"),
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"model_config": fields.Nested(simple_model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"status_count": fields.Nested(status_count_fields),
}
conversation_with_summary_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
}
conversation_detail_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"introduction": fields.String,
"model_config": fields.Nested(model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
}
simple_conversation_fields = {
"id": fields.String,
"name": fields.String,
"inputs": FilesContainedField,
"status": fields.String,
"introduction": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}
conversation_delete_fields = {
"result": fields.String,
}
conversation_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(simple_conversation_fields)),
}
class ConversationInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[SimpleConversation]
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
simple_conversation_model = build_simple_conversation_model(api_or_ns)
copied_fields = conversation_infinite_scroll_pagination_fields.copy()
copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
class ConversationDelete(ResponseModel):
result: str
def build_conversation_delete_model(api_or_ns: Namespace):
"""Build the conversation delete model for the API or Namespace."""
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
class ResultResponse(ResponseModel):
result: str
def build_simple_conversation_model(api_or_ns: Namespace):
"""Build the simple conversation model for the API or Namespace."""
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
class SimpleAccount(ResponseModel):
id: str
name: str
email: str
class Feedback(ResponseModel):
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account: SimpleAccount | None = None
class Annotation(ResponseModel):
id: str
question: str | None = None
content: str
account: SimpleAccount | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class AnnotationHitHistory(ResponseModel):
annotation_id: str
annotation_create_account: SimpleAccount | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class AgentThought(ResponseModel):
id: str
chain_id: str | None = None
message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
message_id: str
position: int
thought: str | None = None
tool: str | None = None
tool_labels: JSONValue
tool_input: str | None = None
created_at: int | None = None
observation: str | None = None
files: list[str]
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
@model_validator(mode="after")
def _fallback_chain_id(self):
if self.chain_id is None and self.message_chain_id:
self.chain_id = self.message_chain_id
return self
class MessageDetail(ResponseModel):
id: str
conversation_id: str
inputs: dict[str, JSONValue]
query: str
message: JSONValue
message_tokens: int
answer: str
answer_tokens: int
provider_response_latency: float
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
feedbacks: list[Feedback]
workflow_run_id: str | None = None
annotation: Annotation | None = None
annotation_hit_history: AnnotationHitHistory | None = None
created_at: int | None = None
agent_thoughts: list[AgentThought]
message_files: list[MessageFile]
metadata: JSONValue
status: str
error: str | None = None
parent_message_id: str | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class FeedbackStat(ResponseModel):
like: int
dislike: int
class StatusCount(ResponseModel):
success: int
failed: int
partial_success: int
class ModelConfig(ResponseModel):
opening_statement: str | None = None
suggested_questions: JSONValue | None = None
model: JSONValue | None = None
user_input_form: JSONValue | None = None
pre_prompt: str | None = None
agent_mode: JSONValue | None = None
class SimpleModelConfig(ResponseModel):
model: JSONValue | None = None
pre_prompt: str | None = None
class SimpleMessageDetail(ResponseModel):
inputs: dict[str, JSONValue]
query: str
message: str
answer: str
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
class Conversation(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_end_user_session_id: str | None = None
from_account_id: str | None = None
from_account_name: str | None = None
read_at: int | None = None
created_at: int | None = None
updated_at: int | None = None
annotation: Annotation | None = None
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
message: SimpleMessageDetail | None = None
class ConversationPagination(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[Conversation]
class ConversationMessageDetail(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: int | None = None
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
message: MessageDetail | None = None
class ConversationWithSummary(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_end_user_session_id: str | None = None
from_account_id: str | None = None
from_account_name: str | None = None
name: str
summary: str
read_at: int | None = None
created_at: int | None = None
updated_at: int | None = None
annotated: bool
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
message_count: int
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
status_count: StatusCount | None = None
class ConversationWithSummaryPagination(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationWithSummary]
class ConversationDetail(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: int | None = None
updated_at: int | None = None
annotated: bool
introduction: str | None = None
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
message_count: int
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
def to_timestamp(value: datetime | None) -> int | None:
if value is None:
return None
return int(value.timestamp())
def format_files_contained(value: JSONValue) -> JSONValue:
if isinstance(value, File):
return value.model_dump()
if isinstance(value, dict):
return {k: format_files_contained(v) for k, v in value.items()}
if isinstance(value, list):
return [format_files_contained(v) for v in value]
return value
def message_text(value: JSONValue) -> str:
if isinstance(value, list) and value:
first = value[0]
if isinstance(first, dict):
text = first.get("text")
if isinstance(text, str):
return text
return ""
def extract_model_config(value: object | None) -> dict[str, JSONValue]:
if value is None:
return {}
if isinstance(value, dict):
return value
if hasattr(value, "to_dict"):
return value.to_dict()
return {}

View File

@ -1,77 +1,137 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
from datetime import datetime
from typing import TypeAlias
from .raws import FilesContainedField
from pydantic import BaseModel, ConfigDict, Field, field_validator
feedback_fields = {
"rating": fields.String,
}
from core.file import File
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
JSONValueType: TypeAlias = JSONValue
def build_feedback_model(api_or_ns: Namespace):
"""Build the feedback model for the API or Namespace."""
return api_or_ns.model("Feedback", feedback_fields)
class ResponseModel(BaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
class SimpleFeedback(ResponseModel):
rating: str | None = None
def build_agent_thought_model(api_or_ns: Namespace):
"""Build the agent thought model for the API or Namespace."""
return api_or_ns.model("AgentThought", agent_thought_fields)
class RetrieverResource(ResponseModel):
id: str
message_id: str
position: int
dataset_id: str | None = None
dataset_name: str | None = None
document_id: str | None = None
document_name: str | None = None
data_source_type: str | None = None
segment_id: str | None = None
score: float | None = None
hit_count: int | None = None
word_count: int | None = None
segment_position: int | None = None
index_node_hash: str | None = None
content: str | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
retriever_resource_fields = {
"id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"dataset_id": fields.String,
"dataset_name": fields.String,
"document_id": fields.String,
"document_name": fields.String,
"data_source_type": fields.String,
"segment_id": fields.String,
"score": fields.Float,
"hit_count": fields.Integer,
"word_count": fields.Integer,
"segment_position": fields.Integer,
"index_node_hash": fields.String,
"content": fields.String,
"created_at": TimestampField,
}
class MessageListItem(ResponseModel):
id: str
conversation_id: str
parent_message_id: str | None = None
inputs: dict[str, JSONValueType]
query: str
answer: str = Field(validation_alias="re_sign_file_url_answer")
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
retriever_resources: list[RetrieverResource]
created_at: int | None = None
agent_thoughts: list[AgentThought]
message_files: list[MessageFile]
status: str
error: str | None = None
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields)),
"status": fields.String,
"error": fields.String,
}
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
return format_files_contained(value)
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class WebMessageListItem(MessageListItem):
metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
class MessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[MessageListItem]
class WebMessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[WebMessageListItem]
class SavedMessageItem(ResponseModel):
id: str
inputs: dict[str, JSONValueType]
query: str
answer: str
message_files: list[MessageFile]
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
created_at: int | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class SavedMessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[SavedMessageItem]
class SuggestedQuestionsResponse(ResponseModel):
data: list[str]
def to_timestamp(value: datetime | None) -> int | None:
if value is None:
return None
return int(value.timestamp())
def format_files_contained(value: JSONValueType) -> JSONValueType:
if isinstance(value, File):
return value.model_dump()
if isinstance(value, dict):
return {k: format_files_contained(v) for k, v in value.items()}
if isinstance(value, list):
return [format_files_contained(v) for v in value]
return value

View File

@ -1,5 +1,4 @@
import re
import sys
from collections.abc import Mapping
from typing import Any
@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
# Log stack
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = (None, None, None)
current_app.log_exception(exc_info)
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
# Explicit log_exception call removed to avoid duplicate log entries
return data, status_code

View File

@ -853,7 +853,7 @@ class TriggerProviderService:
"""
Create a subscription builder for rebuilding an existing subscription.
This method creates a builder pre-filled with data from the rebuild request,
This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub)
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
:param tenant_id: Tenant ID
@ -868,111 +868,50 @@ class TriggerProviderService:
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
with Session(db.engine, expire_on_commit=False) as session:
try:
# Get subscription within the transaction
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}:
raise ValueError(f"Credential type {credential_type} not supported for auto creation")
# Decrypt existing credentials for merging
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
# Delete the previous subscription
user_id = subscription.user_id
unsubscribe_result = TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=subscription.credentials,
credential_type=credential_type,
)
if not unsubscribe_result.success:
raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}")
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
merged_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
user_id = subscription.user_id
# TODO: Trying to invoke update api of the plugin trigger provider
# FALLBACK: If the update api is not implemented,
# delete the previous subscription and create a new one
# Unsubscribe the previous subscription (external call, but we'll handle errors)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=decrypted_credentials,
credential_type=credential_type,
)
except Exception as e:
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
# Continue anyway - the subscription might already be deleted externally
# Create a new subscription with the same subscription_id and endpoint_id (external call)
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=merged_credentials,
credential_type=credential_type,
)
# Update the subscription in the same transaction
# Inline update logic to reuse the same session
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing and existing.id != subscription.id:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update parameters
subscription.parameters = dict(parameters)
# Update credentials with merged (and encrypted) values
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
# Update properties
if new_subscription.properties:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
# Update expiration timestamp
if new_subscription.expires_at is not None:
subscription.expires_at = new_subscription.expires_at
# Commit the transaction
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
except Exception as e:
# Rollback on any error
session.rollback()
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
raise
# Create a new subscription with the same subscription_id and endpoint_id
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=new_credentials,
credential_type=credential_type,
)
TriggerProviderService.update_trigger_subscription(
tenant_id=tenant_id,
subscription_id=subscription.id,
name=name,
parameters=parameters,
credentials=new_credentials,
properties=new_subscription.properties,
expires_at=new_subscription.expires_at,
)

View File

@ -10,6 +10,7 @@ from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@ -67,6 +68,16 @@ def init_code_node(code_config: dict):
config=code_config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_limits=CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
),
)
return node

View File

@ -474,64 +474,6 @@ class TestTriggerProviderService:
assert subscription.name == original_name
assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that unsubscribe errors are handled gracefully and operation continues.
This test verifies:
- Unsubscribe errors are caught and logged but don't stop the rebuild
- Rebuild continues even if unsubscribe fails
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
original_credentials = {"api_key": "original-key"}
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
original_credentials,
mock_external_service_dependencies,
)
# Make unsubscribe_trigger raise an error (should be caught and continue)
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
"Unsubscribe failed"
)
new_subscription_entity = TriggerSubscriptionEntity(
endpoint=subscription.endpoint_id,
parameters={},
properties={},
expires_at=-1,
)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
# Execute rebuild - should succeed despite unsubscribe error
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={"api_key": "new-key"},
parameters={},
)
# Verify subscribe was still called (operation continued)
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
# Verify subscription was updated
db.session.refresh(subscription)
assert subscription.parameters == {}
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
@ -558,70 +500,6 @@ class TestTriggerProviderService:
parameters={},
)
def test_rebuild_trigger_subscription_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when provider is not found.
This test verifies:
- Proper error is raised when provider doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
# Make get_trigger_provider return None
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
with pytest.raises(ValueError, match="Provider.*not found"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=fake.uuid4(),
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_unsupported_credential_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error when credential type is not supported for rebuild.
This test verifies:
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.UNAUTHORIZED # Not supported
subscription = self._create_test_subscription(
db_session_with_containers,
tenant.id,
account.id,
provider_id,
credential_type,
{},
mock_external_service_dependencies,
)
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=tenant.id,
provider_id=provider_id,
subscription_id=subscription.id,
credentials={},
parameters={},
)
def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -0,0 +1,145 @@
"""
Test for document detail API data_source_info serialization fix.
This test verifies that the document detail API returns both data_source_info
and data_source_detail_dict for all data_source_type values, including "local_file".
"""
import json
from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union
from models.dataset import Document
class LocalFileInfo(TypedDict):
file_path: str
size: int
created_at: NotRequired[str]
class UploadFileInfo(TypedDict):
upload_file_id: str
class NotionImportInfo(TypedDict):
notion_page_id: str
workspace_id: str
class WebsiteCrawlInfo(TypedDict):
url: str
job_id: str
RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]
T_type = TypeVar("T_type", bound=str)
T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo])
class Case(TypedDict, Generic[T_type, T_info]):
data_source_type: T_type
data_source_info: str
expected_raw: T_info
LocalFileCase = Case[Literal["local_file"], LocalFileInfo]
UploadFileCase = Case[Literal["upload_file"], UploadFileInfo]
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase]
case_1: LocalFileCase = {
"data_source_type": "local_file",
"data_source_info": json.dumps({"file_path": "/tmp/test.txt", "size": 1024}),
"expected_raw": {"file_path": "/tmp/test.txt", "size": 1024},
}
# ERROR: Expected LocalFileInfo, but got WebsiteCrawlInfo
case_2: LocalFileCase = {
"data_source_type": "local_file",
"data_source_info": "...",
"expected_raw": {"file_path": "https://google.com", "size": 123},
}
cases: list[AnyCase] = [case_1]
class TestDocumentDetailDataSourceInfo:
"""Test cases for document detail API data_source_info serialization."""
def test_data_source_info_dict_returns_raw_data(self):
"""Test that data_source_info_dict returns raw JSON data for all data_source_type values."""
# Test data for different data_source_type values
for case in cases:
document = Document(
data_source_type=case["data_source_type"],
data_source_info=case["data_source_info"],
)
# Test data_source_info_dict (raw data)
raw_result = document.data_source_info_dict
assert raw_result == case["expected_raw"], f"Failed for {case['data_source_type']}"
# Verify raw_result is always a valid dict
assert isinstance(raw_result, dict)
def test_local_file_data_source_info_without_db_context(self):
"""Test that local_file type data_source_info_dict works without database context."""
test_data: LocalFileInfo = {
"file_path": "/local/path/document.txt",
"size": 512,
"created_at": "2024-01-01T00:00:00Z",
}
document = Document(
data_source_type="local_file",
data_source_info=json.dumps(test_data),
)
# data_source_info_dict should return the raw data (this doesn't need DB context)
raw_data = document.data_source_info_dict
assert raw_data == test_data
assert isinstance(raw_data, dict)
# Verify the data contains expected keys for pipeline mode
assert "file_path" in raw_data
assert "size" in raw_data
def test_notion_and_website_crawl_data_source_detail(self):
"""Test that notion_import and website_crawl return raw data in data_source_detail_dict."""
# Test notion_import
notion_data: NotionImportInfo = {"notion_page_id": "page-123", "workspace_id": "ws-456"}
document = Document(
data_source_type="notion_import",
data_source_info=json.dumps(notion_data),
)
# data_source_detail_dict should return raw data for notion_import
detail_result = document.data_source_detail_dict
assert detail_result == notion_data
# Test website_crawl
website_data: WebsiteCrawlInfo = {"url": "https://example.com", "job_id": "job-789"}
document = Document(
data_source_type="website_crawl",
data_source_info=json.dumps(website_data),
)
# data_source_detail_dict should return raw data for website_crawl
detail_result = document.data_source_detail_dict
assert detail_result == website_data
def test_local_file_data_source_detail_dict_without_db(self):
"""Test that local_file returns empty data_source_detail_dict (this doesn't need DB context)."""
# Test local_file - this should work without database context since it returns {} early
document = Document(
data_source_type="local_file",
data_source_info=json.dumps({"file_path": "/tmp/test.txt"}),
)
# Should return empty dict for local_file type (handled in the model)
detail_result = document.data_source_detail_dict
assert detail_result == {}

View File

@ -0,0 +1,174 @@
"""Unit tests for controllers.web.message message list mapping."""
from __future__ import annotations
import builtins
from datetime import datetime
from types import ModuleType, SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from flask import Flask
from flask.views import MethodView
# Ensure flask_restx.api finds MethodView during import.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_controller_module():
"""Import controllers.web.message using a stub package."""
import importlib
import importlib.util
import sys
parent_module_name = "controllers.web"
module_name = f"{parent_module_name}.message"
if parent_module_name not in sys.modules:
from flask_restx import Namespace
stub = ModuleType(parent_module_name)
stub.__file__ = "controllers/web/__init__.py"
stub.__path__ = ["controllers/web"]
stub.__package__ = "controllers"
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
stub.web_ns = Namespace("web", description="Web API", path="/")
sys.modules[parent_module_name] = stub
wraps_module_name = f"{parent_module_name}.wraps"
if wraps_module_name not in sys.modules:
wraps_stub = ModuleType(wraps_module_name)
class WebApiResource:
pass
wraps_stub.WebApiResource = WebApiResource
sys.modules[wraps_module_name] = wraps_stub
return importlib.import_module(module_name)
message_module = _load_controller_module()
MessageListApi = message_module.MessageListApi
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
def test_message_list_mapping(app: Flask) -> None:
conversation_id = str(uuid4())
message_id = str(uuid4())
created_at = datetime(2024, 1, 1, 12, 0, 0)
resource_created_at = datetime(2024, 1, 1, 13, 0, 0)
thought_created_at = datetime(2024, 1, 1, 14, 0, 0)
retriever_resource_obj = SimpleNamespace(
id="res-obj",
message_id=message_id,
position=2,
dataset_id="ds-1",
dataset_name="dataset",
document_id="doc-1",
document_name="document",
data_source_type="file",
segment_id="seg-1",
score=0.9,
hit_count=1,
word_count=10,
segment_position=0,
index_node_hash="hash",
content="content",
created_at=resource_created_at,
)
agent_thought = SimpleNamespace(
id="thought-1",
chain_id=None,
message_chain_id="chain-1",
message_id=message_id,
position=1,
thought="thinking",
tool="tool",
tool_labels={"label": "value"},
tool_input="{}",
created_at=thought_created_at,
observation="observed",
files=["file-a"],
)
message_file_obj = SimpleNamespace(
id="file-obj",
filename="b.txt",
type="file",
url=None,
mime_type=None,
size=None,
transfer_method="local",
belongs_to=None,
upload_file_id=None,
)
message = SimpleNamespace(
id=message_id,
conversation_id=conversation_id,
parent_message_id=None,
inputs={"foo": "bar"},
query="hello",
re_sign_file_url_answer="answer",
user_feedback=SimpleNamespace(rating="like"),
retriever_resources=[
{"id": "res-dict", "message_id": message_id, "position": 1},
retriever_resource_obj,
],
created_at=created_at,
agent_thoughts=[agent_thought],
message_files=[
{"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"},
message_file_obj,
],
status="success",
error=None,
message_metadata_dict={"meta": "value"},
)
pagination = SimpleNamespace(limit=20, has_more=False, data=[message])
app_model = SimpleNamespace(mode="chat")
end_user = SimpleNamespace()
with (
patch.object(message_module.MessageService, "pagination_by_first_id", return_value=pagination) as mock_page,
app.test_request_context(f"/messages?conversation_id={conversation_id}&limit=20"),
):
response = MessageListApi().get(app_model, end_user)
mock_page.assert_called_once_with(app_model, end_user, conversation_id, None, 20)
assert response["limit"] == 20
assert response["has_more"] is False
assert len(response["data"]) == 1
item = response["data"][0]
assert item["id"] == message_id
assert item["conversation_id"] == conversation_id
assert item["inputs"] == {"foo": "bar"}
assert item["answer"] == "answer"
assert item["feedback"]["rating"] == "like"
assert item["metadata"] == {"meta": "value"}
assert item["created_at"] == int(created_at.timestamp())
assert item["retriever_resources"][0]["id"] == "res-dict"
assert item["retriever_resources"][1]["id"] == "res-obj"
assert item["retriever_resources"][1]["created_at"] == int(resource_created_at.timestamp())
assert item["agent_thoughts"][0]["chain_id"] == "chain-1"
assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp())
assert item["message_files"][0]["id"] == "file-dict"
assert item["message_files"][1]["id"] == "file-obj"

View File

@ -14,12 +14,12 @@ def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
mock_client.request.assert_called_once()
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@ -27,7 +27,6 @@ def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
@ -72,34 +71,12 @@ class TestGetUserProvidedHostHeader:
assert result in ("first.com", "second.com")
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_without_user_header(mock_get_client):
"""Test that when no Host header is provided, the default behavior is maintained."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
# Host should not be set if not provided by user
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_with_user_header(mock_get_client):
"""Test that user-provided Host header is preserved in the request."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
@ -107,3 +84,93 @@ def test_host_header_preservation_with_user_header(mock_get_client):
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
# Verify client.request was called with the host header preserved (lowercase)
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs["headers"]["host"] == custom_host
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"])
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
"""Test that Host header is preserved regardless of case."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
assert response.status_code == 200
# Host header should be normalized to lowercase "host"
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs["headers"]["host"] == "api.example.com"
class TestFollowRedirectsParameter:
"""Tests for follow_redirects parameter handling.
These tests verify that follow_redirects is correctly passed to client.request().
"""
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_passed_to_request(self, mock_get_client):
"""Verify follow_redirects IS passed to client.request()."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
make_request("GET", "http://example.com", follow_redirects=True)
# Verify follow_redirects was passed to request
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
# Use allow_redirects (requests-style parameter)
make_request("GET", "http://example.com", allow_redirects=True)
# Verify it was converted to follow_redirects
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
assert "allow_redirects" not in call_kwargs
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_not_set_when_not_specified(self, mock_get_client):
"""Verify follow_redirects is not in kwargs when not specified (httpx default behavior)."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
make_request("GET", "http://example.com")
# follow_redirects should not be in kwargs, letting httpx use its default
call_kwargs = mock_client.request.call_args.kwargs
assert "follow_redirects" not in call_kwargs
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
"""Verify follow_redirects takes precedence when both are specified."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
# Both specified - follow_redirects should take precedence
make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True)
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True

View File

@ -0,0 +1,79 @@
"""Tests for logging context module."""
import uuid
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
class TestLoggingContext:
"""Tests for the logging context functions."""
def test_init_creates_request_id(self):
"""init_request_context should create a 10-char request ID."""
init_request_context()
request_id = get_request_id()
assert len(request_id) == 10
assert all(c in "0123456789abcdef" for c in request_id)
def test_init_creates_trace_id(self):
"""init_request_context should create a 32-char trace ID."""
init_request_context()
trace_id = get_trace_id()
assert len(trace_id) == 32
assert all(c in "0123456789abcdef" for c in trace_id)
def test_trace_id_derived_from_request_id(self):
"""trace_id should be deterministically derived from request_id."""
init_request_context()
request_id = get_request_id()
trace_id = get_trace_id()
# Verify trace_id is derived using uuid5
expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
assert trace_id == expected_trace
def test_clear_resets_context(self):
"""clear_request_context should reset both IDs to empty strings."""
init_request_context()
assert get_request_id() != ""
assert get_trace_id() != ""
clear_request_context()
assert get_request_id() == ""
assert get_trace_id() == ""
def test_default_values_are_empty(self):
"""Default values should be empty strings before init."""
clear_request_context()
assert get_request_id() == ""
assert get_trace_id() == ""
def test_multiple_inits_create_different_ids(self):
"""Each init should create new unique IDs."""
init_request_context()
first_request_id = get_request_id()
first_trace_id = get_trace_id()
init_request_context()
second_request_id = get_request_id()
second_trace_id = get_trace_id()
assert first_request_id != second_request_id
assert first_trace_id != second_trace_id
def test_context_isolation(self):
"""Context should be isolated per-call (no thread leakage in same thread)."""
init_request_context()
id1 = get_request_id()
# Simulate another request
init_request_context()
id2 = get_request_id()
# IDs should be different
assert id1 != id2

View File

@ -0,0 +1,114 @@
"""Tests for logging filters."""
import logging
from unittest import mock
import pytest
@pytest.fixture
def log_record():
return logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="test",
args=(),
exc_info=None,
)
class TestTraceContextFilter:
def test_sets_empty_trace_id_without_context(self, log_record):
from core.logging.context import clear_request_context
from core.logging.filters import TraceContextFilter
# Ensure no context is set
clear_request_context()
filter = TraceContextFilter()
result = filter.filter(log_record)
assert result is True
assert hasattr(log_record, "trace_id")
assert hasattr(log_record, "span_id")
assert hasattr(log_record, "req_id")
# Without context, IDs should be empty
assert log_record.trace_id == ""
assert log_record.req_id == ""
def test_sets_trace_id_from_context(self, log_record):
"""Test that trace_id and req_id are set from ContextVar when initialized."""
from core.logging.context import init_request_context
from core.logging.filters import TraceContextFilter
# Initialize context (no Flask needed!)
init_request_context()
filter = TraceContextFilter()
filter.filter(log_record)
# With context initialized, IDs should be set
assert log_record.trace_id != ""
assert len(log_record.trace_id) == 32
assert log_record.req_id != ""
assert len(log_record.req_id) == 10
def test_filter_always_returns_true(self, log_record):
from core.logging.filters import TraceContextFilter
filter = TraceContextFilter()
result = filter.filter(log_record)
assert result is True
def test_sets_trace_id_from_otel_when_available(self, log_record):
from core.logging.filters import TraceContextFilter
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
filter = TraceContextFilter()
filter.filter(log_record)
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_record.span_id == "051581bf3bb55c45"
class TestIdentityContextFilter:
def test_sets_empty_identity_without_request_context(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
result = filter.filter(log_record)
assert result is True
assert log_record.tenant_id == ""
assert log_record.user_id == ""
assert log_record.user_type == ""
def test_filter_always_returns_true(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
result = filter.filter(log_record)
assert result is True
def test_handles_exception_gracefully(self, log_record):
from core.logging.filters import IdentityContextFilter
filter = IdentityContextFilter()
# Should not raise even if something goes wrong
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
result = filter.filter(log_record)
assert result is True
assert log_record.tenant_id == ""

View File

@ -0,0 +1,267 @@
"""Tests for structured JSON formatter."""
import logging
import sys
import orjson
class TestStructuredJSONFormatter:
def test_basic_log_format(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter(service_name="test-service")
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=42,
msg="Test message",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["severity"] == "INFO"
assert log_dict["service"] == "test-service"
assert log_dict["caller"] == "test.py:42"
assert log_dict["message"] == "Test message"
assert "ts" in log_dict
assert log_dict["ts"].endswith("Z")
def test_severity_mapping(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
test_cases = [
(logging.DEBUG, "DEBUG"),
(logging.INFO, "INFO"),
(logging.WARNING, "WARN"),
(logging.ERROR, "ERROR"),
(logging.CRITICAL, "ERROR"),
]
for level, expected_severity in test_cases:
record = logging.LogRecord(
name="test",
level=level,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}"
def test_error_with_stack_trace(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
try:
raise ValueError("Test error")
except ValueError:
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test",
level=logging.ERROR,
pathname="test.py",
lineno=10,
msg="Error occurred",
args=(),
exc_info=exc_info,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["severity"] == "ERROR"
assert "stack_trace" in log_dict
assert "ValueError: Test error" in log_dict["stack_trace"]
def test_no_stack_trace_for_info(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
try:
raise ValueError("Test error")
except ValueError:
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=10,
msg="Info message",
args=(),
exc_info=exc_info,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert "stack_trace" not in log_dict
def test_trace_context_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2"
record.span_id = "051581bf3bb55c45"
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2"
assert log_dict["span_id"] == "051581bf3bb55c45"
def test_identity_context_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.tenant_id = "t-global-corp"
record.user_id = "u-admin-007"
record.user_type = "admin"
output = formatter.format(record)
log_dict = orjson.loads(output)
assert "identity" in log_dict
assert log_dict["identity"]["tenant_id"] == "t-global-corp"
assert log_dict["identity"]["user_id"] == "u-admin-007"
assert log_dict["identity"]["user_type"] == "admin"
def test_no_identity_when_empty(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert "identity" not in log_dict
def test_attributes_included(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
record.attributes = {"order_id": "ord-123", "amount": 99.99}
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["attributes"]["order_id"] == "ord-123"
assert log_dict["attributes"]["amount"] == 99.99
def test_message_with_args(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="User %s logged in from %s",
args=("john", "192.168.1.1"),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
assert log_dict["message"] == "User john logged in from 192.168.1.1"
def test_timestamp_format(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test",
args=(),
exc_info=None,
)
output = formatter.format(record)
log_dict = orjson.loads(output)
# Verify ISO 8601 format with Z suffix
ts = log_dict["ts"]
assert ts.endswith("Z")
assert "T" in ts
# Should have milliseconds
assert "." in ts
def test_fallback_for_non_serializable_attributes(self):
from core.logging.structured_formatter import StructuredJSONFormatter
formatter = StructuredJSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test.py",
lineno=1,
msg="Test with non-serializable",
args=(),
exc_info=None,
)
# Set is not serializable by orjson
record.attributes = {"items": {1, 2, 3}, "custom": object()}
# Should not raise, fallback to json.dumps with default=str
output = formatter.format(record)
# Verify it's valid JSON (parsed by stdlib json since orjson may fail)
import json
log_dict = json.loads(output)
assert log_dict["message"] == "Test with non-serializable"
assert "attributes" in log_dict

View File

@ -0,0 +1,102 @@
"""Tests for trace helper functions."""
import re
from unittest import mock
class TestGetSpanIdFromOtelContext:
def test_returns_none_without_span(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = get_span_id_from_otel_context()
assert result is None
def test_returns_span_id_when_available(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
result = get_span_id_from_otel_context()
assert result == "051581bf3bb55c45"
def test_returns_none_on_exception(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
result = get_span_id_from_otel_context()
assert result is None
class TestGenerateTraceparentHeader:
def test_generates_valid_format(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = generate_traceparent_header()
assert result is not None
# Format: 00-{trace_id}-{span_id}-01
parts = result.split("-")
assert len(parts) == 4
assert parts[0] == "00" # version
assert len(parts[1]) == 32 # trace_id (32 hex chars)
assert len(parts[2]) == 16 # span_id (16 hex chars)
assert parts[3] == "01" # flags
def test_uses_otel_context_when_available(self):
from core.helper.trace_id_helper import generate_traceparent_header
mock_span = mock.MagicMock()
mock_context = mock.MagicMock()
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with (
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
result = generate_traceparent_header()
assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
def test_generates_hex_only_values(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
result = generate_traceparent_header()
parts = result.split("-")
# All parts should be valid hex
assert re.match(r"^[0-9a-f]+$", parts[1])
assert re.match(r"^[0-9a-f]+$", parts[2])
class TestParseTraceparentHeader:
def test_parses_valid_traceparent(self):
from core.helper.trace_id_helper import parse_traceparent_header
traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
result = parse_traceparent_header(traceparent)
assert result == "5b8aa5a2d2c872e8321cf37308d69df2"
def test_returns_none_for_invalid_format(self):
from core.helper.trace_id_helper import parse_traceparent_header
# Wrong number of parts
assert parse_traceparent_header("00-abc-def") is None
# Wrong trace_id length
assert parse_traceparent_header("00-abc-def-01") is None
def test_returns_none_for_empty_string(self):
from core.helper.trace_id_helper import parse_traceparent_header
assert parse_traceparent_header("") is None

View File

@ -73,6 +73,7 @@ import pytest
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from models.dataset import Dataset
@ -1518,6 +1519,282 @@ class TestRetrievalService:
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["reranking_model"] == reranking_model
# ==================== Multiple Retrieve Thread Tests ====================
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset(
self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
):
"""
Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1.
When there is only one dataset, the second reranking is unnecessary
because the documents are already ranked from the first retrieval.
This optimization avoids the overhead of reranking when it won't
provide any benefit.
Verifies:
- DataPostProcessor is NOT called when dataset_count == 1
- Documents are still added to all_documents
- Standard scoring logic is applied instead
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
tenant_id = str(uuid4())
# Create test documents
doc1 = Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
doc2 = Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
# Mock _retriever to return documents
def side_effect_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.extend([doc1, doc2])
mock_retriever.side_effect = side_effect_retriever
# Set up dataset with high_quality indexing
mock_dataset.indexing_technique = "high_quality"
all_documents = []
# Act - Call with dataset_count = 1
dataset_retrieval._multiple_retrieve_thread(
flask_app=mock_flask_app,
available_datasets=[mock_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
weights=None,
top_k=5,
score_threshold=0.5,
query="test query",
attachment_id=None,
dataset_count=1, # Single dataset - should skip second reranking
)
# Assert
# DataPostProcessor should NOT be called (second reranking skipped)
mock_data_processor_class.assert_not_called()
# Documents should still be added to all_documents
assert len(all_documents) == 2
assert all_documents[0].page_content == "Test content 1"
assert all_documents[1].page_content == "Test content 2"
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets(
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
):
"""
Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1.
When there are multiple datasets, the second reranking is necessary
to merge and re-rank results from different datasets. This ensures
the most relevant documents across all datasets are returned.
Verifies:
- DataPostProcessor IS called when dataset_count > 1
- Reranking is applied with correct parameters
- Documents are processed correctly
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
tenant_id = str(uuid4())
# Create test documents
doc1 = Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
doc2 = Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
# Mock _retriever to return documents
def side_effect_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.extend([doc1, doc2])
mock_retriever.side_effect = side_effect_retriever
# Set up dataset with high_quality indexing
mock_dataset.indexing_technique = "high_quality"
# Mock DataPostProcessor instance and its invoke method
mock_processor_instance = Mock()
# Simulate reranking - return documents in reversed order with updated scores
reranked_docs = [
Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
),
Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
),
]
mock_processor_instance.invoke.return_value = reranked_docs
mock_data_processor_class.return_value = mock_processor_instance
all_documents = []
# Create second dataset
mock_dataset2 = Mock(spec=Dataset)
mock_dataset2.id = str(uuid4())
mock_dataset2.indexing_technique = "high_quality"
mock_dataset2.provider = "dify"
# Act - Call with dataset_count = 2
dataset_retrieval._multiple_retrieve_thread(
flask_app=mock_flask_app,
available_datasets=[mock_dataset, mock_dataset2],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
weights=None,
top_k=5,
score_threshold=0.5,
query="test query",
attachment_id=None,
dataset_count=2, # Multiple datasets - should perform second reranking
)
# Assert
# DataPostProcessor SHOULD be called (second reranking performed)
mock_data_processor_class.assert_called_once_with(
tenant_id,
"reranking_model",
{"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
None,
False,
)
# Verify invoke was called with correct parameters
mock_processor_instance.invoke.assert_called_once()
# Documents should be added to all_documents after reranking
assert len(all_documents) == 2
# The reranked order should be reflected
assert all_documents[0].page_content == "Test content 2"
assert all_documents[1].page_content == "Test content 1"
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring(
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
):
"""
Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1
and reranking is enabled.
When there's only one dataset, instead of using DataPostProcessor,
the method should fall through to the standard scoring logic
(calculate_vector_score for high_quality datasets).
Verifies:
- DataPostProcessor is NOT called
- calculate_vector_score IS called for high_quality indexing
- Documents are scored correctly
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
tenant_id = str(uuid4())
# Create test documents
doc1 = Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
doc2 = Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
# Mock _retriever to return documents
def side_effect_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.extend([doc1, doc2])
mock_retriever.side_effect = side_effect_retriever
# Set up dataset with high_quality indexing
mock_dataset.indexing_technique = "high_quality"
# Mock calculate_vector_score to return scored documents
scored_docs = [
Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
),
]
mock_calculate_vector_score.return_value = scored_docs
all_documents = []
# Act - Call with dataset_count = 1
dataset_retrieval._multiple_retrieve_thread(
flask_app=mock_flask_app,
available_datasets=[mock_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True, # Reranking enabled but should be skipped for single dataset
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
weights=None,
top_k=5,
score_threshold=0.5,
query="test query",
attachment_id=None,
dataset_count=1,
)
# Assert
# DataPostProcessor should NOT be called
mock_data_processor_class.assert_not_called()
# calculate_vector_score SHOULD be called for high_quality datasets
mock_calculate_vector_score.assert_called_once()
call_args = mock_calculate_vector_score.call_args
assert call_args[0][1] == 5 # top_k
# Documents should be added after standard scoring
assert len(all_documents) == 1
assert all_documents[0].page_content == "Test content 1"
class TestRetrievalMethods:
"""

View File

@ -0,0 +1,113 @@
import threading
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from flask import Flask, current_app
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from models.dataset import Dataset
class TestRetrievalService:
@pytest.fixture
def mock_dataset(self) -> Dataset:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid4())
dataset.tenant_id = str(uuid4())
dataset.name = "test_dataset"
dataset.indexing_technique = "high_quality"
dataset.provider = "dify"
return dataset
def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset):
"""
Repro test for current bug:
reranking runs after `with flask_app.app_context():` exits.
`_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`,
so we must assert from that list (not from an outer try/except).
"""
dataset_retrieval = DatasetRetrieval()
flask_app = Flask(__name__)
tenant_id = str(uuid4())
# second dataset to ensure dataset_count > 1 reranking branch
secondary_dataset = Mock(spec=Dataset)
secondary_dataset.id = str(uuid4())
secondary_dataset.provider = "dify"
secondary_dataset.indexing_technique = "high_quality"
# retriever returns 1 doc into internal list (all_documents_item)
document = Document(
page_content="Context aware doc",
metadata={
"doc_id": "doc1",
"score": 0.95,
"document_id": str(uuid4()),
"dataset_id": mock_dataset.id,
},
provider="dify",
)
def fake_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.append(document)
called = {"init": 0, "invoke": 0}
class ContextRequiredPostProcessor:
def __init__(self, *args, **kwargs):
called["init"] += 1
# will raise RuntimeError if no Flask app context exists
_ = current_app.name
def invoke(self, *args, **kwargs):
called["invoke"] += 1
_ = current_app.name
return kwargs.get("documents") or args[1]
# output list from _multiple_retrieve_thread
all_documents: list[Document] = []
# IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here
thread_exceptions: list[Exception] = []
def target():
with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever):
with patch(
"core.rag.retrieval.dataset_retrieval.DataPostProcessor",
ContextRequiredPostProcessor,
):
dataset_retrieval._multiple_retrieve_thread(
flask_app=flask_app,
available_datasets=[mock_dataset, secondary_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={
"reranking_provider_name": "cohere",
"reranking_model_name": "rerank-v2",
},
weights=None,
top_k=3,
score_threshold=0.0,
query="test query",
attachment_id=None,
dataset_count=2, # force reranking branch
thread_exceptions=thread_exceptions, # ✅ key
)
t = threading.Thread(target=target)
t.start()
t.join()
# Ensure reranking branch was actually executed
assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run."
# Current buggy code should record an exception (not raise it)
assert not thread_exceptions, thread_exceptions

View File

@ -103,13 +103,25 @@ class MockNodeFactory(DifyNodeFactory):
# Create mock node instance
mock_class = self._mock_node_types[node_type]
mock_instance = mock_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
)
if node_type == NodeType.CODE:
mock_instance = mock_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
else:
mock_instance = mock_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
)
return mock_instance

View File

@ -40,12 +40,14 @@ class MockNodeMixin:
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
**kwargs,
)
self.mock_config = mock_config

View File

@ -5,11 +5,24 @@ This module tests the functionality of MockTemplateTransformNode and MockCodeNod
to ensure they work correctly with the TableTestRunner.
"""
from configs import dify_config
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes.code.limits import CodeNodeLimits
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode
DEFAULT_CODE_LIMITS = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
class TestMockTemplateTransformNode:
"""Test cases for MockTemplateTransformNode."""
@ -306,6 +319,7 @@ class TestMockCodeNode:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
code_limits=DEFAULT_CODE_LIMITS,
)
# Run the node
@ -370,6 +384,7 @@ class TestMockCodeNode:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
code_limits=DEFAULT_CODE_LIMITS,
)
# Run the node
@ -438,6 +453,7 @@ class TestMockCodeNode:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
code_limits=DEFAULT_CODE_LIMITS,
)
# Run the node

View File

@ -1,3 +1,4 @@
from configs import dify_config
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.nodes.code.code_node import CodeNode
@ -7,6 +8,18 @@ from core.workflow.nodes.code.exc import (
DepthLimitError,
OutputValidationError,
)
from core.workflow.nodes.code.limits import CodeNodeLimits
CodeNode._limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
class TestCodeNodeExceptions:

View File

@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.helper.code_executor.code_executor import CodeExecutionError
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.workflow import WorkflowType
@ -127,7 +127,9 @@ class TestTemplateTransformNode:
"""Test version class method."""
assert TemplateTransformNode.version() == "1"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_simple_template(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
@ -145,7 +147,7 @@ class TestTemplateTransformNode:
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
# Setup mock executor
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
mock_execute.return_value = "Hello Alice, you are 30 years old!"
node = TemplateTransformNode(
id="test_node",
@ -162,7 +164,9 @@ class TestTemplateTransformNode:
assert result.inputs["name"] == "Alice"
assert result.inputs["age"] == 30
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with None variable values."""
node_data = {
@ -172,7 +176,7 @@ class TestTemplateTransformNode:
}
mock_graph_runtime_state.variable_pool.get.return_value = None
mock_execute.return_value = {"result": "Value: "}
mock_execute.return_value = "Value: "
node = TemplateTransformNode(
id="test_node",
@ -187,13 +191,15 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.inputs["value"] is None
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_code_execution_error(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when code execution fails."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.side_effect = CodeExecutionError("Template syntax error")
mock_execute.side_effect = TemplateRenderError("Template syntax error")
node = TemplateTransformNode(
id="test_node",
@ -208,14 +214,16 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Template syntax error" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when output exceeds maximum length."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
mock_execute.return_value = "This is a very long output that exceeds the limit"
node = TemplateTransformNode(
id="test_node",
@ -230,7 +238,9 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Output length exceeds" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_complex_jinja2_template(
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
):
@ -257,7 +267,7 @@ class TestTemplateTransformNode:
("sys", "show_total"): mock_show_total,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
mock_execute.return_value = "apple, banana, orange (Total: 3)"
node = TemplateTransformNode(
id="test_node",
@ -292,7 +302,9 @@ class TestTemplateTransformNode:
assert mapping["node_123.var1"] == ["sys", "input1"]
assert mapping["node_123.var2"] == ["sys", "input2"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with no variables (static template)."""
node_data = {
@ -301,7 +313,7 @@ class TestTemplateTransformNode:
"template": "This is a static message.",
}
mock_execute.return_value = {"result": "This is a static message."}
mock_execute.return_value = "This is a static message."
node = TemplateTransformNode(
id="test_node",
@ -317,7 +329,9 @@ class TestTemplateTransformNode:
assert result.outputs["output"] == "This is a static message."
assert result.inputs == {}
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with numeric variable values."""
node_data = {
@ -339,7 +353,7 @@ class TestTemplateTransformNode:
("sys", "quantity"): mock_quantity,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "Total: $31.5"}
mock_execute.return_value = "Total: $31.5"
node = TemplateTransformNode(
id="test_node",
@ -354,7 +368,9 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "Total: $31.5"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with dictionary variable values."""
node_data = {
@ -367,7 +383,7 @@ class TestTemplateTransformNode:
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
mock_execute.return_value = "Name: John Doe, Email: john@example.com"
node = TemplateTransformNode(
id="test_node",
@ -383,7 +399,9 @@ class TestTemplateTransformNode:
assert "John Doe" in result.outputs["output"]
assert "john@example.com" in result.outputs["output"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with list variable values."""
node_data = {
@ -396,7 +414,7 @@ class TestTemplateTransformNode:
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
mock_execute.return_value = "Tags: #python #ai #workflow "
node = TemplateTransformNode(
id="test_node",

View File

@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite():
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
def test_external_api_param_mapping_and_quota_and_exc_info_none():
# Force exc_info() to return (None,None,None) only during request
import libs.external_api as ext
def test_external_api_param_mapping_and_quota():
app = _create_api_app()
client = app.test_client()
orig_exc_info = ext.sys.exc_info
try:
ext.sys.exc_info = lambda: (None, None, None)
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
app = _create_api_app()
client = app.test_client()
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
finally:
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
def test_unauthorized_and_force_logout_clears_cookies():

View File

@ -1953,14 +1953,14 @@ wheels = [
[[package]]
name = "fickling"
version = "0.1.5"
version = "0.1.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "stdlib-list" },
]
sdist = { url = "https://files.pythonhosted.org/packages/41/94/0d0ce455952c036cfee235637f786c1d1d07d1b90f6a4dfb50e0eff929d6/fickling-0.1.5.tar.gz", hash = "sha256:92f9b49e717fa8dbc198b4b7b685587adb652d85aa9ede8131b3e44494efca05", size = 282462, upload-time = "2025-11-18T05:04:30.748Z" }
sdist = { url = "https://files.pythonhosted.org/packages/07/ab/7571453f9365c17c047b5a7b7e82692a7f6be51203f295030886758fd57a/fickling-0.1.6.tar.gz", hash = "sha256:03cb5d7bd09f9169c7583d2079fad4b3b88b25f865ed0049172e5cb68582311d", size = 284033, upload-time = "2025-12-15T18:14:58.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bf/a7/d25912b2e3a5b0a37e6f460050bbc396042b5906a6563a1962c484abc3c6/fickling-0.1.5-py3-none-any.whl", hash = "sha256:6aed7270bfa276e188b0abe043a27b3a042129d28ec1fa6ff389bdcc5ad178bb", size = 46240, upload-time = "2025-11-18T05:04:29.048Z" },
{ url = "https://files.pythonhosted.org/packages/76/99/cc04258dda421bc612cdfe4be8c253f45b922f1c7f268b5a0b9962d9cd12/fickling-0.1.6-py3-none-any.whl", hash = "sha256:465d0069548bfc731bdd75a583cb4cf5a4b2666739c0f76287807d724b147ed3", size = 47922, upload-time = "2025-12-15T18:14:57.526Z" },
]
[[package]]
@ -2955,14 +2955,14 @@ wheels = [
[[package]]
name = "intersystems-irispython"
version = "5.3.0"
version = "5.3.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" },
{ url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" },
{ url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" },
{ url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" },
{ url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" },
{ url = "https://files.pythonhosted.org/packages/33/5b/8eac672a6ef26bef6ef79a7c9557096167b50c4d3577d558ae6999c195fe/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-macosx_10_9_universal2.whl", hash = "sha256:634c9b4ec620837d830ff49543aeb2797a1ce8d8570a0e868398b85330dfcc4d", size = 6736686, upload-time = "2025-12-19T16:24:57.734Z" },
{ url = "https://files.pythonhosted.org/packages/ba/17/bab3e525ffb6711355f7feea18c1b7dced9c2484cecbcdd83f74550398c0/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf912f30f85e2a42f2c2ea77fbeb98a24154d5ea7428a50382786a684ec4f583", size = 16005259, upload-time = "2025-12-19T16:25:05.578Z" },
{ url = "https://files.pythonhosted.org/packages/39/59/9bb79d9e32e3e55fc9aed8071a797b4497924cbc6457cea9255bb09320b7/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be5659a6bb57593910f2a2417eddb9f5dc2f93a337ead6ddca778f557b8a359a", size = 15638040, upload-time = "2025-12-19T16:24:54.429Z" },
{ url = "https://files.pythonhosted.org/packages/cf/47/654ccf9c5cca4f5491f070888544165c9e2a6a485e320ea703e4e38d2358/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win32.whl", hash = "sha256:583e4f17088c1e0530f32efda1c0ccb02993cbc22035bc8b4c71d8693b04ee7e", size = 2879644, upload-time = "2025-12-19T16:24:59.945Z" },
{ url = "https://files.pythonhosted.org/packages/68/95/19cc13d09f1b4120bd41b1434509052e1d02afd27f2679266d7ad9cc1750/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win_amd64.whl", hash = "sha256:1d5d40450a0cdeec2a1f48d12d946a8a8ffc7c128576fcae7d58e66e3a127eae", size = 3522092, upload-time = "2025-12-19T16:25:01.834Z" },
]
[[package]]

View File

@ -69,6 +69,8 @@ PYTHONIOENCODING=utf-8
# The log level for the application.
# Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
LOG_LEVEL=INFO
# Log output format: text or json
LOG_OUTPUT_FORMAT=text
# Log file path
LOG_FILE=/app/logs/server.log
# Log file max size, the unit is MB
@ -231,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
# You can adjust the database configuration according to your needs.
# ------------------------------
# Database type, supported values are `postgresql` and `mysql`
# Database type, supported values are `postgresql`, `mysql`, `oceanbase`, `seekdb`
DB_TYPE=postgresql
# For MySQL, only `root` user is supported for now
DB_USERNAME=postgres
@ -531,7 +533,7 @@ SUPABASE_URL=your-server-url
# ------------------------------
# The type of vector store to use.
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`.
# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -542,9 +544,9 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENDPOINT=grpc://weaviate:50051
WEAVIATE_TOKENIZATION=word
# For OceanBase metadata database configuration, available when `DB_TYPE` is `mysql` and `COMPOSE_PROFILES` includes `oceanbase`.
# For OceanBase metadata database configuration, available when `DB_TYPE` is `oceanbase`.
# For OceanBase vector database configuration, available when `VECTOR_STORE` is `oceanbase`
# If you want to use OceanBase as both vector database and metadata database, you need to set `DB_TYPE` to `mysql`, `COMPOSE_PROFILES` is `oceanbase`, and set Database Configuration is the same as the vector database.
# If you want to use OceanBase as both vector database and metadata database, you need to set both `DB_TYPE` and `VECTOR_STORE` to `oceanbase`, and set Database Configuration is the same as the vector database.
# seekdb is the lite version of OceanBase and shares the connection configuration with OceanBase.
OCEANBASE_VECTOR_HOST=oceanbase
OCEANBASE_VECTOR_PORT=2881

View File

@ -129,6 +129,7 @@ services:
- ./middleware.env
environment:
# Use the shared environment variables.
LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text}
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
REDIS_HOST: ${REDIS_HOST:-redis}
REDIS_PORT: ${REDIS_PORT:-6379}

View File

@ -17,6 +17,7 @@ x-shared-env: &shared-api-worker-env
LC_ALL: ${LC_ALL:-en_US.UTF-8}
PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8}
LOG_LEVEL: ${LOG_LEVEL:-INFO}
LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text}
LOG_FILE: ${LOG_FILE:-/app/logs/server.log}
LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20}
LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5}

View File

@ -54,3 +54,52 @@ http_access allow src_all
# Unless the option's size is increased, an error will occur when uploading more than two files.
client_request_buffer_max_size 100 MB
################################## Performance & Concurrency ###############################
# Increase file descriptor limit for high concurrency
max_filedescriptors 65536
# Timeout configurations for image requests
connect_timeout 30 seconds
request_timeout 2 minutes
read_timeout 2 minutes
client_lifetime 5 minutes
shutdown_lifetime 30 seconds
# Persistent connections - improve performance for multiple requests
server_persistent_connections on
client_persistent_connections on
persistent_request_timeout 30 seconds
pconn_timeout 1 minute
# Connection pool and concurrency limits
client_db on
server_idle_pconn_timeout 2 minutes
client_idle_pconn_timeout 2 minutes
# Quick abort settings - don't abort requests that are mostly done
quick_abort_min 16 KB
quick_abort_max 16 MB
quick_abort_pct 95
# Memory and cache optimization
memory_cache_mode disk
cache_mem 256 MB
maximum_object_size_in_memory 512 KB
# DNS resolver settings for better performance
dns_timeout 30 seconds
dns_retransmit_interval 5 seconds
# By default, Squid uses the system's configured DNS resolvers.
# If you need to override them, set dns_nameservers to appropriate servers
# for your environment (for example, internal/corporate DNS). The following
# is an example using public DNS and SHOULD be customized before use:
# dns_nameservers 8.8.8.8 8.8.4.4
# Logging format for better debugging
logformat dify_log %ts.%03tu %6tr %>a %Ss/%03>Hs %<st %rm %ru %[un %Sh/%<a %mt
access_log daemon:/var/log/squid/access.log dify_log
# Access log to track concurrent requests and timeouts
logfile_rotate 10

View File

@ -7,8 +7,12 @@ logs
# node
node_modules
dist
build
coverage
.husky
.next
.pnpm-store
# vscode
.vscode
@ -22,3 +26,7 @@ node_modules
# Jetbrains
.idea
# git
.git
.gitignore

View File

@ -47,6 +47,8 @@ NEXT_PUBLIC_TOP_K_MAX_VALUE=10
# The maximum number of tokens for segmentation
NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH at container startup (Docker only)
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
# Maximum loop count in the workflow
NEXT_PUBLIC_LOOP_NODE_MAX_COUNT=100

View File

@ -12,7 +12,8 @@ RUN apk add --no-cache tzdata
RUN corepack enable
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
ENV NEXT_PUBLIC_BASE_PATH=""
ARG NEXT_PUBLIC_BASE_PATH=""
ENV NEXT_PUBLIC_BASE_PATH="$NEXT_PUBLIC_BASE_PATH"
# install packages

View File

@ -110,7 +110,7 @@ const GotoAnything: FC<Props> = ({
isWorkflowPage,
isRagPipelinePage,
defaultLocale,
Object.keys(Actions).sort().join(','),
Actions,
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()

View File

@ -1,10 +1,13 @@
'use client'
import type { PluginDetail } from '../types'
import type { FilterState } from './filter-management'
import { useDebounceFn } from 'ahooks'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import PluginDetailPanel from '@/app/components/plugins/plugin-detail-panel'
import { useGetLanguage } from '@/context/i18n'
import { renderI18nObject } from '@/i18n-config'
import { useInstalledLatestVersion, useInstalledPluginList, useInvalidateInstalledPluginList } from '@/service/use-plugins'
import Loading from '../../base/loading'
import { PluginSource } from '../types'
@ -13,8 +16,34 @@ import Empty from './empty'
import FilterManagement from './filter-management'
import List from './list'
const matchesSearchQuery = (plugin: PluginDetail & { latest_version: string }, query: string, locale: string): boolean => {
if (!query)
return true
const lowerQuery = query.toLowerCase()
const { declaration } = plugin
// Match plugin_id
if (plugin.plugin_id.toLowerCase().includes(lowerQuery))
return true
// Match plugin name
if (plugin.name?.toLowerCase().includes(lowerQuery))
return true
// Match declaration name
if (declaration.name?.toLowerCase().includes(lowerQuery))
return true
// Match localized label
const label = renderI18nObject(declaration.label, locale)
if (label?.toLowerCase().includes(lowerQuery))
return true
// Match localized description
const description = renderI18nObject(declaration.description, locale)
if (description?.toLowerCase().includes(lowerQuery))
return true
return false
}
const PluginsPanel = () => {
const { t } = useTranslation()
const locale = useGetLanguage()
const filters = usePluginPageContext(v => v.filters) as FilterState
const setFilters = usePluginPageContext(v => v.setFilters)
const { data: pluginList, isLoading: isPluginListLoading, isFetching, isLastPage, loadNextPage } = useInstalledPluginList()
@ -48,11 +77,11 @@ const PluginsPanel = () => {
return (
(categories.length === 0 || categories.includes(plugin.declaration.category))
&& (tags.length === 0 || tags.some(tag => plugin.declaration.tags.includes(tag)))
&& (searchQuery === '' || plugin.plugin_id.toLowerCase().includes(searchQuery.toLowerCase()))
&& matchesSearchQuery(plugin, searchQuery, locale)
)
})
return filteredList
}, [pluginListWithLatestVersion, filters])
}, [pluginListWithLatestVersion, filters, locale])
const currentPluginDetail = useMemo(() => {
const detail = pluginListWithLatestVersion.find(plugin => plugin.plugin_id === currentPluginID)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,971 @@
import type { PanelProps } from '@/app/components/workflow/panel'
import { render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import RagPipelinePanel from './index'
// ============================================================================
// Mock External Dependencies
// ============================================================================
// Type definitions for dynamic module
type DynamicModule = {
default?: React.ComponentType<Record<string, unknown>>
}
type PromiseOrModule = Promise<DynamicModule> | DynamicModule
// Mock next/dynamic to return synchronous components immediately
vi.mock('next/dynamic', () => ({
default: (loader: () => PromiseOrModule, _options?: Record<string, unknown>) => {
let Component: React.ComponentType<Record<string, unknown>> | null = null
// Try to resolve the loader synchronously for mocked modules
try {
const result = loader() as PromiseOrModule
if (result && typeof (result as Promise<DynamicModule>).then === 'function') {
// For async modules, we need to handle them specially
// This will work with vi.mock since mocks resolve synchronously
(result as Promise<DynamicModule>).then((mod: DynamicModule) => {
Component = (mod.default || mod) as React.ComponentType<Record<string, unknown>>
})
}
else if (result) {
Component = ((result as DynamicModule).default || result) as React.ComponentType<Record<string, unknown>>
}
}
catch {
// If the module can't be resolved, Component stays null
}
// Return a simple wrapper that renders the component or null
const DynamicComponent = React.forwardRef((props: Record<string, unknown>, ref: React.Ref<unknown>) => {
// For mocked modules, Component should already be set
if (Component)
return <Component {...props} ref={ref} />
return null
})
DynamicComponent.displayName = 'DynamicComponent'
return DynamicComponent
},
}))
// Mock workflow store
let mockHistoryWorkflowData: Record<string, unknown> | null = null
let mockShowDebugAndPreviewPanel = false
let mockShowGlobalVariablePanel = false
let mockShowInputFieldPanel = false
let mockShowInputFieldPreviewPanel = false
let mockInputFieldEditPanelProps: Record<string, unknown> | null = null
let mockPipelineId = 'test-pipeline-123'
type MockStoreState = {
historyWorkflowData: Record<string, unknown> | null
showDebugAndPreviewPanel: boolean
showGlobalVariablePanel: boolean
showInputFieldPanel: boolean
showInputFieldPreviewPanel: boolean
inputFieldEditPanelProps: Record<string, unknown> | null
pipelineId: string
}
vi.mock('@/app/components/workflow/store', () => ({
useStore: (selector: (state: MockStoreState) => unknown) => {
const state: MockStoreState = {
historyWorkflowData: mockHistoryWorkflowData,
showDebugAndPreviewPanel: mockShowDebugAndPreviewPanel,
showGlobalVariablePanel: mockShowGlobalVariablePanel,
showInputFieldPanel: mockShowInputFieldPanel,
showInputFieldPreviewPanel: mockShowInputFieldPreviewPanel,
inputFieldEditPanelProps: mockInputFieldEditPanelProps,
pipelineId: mockPipelineId,
}
return selector(state)
},
}))
// Mock Panel component to capture props and render children
let capturedPanelProps: PanelProps | null = null
vi.mock('@/app/components/workflow/panel', () => ({
default: (props: PanelProps) => {
capturedPanelProps = props
return (
<div data-testid="workflow-panel">
<div data-testid="panel-left">{props.components?.left}</div>
<div data-testid="panel-right">{props.components?.right}</div>
</div>
)
},
}))
// Mock Record component
vi.mock('@/app/components/workflow/panel/record', () => ({
default: () => <div data-testid="record-panel">Record Panel</div>,
}))
// Mock TestRunPanel component
vi.mock('@/app/components/rag-pipeline/components/panel/test-run', () => ({
default: () => <div data-testid="test-run-panel">Test Run Panel</div>,
}))
// Mock InputFieldPanel component
vi.mock('./input-field', () => ({
default: () => <div data-testid="input-field-panel">Input Field Panel</div>,
}))
// Mock InputFieldEditorPanel component
const mockInputFieldEditorProps = vi.fn()
vi.mock('./input-field/editor', () => ({
default: (props: Record<string, unknown>) => {
mockInputFieldEditorProps(props)
return <div data-testid="input-field-editor-panel">Input Field Editor Panel</div>
},
}))
// Mock PreviewPanel component
vi.mock('./input-field/preview', () => ({
default: () => <div data-testid="preview-panel">Preview Panel</div>,
}))
// Mock GlobalVariablePanel component
vi.mock('@/app/components/workflow/panel/global-variable-panel', () => ({
default: () => <div data-testid="global-variable-panel">Global Variable Panel</div>,
}))
// ============================================================================
// Helper Functions
// ============================================================================
type SetupMockOptions = {
historyWorkflowData?: Record<string, unknown> | null
showDebugAndPreviewPanel?: boolean
showGlobalVariablePanel?: boolean
showInputFieldPanel?: boolean
showInputFieldPreviewPanel?: boolean
inputFieldEditPanelProps?: Record<string, unknown> | null
pipelineId?: string
}
const setupMocks = (options?: SetupMockOptions) => {
mockHistoryWorkflowData = options?.historyWorkflowData ?? null
mockShowDebugAndPreviewPanel = options?.showDebugAndPreviewPanel ?? false
mockShowGlobalVariablePanel = options?.showGlobalVariablePanel ?? false
mockShowInputFieldPanel = options?.showInputFieldPanel ?? false
mockShowInputFieldPreviewPanel = options?.showInputFieldPreviewPanel ?? false
mockInputFieldEditPanelProps = options?.inputFieldEditPanelProps ?? null
mockPipelineId = options?.pipelineId ?? 'test-pipeline-123'
capturedPanelProps = null
}
// ============================================================================
// RagPipelinePanel Component Tests
// ============================================================================
describe('RagPipelinePanel', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
// -------------------------------------------------------------------------
// Rendering Tests
// -------------------------------------------------------------------------
describe('Rendering', () => {
it('should render without crashing', async () => {
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('workflow-panel')).toBeInTheDocument()
})
})
it('should render Panel component with correct structure', async () => {
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('panel-left')).toBeInTheDocument()
expect(screen.getByTestId('panel-right')).toBeInTheDocument()
})
})
it('should pass versionHistoryPanelProps to Panel', async () => {
// Arrange
setupMocks({ pipelineId: 'my-pipeline-456' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.versionHistoryPanelProps).toBeDefined()
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
'/rag/pipelines/my-pipeline-456/workflows',
)
})
})
})
// -------------------------------------------------------------------------
// Memoization Tests - versionHistoryPanelProps
// -------------------------------------------------------------------------
describe('Memoization - versionHistoryPanelProps', () => {
it('should compute correct getVersionListUrl based on pipelineId', async () => {
// Arrange
setupMocks({ pipelineId: 'pipeline-abc' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
'/rag/pipelines/pipeline-abc/workflows',
)
})
})
it('should compute correct deleteVersionUrl function', async () => {
// Arrange
setupMocks({ pipelineId: 'pipeline-xyz' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
const deleteUrl = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-1')
expect(deleteUrl).toBe('/rag/pipelines/pipeline-xyz/workflows/version-1')
})
})
it('should compute correct updateVersionUrl function', async () => {
// Arrange
setupMocks({ pipelineId: 'pipeline-def' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
const updateUrl = capturedPanelProps?.versionHistoryPanelProps?.updateVersionUrl?.('version-2')
expect(updateUrl).toBe('/rag/pipelines/pipeline-def/workflows/version-2')
})
})
it('should set latestVersionId to empty string', async () => {
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.versionHistoryPanelProps?.latestVersionId).toBe('')
})
})
})
// -------------------------------------------------------------------------
// Memoization Tests - panelProps
// -------------------------------------------------------------------------
describe('Memoization - panelProps', () => {
it('should pass components.left to Panel', async () => {
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.components?.left).toBeDefined()
})
})
it('should pass components.right to Panel', async () => {
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.components?.right).toBeDefined()
})
})
it('should pass versionHistoryPanelProps to panelProps', async () => {
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.versionHistoryPanelProps).toBeDefined()
})
})
})
// -------------------------------------------------------------------------
// Component Memoization Tests (React.memo)
// -------------------------------------------------------------------------
describe('Component Memoization', () => {
it('should be wrapped with React.memo', async () => {
// The component should not break when re-rendered
const { rerender } = render(<RagPipelinePanel />)
// Act - rerender without prop changes
rerender(<RagPipelinePanel />)
// Assert - component should still render correctly
await waitFor(() => {
expect(screen.getByTestId('workflow-panel')).toBeInTheDocument()
})
})
})
})
// ============================================================================
// RagPipelinePanelOnRight Component Tests
// ============================================================================
describe('RagPipelinePanelOnRight', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
// -------------------------------------------------------------------------
// Conditional Rendering - Record Panel
// -------------------------------------------------------------------------
describe('Record Panel Conditional Rendering', () => {
it('should render Record panel when historyWorkflowData exists', async () => {
// Arrange
setupMocks({ historyWorkflowData: { id: 'history-1' } })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
})
})
it('should not render Record panel when historyWorkflowData is null', async () => {
// Arrange
setupMocks({ historyWorkflowData: null })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('record-panel')).not.toBeInTheDocument()
})
})
it('should not render Record panel when historyWorkflowData is undefined', async () => {
// Arrange
setupMocks({ historyWorkflowData: undefined })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('record-panel')).not.toBeInTheDocument()
})
})
})
// -------------------------------------------------------------------------
// Conditional Rendering - TestRun Panel
// -------------------------------------------------------------------------
describe('TestRun Panel Conditional Rendering', () => {
it('should render TestRun panel when showDebugAndPreviewPanel is true', async () => {
// Arrange
setupMocks({ showDebugAndPreviewPanel: true })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
})
})
it('should not render TestRun panel when showDebugAndPreviewPanel is false', async () => {
// Arrange
setupMocks({ showDebugAndPreviewPanel: false })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('test-run-panel')).not.toBeInTheDocument()
})
})
})
// -------------------------------------------------------------------------
// Conditional Rendering - GlobalVariable Panel
// -------------------------------------------------------------------------
describe('GlobalVariable Panel Conditional Rendering', () => {
it('should render GlobalVariable panel when showGlobalVariablePanel is true', async () => {
// Arrange
setupMocks({ showGlobalVariablePanel: true })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
})
})
it('should not render GlobalVariable panel when showGlobalVariablePanel is false', async () => {
// Arrange
setupMocks({ showGlobalVariablePanel: false })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('global-variable-panel')).not.toBeInTheDocument()
})
})
})
// -------------------------------------------------------------------------
// Multiple Panels Rendering
// -------------------------------------------------------------------------
describe('Multiple Panels Rendering', () => {
it('should render all right panels when all conditions are true', async () => {
// Arrange
setupMocks({
historyWorkflowData: { id: 'history-1' },
showDebugAndPreviewPanel: true,
showGlobalVariablePanel: true,
})
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
})
})
it('should render no right panels when all conditions are false', async () => {
// Arrange
setupMocks({
historyWorkflowData: null,
showDebugAndPreviewPanel: false,
showGlobalVariablePanel: false,
})
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('record-panel')).not.toBeInTheDocument()
expect(screen.queryByTestId('test-run-panel')).not.toBeInTheDocument()
expect(screen.queryByTestId('global-variable-panel')).not.toBeInTheDocument()
})
})
it('should render only Record and TestRun panels', async () => {
// Arrange
setupMocks({
historyWorkflowData: { id: 'history-1' },
showDebugAndPreviewPanel: true,
showGlobalVariablePanel: false,
})
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
expect(screen.queryByTestId('global-variable-panel')).not.toBeInTheDocument()
})
})
})
})
// ============================================================================
// RagPipelinePanelOnLeft Component Tests
// ============================================================================
describe('RagPipelinePanelOnLeft', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
// -------------------------------------------------------------------------
// Conditional Rendering - Preview Panel
// -------------------------------------------------------------------------
describe('Preview Panel Conditional Rendering', () => {
it('should render Preview panel when showInputFieldPreviewPanel is true', async () => {
// Arrange
setupMocks({ showInputFieldPreviewPanel: true })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
})
})
it('should not render Preview panel when showInputFieldPreviewPanel is false', async () => {
// Arrange
setupMocks({ showInputFieldPreviewPanel: false })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('preview-panel')).not.toBeInTheDocument()
})
})
})
// -------------------------------------------------------------------------
// Conditional Rendering - InputFieldEditor Panel
// -------------------------------------------------------------------------
describe('InputFieldEditor Panel Conditional Rendering', () => {
it('should render InputFieldEditor panel when inputFieldEditPanelProps is provided', async () => {
// Arrange
const editProps = {
onClose: vi.fn(),
onSubmit: vi.fn(),
initialData: { variable: 'test' },
}
setupMocks({ inputFieldEditPanelProps: editProps })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
})
})
it('should not render InputFieldEditor panel when inputFieldEditPanelProps is null', async () => {
// Arrange
setupMocks({ inputFieldEditPanelProps: null })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('input-field-editor-panel')).not.toBeInTheDocument()
})
})
it('should pass props to InputFieldEditor panel', async () => {
// Arrange
const editProps = {
onClose: vi.fn(),
onSubmit: vi.fn(),
initialData: { variable: 'test_var', label: 'Test Label' },
}
setupMocks({ inputFieldEditPanelProps: editProps })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(mockInputFieldEditorProps).toHaveBeenCalledWith(
expect.objectContaining({
onClose: editProps.onClose,
onSubmit: editProps.onSubmit,
initialData: editProps.initialData,
}),
)
})
})
})
// -------------------------------------------------------------------------
// Conditional Rendering - InputField Panel
// -------------------------------------------------------------------------
describe('InputField Panel Conditional Rendering', () => {
it('should render InputField panel when showInputFieldPanel is true', async () => {
// Arrange
setupMocks({ showInputFieldPanel: true })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
})
})
it('should not render InputField panel when showInputFieldPanel is false', async () => {
// Arrange
setupMocks({ showInputFieldPanel: false })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('input-field-panel')).not.toBeInTheDocument()
})
})
})
// -------------------------------------------------------------------------
// Multiple Panels Rendering
// -------------------------------------------------------------------------
describe('Multiple Left Panels Rendering', () => {
it('should render all left panels when all conditions are true', async () => {
// Arrange
setupMocks({
showInputFieldPreviewPanel: true,
inputFieldEditPanelProps: { onClose: vi.fn(), onSubmit: vi.fn() },
showInputFieldPanel: true,
})
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
})
})
it('should render no left panels when all conditions are false', async () => {
// Arrange
setupMocks({
showInputFieldPreviewPanel: false,
inputFieldEditPanelProps: null,
showInputFieldPanel: false,
})
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.queryByTestId('preview-panel')).not.toBeInTheDocument()
expect(screen.queryByTestId('input-field-editor-panel')).not.toBeInTheDocument()
expect(screen.queryByTestId('input-field-panel')).not.toBeInTheDocument()
})
})
it('should render only Preview and InputField panels', async () => {
// Arrange
setupMocks({
showInputFieldPreviewPanel: true,
inputFieldEditPanelProps: null,
showInputFieldPanel: true,
})
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
expect(screen.queryByTestId('input-field-editor-panel')).not.toBeInTheDocument()
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
})
})
})
})
// ============================================================================
// Edge Cases Tests
// ============================================================================
describe('Edge Cases', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
// -------------------------------------------------------------------------
// Empty/Undefined Values
// -------------------------------------------------------------------------
describe('Empty/Undefined Values', () => {
it('should handle empty pipelineId gracefully', async () => {
// Arrange
setupMocks({ pipelineId: '' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
'/rag/pipelines//workflows',
)
})
})
it('should handle special characters in pipelineId', async () => {
// Arrange
setupMocks({ pipelineId: 'pipeline-with-special_chars.123' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
'/rag/pipelines/pipeline-with-special_chars.123/workflows',
)
})
})
})
// -------------------------------------------------------------------------
// Props Spreading Tests
// -------------------------------------------------------------------------
describe('Props Spreading', () => {
it('should correctly spread inputFieldEditPanelProps to editor component', async () => {
// Arrange
const customProps = {
onClose: vi.fn(),
onSubmit: vi.fn(),
initialData: {
variable: 'custom_var',
label: 'Custom Label',
type: 'text',
},
extraProp: 'extra-value',
}
setupMocks({ inputFieldEditPanelProps: customProps })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(mockInputFieldEditorProps).toHaveBeenCalledWith(
expect.objectContaining({
extraProp: 'extra-value',
}),
)
})
})
})
// -------------------------------------------------------------------------
// State Combinations
// -------------------------------------------------------------------------
describe('State Combinations', () => {
it('should handle all panels visible simultaneously', async () => {
// Arrange
setupMocks({
historyWorkflowData: { id: 'h1' },
showDebugAndPreviewPanel: true,
showGlobalVariablePanel: true,
showInputFieldPreviewPanel: true,
inputFieldEditPanelProps: { onClose: vi.fn(), onSubmit: vi.fn() },
showInputFieldPanel: true,
})
// Act
render(<RagPipelinePanel />)
// Assert - All panels should be visible
await waitFor(() => {
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
})
})
})
})
// ============================================================================
// URL Generator Functions Tests
// ============================================================================
describe('URL Generator Functions', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
it('should return consistent URLs for same versionId', async () => {
// Arrange
setupMocks({ pipelineId: 'stable-pipeline' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
const deleteUrl1 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-x')
const deleteUrl2 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-x')
expect(deleteUrl1).toBe(deleteUrl2)
})
})
it('should return different URLs for different versionIds', async () => {
// Arrange
setupMocks({ pipelineId: 'stable-pipeline' })
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
const deleteUrl1 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-1')
const deleteUrl2 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-2')
expect(deleteUrl1).not.toBe(deleteUrl2)
expect(deleteUrl1).toBe('/rag/pipelines/stable-pipeline/workflows/version-1')
expect(deleteUrl2).toBe('/rag/pipelines/stable-pipeline/workflows/version-2')
})
})
})
// ============================================================================
// Type Safety Tests
// ============================================================================
describe('Type Safety', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
it('should pass correct PanelProps structure', async () => {
// Act
render(<RagPipelinePanel />)
// Assert - Check structure matches PanelProps
await waitFor(() => {
expect(capturedPanelProps).toHaveProperty('components')
expect(capturedPanelProps).toHaveProperty('versionHistoryPanelProps')
expect(capturedPanelProps?.components).toHaveProperty('left')
expect(capturedPanelProps?.components).toHaveProperty('right')
})
})
it('should pass correct versionHistoryPanelProps structure', async () => {
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('getVersionListUrl')
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('deleteVersionUrl')
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('updateVersionUrl')
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('latestVersionId')
})
})
})
// ============================================================================
// Performance Tests
// ============================================================================
describe('Performance', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
it('should handle multiple rerenders without issues', async () => {
// Arrange
const { rerender } = render(<RagPipelinePanel />)
// Act - Multiple rerenders
for (let i = 0; i < 10; i++)
rerender(<RagPipelinePanel />)
// Assert - Component should still work
await waitFor(() => {
expect(screen.getByTestId('workflow-panel')).toBeInTheDocument()
})
})
})
// ============================================================================
// Integration Tests
// ============================================================================
describe('Integration Tests', () => {
beforeEach(() => {
vi.clearAllMocks()
setupMocks()
})
it('should pass correct components to Panel', async () => {
// Arrange
setupMocks({
historyWorkflowData: { id: 'h1' },
showInputFieldPanel: true,
})
// Act
render(<RagPipelinePanel />)
// Assert
await waitFor(() => {
expect(capturedPanelProps?.components?.left).toBeDefined()
expect(capturedPanelProps?.components?.right).toBeDefined()
// Check that the components are React elements
expect(React.isValidElement(capturedPanelProps?.components?.left)).toBe(true)
expect(React.isValidElement(capturedPanelProps?.components?.right)).toBe(true)
})
})
it('should correctly consume all store selectors', async () => {
// Arrange
setupMocks({
historyWorkflowData: { id: 'test-history' },
showDebugAndPreviewPanel: true,
showGlobalVariablePanel: true,
showInputFieldPanel: true,
showInputFieldPreviewPanel: true,
inputFieldEditPanelProps: { onClose: vi.fn(), onSubmit: vi.fn() },
pipelineId: 'integration-test-pipeline',
})
// Act
render(<RagPipelinePanel />)
// Assert - All store-dependent rendering should work
await waitFor(() => {
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
'/rag/pipelines/integration-test-pipeline/workflows',
)
})
})
})

View File

@ -49,6 +49,7 @@ const InputFieldEditorPanel = ({
</div>
<button
type="button"
data-testid="input-field-editor-close-btn"
className="absolute right-2.5 top-2.5 flex size-8 items-center justify-center"
onClick={onClose}
>

View File

@ -53,6 +53,7 @@ const FieldList = ({
{LabelRightContent}
</div>
<ActionButton
data-testid="field-list-add-btn"
onClick={() => handleOpenInputFieldEditor()}
disabled={readonly}
className={cn(readonly && 'cursor-not-allowed')}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,937 @@
import type { WorkflowRunningData } from '@/app/components/workflow/types'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { WorkflowRunningStatus } from '@/app/components/workflow/types'
import { ChunkingMode } from '@/models/datasets'
import Header from './header'
// Import components after mocks
import TestRunPanel from './index'
// ============================================================================
// Mocks
// ============================================================================
// Mock workflow store
const mockIsPreparingDataSource = vi.fn(() => true)
const mockSetIsPreparingDataSource = vi.fn()
const mockWorkflowRunningData = vi.fn<() => WorkflowRunningData | undefined>(() => undefined)
const mockPipelineId = 'test-pipeline-id'
vi.mock('@/app/components/workflow/store', () => ({
useStore: (selector: (state: Record<string, unknown>) => unknown) => {
const state = {
isPreparingDataSource: mockIsPreparingDataSource(),
workflowRunningData: mockWorkflowRunningData(),
pipelineId: mockPipelineId,
}
return selector(state)
},
useWorkflowStore: () => ({
getState: () => ({
isPreparingDataSource: mockIsPreparingDataSource(),
setIsPreparingDataSource: mockSetIsPreparingDataSource,
}),
}),
}))
// Mock workflow interactions
const mockHandleCancelDebugAndPreviewPanel = vi.fn()
vi.mock('@/app/components/workflow/hooks', () => ({
useWorkflowInteractions: () => ({
handleCancelDebugAndPreviewPanel: mockHandleCancelDebugAndPreviewPanel,
}),
useWorkflowRun: () => ({
handleRun: vi.fn(),
}),
useToolIcon: () => 'mock-tool-icon',
}))
// Mock data source provider
vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store/provider', () => ({
default: ({ children }: { children: React.ReactNode }) => <div data-testid="data-source-provider">{children}</div>,
}))
// Mock Preparation component
vi.mock('./preparation', () => ({
default: () => <div data-testid="preparation-component">Preparation</div>,
}))
// Mock Result component (for TestRunPanel tests only)
vi.mock('./result', () => ({
default: () => <div data-testid="result-component">Result</div>,
}))
// Mock ResultPanel from workflow
vi.mock('@/app/components/workflow/run/result-panel', () => ({
default: (props: Record<string, unknown>) => (
<div data-testid="result-panel">
ResultPanel -
{' '}
{props.status as string}
</div>
),
}))
// Mock TracingPanel from workflow
vi.mock('@/app/components/workflow/run/tracing-panel', () => ({
default: (props: { list: unknown[] }) => (
<div data-testid="tracing-panel">
TracingPanel -
{' '}
{props.list?.length ?? 0}
{' '}
items
</div>
),
}))
// Mock Loading component
vi.mock('@/app/components/base/loading', () => ({
default: () => <div data-testid="loading">Loading...</div>,
}))
// Mock config
vi.mock('@/config', () => ({
RAG_PIPELINE_PREVIEW_CHUNK_NUM: 5,
}))
// ============================================================================
// Test Data Factories
// ============================================================================
const createMockWorkflowRunningData = (overrides: Partial<WorkflowRunningData> = {}): WorkflowRunningData => ({
result: {
status: WorkflowRunningStatus.Succeeded,
outputs: '{"test": "output"}',
outputs_truncated: false,
inputs: '{"test": "input"}',
inputs_truncated: false,
process_data_truncated: false,
error: undefined,
elapsed_time: 1000,
total_tokens: 100,
created_at: Date.now(),
created_by: 'Test User',
total_steps: 5,
exceptions_count: 0,
},
tracing: [],
...overrides,
})
const createMockGeneralOutputs = (chunkContents: string[] = ['chunk1', 'chunk2']) => ({
chunk_structure: ChunkingMode.text,
preview: chunkContents.map(content => ({ content })),
})
const createMockParentChildOutputs = (parentMode: 'paragraph' | 'full-doc' = 'paragraph') => ({
chunk_structure: ChunkingMode.parentChild,
parent_mode: parentMode,
preview: [
{ content: 'parent1', child_chunks: ['child1', 'child2'] },
{ content: 'parent2', child_chunks: ['child3', 'child4'] },
],
})
const createMockQAOutputs = () => ({
chunk_structure: ChunkingMode.qa,
qa_preview: [
{ question: 'Q1', answer: 'A1' },
{ question: 'Q2', answer: 'A2' },
],
})
// ============================================================================
// TestRunPanel Component Tests
// ============================================================================
describe('TestRunPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
mockIsPreparingDataSource.mockReturnValue(true)
mockWorkflowRunningData.mockReturnValue(undefined)
})
// Basic rendering tests
describe('Rendering', () => {
it('should render with correct container styles', () => {
const { container } = render(<TestRunPanel />)
const panelDiv = container.firstChild as HTMLElement
expect(panelDiv).toHaveClass('relative', 'flex', 'h-full', 'w-[480px]', 'flex-col')
})
it('should render Header component', () => {
render(<TestRunPanel />)
expect(screen.getByText('datasetPipeline.testRun.title')).toBeInTheDocument()
})
})
// Conditional rendering based on isPreparingDataSource
describe('Conditional Content Rendering', () => {
it('should render Preparation inside DataSourceProvider when isPreparingDataSource is true', () => {
mockIsPreparingDataSource.mockReturnValue(true)
render(<TestRunPanel />)
expect(screen.getByTestId('data-source-provider')).toBeInTheDocument()
expect(screen.getByTestId('preparation-component')).toBeInTheDocument()
expect(screen.queryByTestId('result-component')).not.toBeInTheDocument()
})
it('should render Result when isPreparingDataSource is false', () => {
mockIsPreparingDataSource.mockReturnValue(false)
render(<TestRunPanel />)
expect(screen.getByTestId('result-component')).toBeInTheDocument()
expect(screen.queryByTestId('data-source-provider')).not.toBeInTheDocument()
expect(screen.queryByTestId('preparation-component')).not.toBeInTheDocument()
})
})
})
// ============================================================================
// Header Component Tests
// ============================================================================
describe('Header', () => {
beforeEach(() => {
vi.clearAllMocks()
mockIsPreparingDataSource.mockReturnValue(true)
})
// Rendering tests
describe('Rendering', () => {
it('should render title with correct translation key', () => {
render(<Header />)
expect(screen.getByText('datasetPipeline.testRun.title')).toBeInTheDocument()
})
it('should render close button', () => {
render(<Header />)
const closeButton = screen.getByRole('button')
expect(closeButton).toBeInTheDocument()
})
it('should have correct layout classes', () => {
const { container } = render(<Header />)
const headerDiv = container.firstChild as HTMLElement
expect(headerDiv).toHaveClass('flex', 'items-center', 'gap-x-2', 'pl-4', 'pr-3', 'pt-4')
})
})
// Close button interactions
describe('Close Button Interaction', () => {
it('should call setIsPreparingDataSource(false) and handleCancelDebugAndPreviewPanel when clicked and isPreparingDataSource is true', () => {
mockIsPreparingDataSource.mockReturnValue(true)
render(<Header />)
const closeButton = screen.getByRole('button')
fireEvent.click(closeButton)
expect(mockSetIsPreparingDataSource).toHaveBeenCalledWith(false)
expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1)
})
it('should only call handleCancelDebugAndPreviewPanel when isPreparingDataSource is false', () => {
mockIsPreparingDataSource.mockReturnValue(false)
render(<Header />)
const closeButton = screen.getByRole('button')
fireEvent.click(closeButton)
expect(mockSetIsPreparingDataSource).not.toHaveBeenCalled()
expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1)
})
})
})
// ============================================================================
// Result Component Tests (Real Implementation)
// ============================================================================
// Unmock Result for these tests
vi.doUnmock('./result')
describe('Result', () => {
// Dynamically import Result to get real implementation
let Result: typeof import('./result').default
beforeAll(async () => {
const resultModule = await import('./result')
Result = resultModule.default
})
beforeEach(() => {
vi.clearAllMocks()
mockWorkflowRunningData.mockReturnValue(undefined)
})
// Rendering tests
describe('Rendering', () => {
it('should render with RESULT tab active by default', async () => {
render(<Result />)
await waitFor(() => {
const resultTab = screen.getByRole('button', { name: /runLog\.result/i })
expect(resultTab).toHaveClass('border-util-colors-blue-brand-blue-brand-600')
})
})
it('should render all three tabs', () => {
render(<Result />)
expect(screen.getByRole('button', { name: /runLog\.result/i })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /runLog\.detail/i })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /runLog\.tracing/i })).toBeInTheDocument()
})
})
// Tab switching tests
describe('Tab Switching', () => {
it('should switch to DETAIL tab when clicked', async () => {
mockWorkflowRunningData.mockReturnValue(createMockWorkflowRunningData())
render(<Result />)
const detailTab = screen.getByRole('button', { name: /runLog\.detail/i })
fireEvent.click(detailTab)
await waitFor(() => {
expect(screen.getByTestId('result-panel')).toBeInTheDocument()
})
})
it('should switch to TRACING tab when clicked', async () => {
mockWorkflowRunningData.mockReturnValue(createMockWorkflowRunningData({ tracing: [{ id: '1' }] as unknown as WorkflowRunningData['tracing'] }))
render(<Result />)
const tracingTab = screen.getByRole('button', { name: /runLog\.tracing/i })
fireEvent.click(tracingTab)
await waitFor(() => {
expect(screen.getByTestId('tracing-panel')).toBeInTheDocument()
})
})
})
// Loading states
describe('Loading States', () => {
it('should show loading in DETAIL tab when no result data', async () => {
mockWorkflowRunningData.mockReturnValue({
result: undefined as unknown as WorkflowRunningData['result'],
tracing: [],
})
render(<Result />)
const detailTab = screen.getByRole('button', { name: /runLog\.detail/i })
fireEvent.click(detailTab)
await waitFor(() => {
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
})
it('should show loading in TRACING tab when no tracing data', async () => {
mockWorkflowRunningData.mockReturnValue(createMockWorkflowRunningData({ tracing: [] }))
render(<Result />)
const tracingTab = screen.getByRole('button', { name: /runLog\.tracing/i })
fireEvent.click(tracingTab)
await waitFor(() => {
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
})
})
})
// ============================================================================
// ResultPreview Component Tests
// ============================================================================
// We need to import ResultPreview directly
vi.doUnmock('./result/result-preview')
describe('ResultPreview', () => {
let ResultPreview: typeof import('./result/result-preview').default
beforeAll(async () => {
const previewModule = await import('./result/result-preview')
ResultPreview = previewModule.default
})
const mockOnSwitchToDetail = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
// Loading state
describe('Loading State', () => {
it('should show loading spinner when isRunning is true and no outputs', () => {
render(
<ResultPreview
isRunning={true}
outputs={undefined}
error={undefined}
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
expect(screen.getByText('pipeline.result.resultPreview.loading')).toBeInTheDocument()
})
it('should not show loading when outputs are available', () => {
render(
<ResultPreview
isRunning={true}
outputs={createMockGeneralOutputs()}
error={undefined}
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
expect(screen.queryByText('pipeline.result.resultPreview.loading')).not.toBeInTheDocument()
})
})
// Error state
describe('Error State', () => {
it('should show error message when not running and has error', () => {
render(
<ResultPreview
isRunning={false}
outputs={undefined}
error="Test error message"
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
expect(screen.getByText('pipeline.result.resultPreview.error')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'pipeline.result.resultPreview.viewDetails' })).toBeInTheDocument()
})
it('should call onSwitchToDetail when View Details button is clicked', () => {
render(
<ResultPreview
isRunning={false}
outputs={undefined}
error="Test error message"
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
const viewDetailsButton = screen.getByRole('button', { name: 'pipeline.result.resultPreview.viewDetails' })
fireEvent.click(viewDetailsButton)
expect(mockOnSwitchToDetail).toHaveBeenCalledTimes(1)
})
it('should not show error when still running', () => {
render(
<ResultPreview
isRunning={true}
outputs={undefined}
error="Test error message"
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
expect(screen.queryByText('pipeline.result.resultPreview.error')).not.toBeInTheDocument()
})
})
// Success state with outputs
describe('Success State with Outputs', () => {
it('should render chunk content when outputs are available', () => {
render(
<ResultPreview
isRunning={false}
outputs={createMockGeneralOutputs(['test chunk content'])}
error={undefined}
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
// Check that chunk content is rendered (the real ChunkCardList renders the content)
expect(screen.getByText('test chunk content')).toBeInTheDocument()
})
it('should render multiple chunks when provided', () => {
render(
<ResultPreview
isRunning={false}
outputs={createMockGeneralOutputs(['chunk one', 'chunk two'])}
error={undefined}
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
expect(screen.getByText('chunk one')).toBeInTheDocument()
expect(screen.getByText('chunk two')).toBeInTheDocument()
})
it('should show footer tip', () => {
render(
<ResultPreview
isRunning={false}
outputs={createMockGeneralOutputs()}
error={undefined}
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
expect(screen.getByText(/pipeline\.result\.resultPreview\.footerTip/)).toBeInTheDocument()
})
})
// Edge cases
describe('Edge Cases', () => {
it('should handle empty outputs gracefully', () => {
render(
<ResultPreview
isRunning={false}
outputs={null}
error={undefined}
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
// Should not crash and should not show chunk card list
expect(screen.queryByTestId('chunk-card-list')).not.toBeInTheDocument()
})
it('should handle undefined outputs', () => {
render(
<ResultPreview
isRunning={false}
outputs={undefined}
error={undefined}
onSwitchToDetail={mockOnSwitchToDetail}
/>,
)
expect(screen.queryByTestId('chunk-card-list')).not.toBeInTheDocument()
})
})
})
// ============================================================================
// Tabs Component Tests
// ============================================================================
vi.doUnmock('./result/tabs')
describe('Tabs', () => {
let Tabs: typeof import('./result/tabs').default
beforeAll(async () => {
const tabsModule = await import('./result/tabs')
Tabs = tabsModule.default
})
const mockSwitchTab = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
// Rendering tests
describe('Rendering', () => {
it('should render all three tabs', () => {
render(
<Tabs
currentTab="RESULT"
workflowRunningData={createMockWorkflowRunningData()}
switchTab={mockSwitchTab}
/>,
)
expect(screen.getByRole('button', { name: /runLog\.result/i })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /runLog\.detail/i })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /runLog\.tracing/i })).toBeInTheDocument()
})
})
// Active tab styling
describe('Active Tab Styling', () => {
it('should highlight RESULT tab when currentTab is RESULT', () => {
render(
<Tabs
currentTab="RESULT"
workflowRunningData={createMockWorkflowRunningData()}
switchTab={mockSwitchTab}
/>,
)
const resultTab = screen.getByRole('button', { name: /runLog\.result/i })
expect(resultTab).toHaveClass('border-util-colors-blue-brand-blue-brand-600')
})
it('should highlight DETAIL tab when currentTab is DETAIL', () => {
render(
<Tabs
currentTab="DETAIL"
workflowRunningData={createMockWorkflowRunningData()}
switchTab={mockSwitchTab}
/>,
)
const detailTab = screen.getByRole('button', { name: /runLog\.detail/i })
expect(detailTab).toHaveClass('border-util-colors-blue-brand-blue-brand-600')
})
})
// Tab click handling
describe('Tab Click Handling', () => {
it('should call switchTab with RESULT when RESULT tab is clicked', () => {
render(
<Tabs
currentTab="DETAIL"
workflowRunningData={createMockWorkflowRunningData()}
switchTab={mockSwitchTab}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /runLog\.result/i }))
expect(mockSwitchTab).toHaveBeenCalledWith('RESULT')
})
it('should call switchTab with DETAIL when DETAIL tab is clicked', () => {
render(
<Tabs
currentTab="RESULT"
workflowRunningData={createMockWorkflowRunningData()}
switchTab={mockSwitchTab}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /runLog\.detail/i }))
expect(mockSwitchTab).toHaveBeenCalledWith('DETAIL')
})
it('should call switchTab with TRACING when TRACING tab is clicked', () => {
render(
<Tabs
currentTab="RESULT"
workflowRunningData={createMockWorkflowRunningData()}
switchTab={mockSwitchTab}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /runLog\.tracing/i }))
expect(mockSwitchTab).toHaveBeenCalledWith('TRACING')
})
})
// Disabled state when no data
describe('Disabled State', () => {
it('should disable tabs when workflowRunningData is undefined', () => {
render(
<Tabs
currentTab="RESULT"
workflowRunningData={undefined}
switchTab={mockSwitchTab}
/>,
)
const resultTab = screen.getByRole('button', { name: /runLog\.result/i })
expect(resultTab).toBeDisabled()
})
})
})
// ============================================================================
// Tab Component Tests
// ============================================================================
vi.doUnmock('./result/tabs/tab')
describe('Tab', () => {
let Tab: typeof import('./result/tabs/tab').default
beforeAll(async () => {
const tabModule = await import('./result/tabs/tab')
Tab = tabModule.default
})
const mockOnClick = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
// Rendering tests
describe('Rendering', () => {
it('should render tab with label', () => {
render(
<Tab
isActive={false}
label="Test Tab"
value="TEST"
workflowRunningData={createMockWorkflowRunningData()}
onClick={mockOnClick}
/>,
)
expect(screen.getByRole('button', { name: 'Test Tab' })).toBeInTheDocument()
})
})
// Active state styling
describe('Active State', () => {
it('should have active styles when isActive is true', () => {
render(
<Tab
isActive={true}
label="Active Tab"
value="TEST"
workflowRunningData={createMockWorkflowRunningData()}
onClick={mockOnClick}
/>,
)
const tab = screen.getByRole('button')
expect(tab).toHaveClass('border-util-colors-blue-brand-blue-brand-600', 'text-text-primary')
})
it('should have inactive styles when isActive is false', () => {
render(
<Tab
isActive={false}
label="Inactive Tab"
value="TEST"
workflowRunningData={createMockWorkflowRunningData()}
onClick={mockOnClick}
/>,
)
const tab = screen.getByRole('button')
expect(tab).toHaveClass('border-transparent', 'text-text-tertiary')
})
})
// Click handling
describe('Click Handling', () => {
it('should call onClick with value when clicked', () => {
render(
<Tab
isActive={false}
label="Test Tab"
value="MY_VALUE"
workflowRunningData={createMockWorkflowRunningData()}
onClick={mockOnClick}
/>,
)
fireEvent.click(screen.getByRole('button'))
expect(mockOnClick).toHaveBeenCalledWith('MY_VALUE')
})
it('should not call onClick when disabled (no workflowRunningData)', () => {
render(
<Tab
isActive={false}
label="Test Tab"
value="MY_VALUE"
workflowRunningData={undefined}
onClick={mockOnClick}
/>,
)
const tab = screen.getByRole('button')
fireEvent.click(tab)
// The click handler is still called, but button is disabled
expect(tab).toBeDisabled()
})
})
// Disabled state
describe('Disabled State', () => {
it('should be disabled when workflowRunningData is undefined', () => {
render(
<Tab
isActive={false}
label="Test Tab"
value="TEST"
workflowRunningData={undefined}
onClick={mockOnClick}
/>,
)
const tab = screen.getByRole('button')
expect(tab).toBeDisabled()
expect(tab).toHaveClass('opacity-30')
})
it('should not be disabled when workflowRunningData is provided', () => {
render(
<Tab
isActive={false}
label="Test Tab"
value="TEST"
workflowRunningData={createMockWorkflowRunningData()}
onClick={mockOnClick}
/>,
)
const tab = screen.getByRole('button')
expect(tab).not.toBeDisabled()
})
})
})
// ============================================================================
// formatPreviewChunks Utility Tests
// ============================================================================
describe('formatPreviewChunks', () => {
let formatPreviewChunks: typeof import('./result/result-preview/utils').formatPreviewChunks
beforeAll(async () => {
const utilsModule = await import('./result/result-preview/utils')
formatPreviewChunks = utilsModule.formatPreviewChunks
})
// Edge cases
describe('Edge Cases', () => {
it('should return undefined for null outputs', () => {
expect(formatPreviewChunks(null)).toBeUndefined()
})
it('should return undefined for undefined outputs', () => {
expect(formatPreviewChunks(undefined)).toBeUndefined()
})
it('should return undefined for unknown chunk structure', () => {
const outputs = {
chunk_structure: 'unknown_mode',
preview: [],
}
expect(formatPreviewChunks(outputs)).toBeUndefined()
})
})
// General (text) chunks
describe('General Chunks (ChunkingMode.text)', () => {
it('should format general chunks correctly', () => {
const outputs = createMockGeneralOutputs(['content1', 'content2', 'content3'])
const result = formatPreviewChunks(outputs)
expect(result).toEqual(['content1', 'content2', 'content3'])
})
it('should limit to RAG_PIPELINE_PREVIEW_CHUNK_NUM chunks', () => {
const manyChunks = Array.from({ length: 10 }, (_, i) => `chunk${i}`)
const outputs = createMockGeneralOutputs(manyChunks)
const result = formatPreviewChunks(outputs) as string[]
// RAG_PIPELINE_PREVIEW_CHUNK_NUM is mocked to 5
expect(result).toHaveLength(5)
expect(result).toEqual(['chunk0', 'chunk1', 'chunk2', 'chunk3', 'chunk4'])
})
it('should handle empty preview array', () => {
const outputs = createMockGeneralOutputs([])
const result = formatPreviewChunks(outputs)
expect(result).toEqual([])
})
})
// Parent-child chunks
describe('Parent-Child Chunks (ChunkingMode.parentChild)', () => {
it('should format paragraph mode parent-child chunks correctly', () => {
const outputs = createMockParentChildOutputs('paragraph')
const result = formatPreviewChunks(outputs)
expect(result).toEqual({
parent_child_chunks: [
{ parent_content: 'parent1', child_contents: ['child1', 'child2'], parent_mode: 'paragraph' },
{ parent_content: 'parent2', child_contents: ['child3', 'child4'], parent_mode: 'paragraph' },
],
parent_mode: 'paragraph',
})
})
it('should format full-doc mode parent-child chunks and limit child chunks', () => {
const outputs = {
chunk_structure: ChunkingMode.parentChild,
parent_mode: 'full-doc' as const,
preview: [
{
content: 'full-doc-parent',
child_chunks: Array.from({ length: 10 }, (_, i) => `child${i}`),
},
],
}
const result = formatPreviewChunks(outputs)
expect(result).toEqual({
parent_child_chunks: [
{
parent_content: 'full-doc-parent',
child_contents: ['child0', 'child1', 'child2', 'child3', 'child4'], // Limited to 5
parent_mode: 'full-doc',
},
],
parent_mode: 'full-doc',
})
})
})
// QA chunks
describe('QA Chunks (ChunkingMode.qa)', () => {
it('should format QA chunks correctly', () => {
const outputs = createMockQAOutputs()
const result = formatPreviewChunks(outputs)
expect(result).toEqual({
qa_chunks: [
{ question: 'Q1', answer: 'A1' },
{ question: 'Q2', answer: 'A2' },
],
})
})
it('should limit QA chunks to RAG_PIPELINE_PREVIEW_CHUNK_NUM', () => {
const outputs = {
chunk_structure: ChunkingMode.qa,
qa_preview: Array.from({ length: 10 }, (_, i) => ({
question: `Q${i}`,
answer: `A${i}`,
})),
}
const result = formatPreviewChunks(outputs) as { qa_chunks: Array<{ question: string, answer: string }> }
expect(result.qa_chunks).toHaveLength(5)
})
})
})
// ============================================================================
// Types Tests
// ============================================================================
describe('Types', () => {
describe('TestRunStep Enum', () => {
it('should have correct enum values', async () => {
const { TestRunStep } = await import('./types')
expect(TestRunStep.dataSource).toBe('dataSource')
expect(TestRunStep.documentProcessing).toBe('documentProcessing')
})
})
})

View File

@ -0,0 +1,549 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Actions from './index'
// ============================================================================
// Actions Component Tests
// ============================================================================
describe('Actions', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// -------------------------------------------------------------------------
// Rendering Tests
// -------------------------------------------------------------------------
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
// Assert
expect(screen.getByRole('button')).toBeInTheDocument()
})
it('should render button with translated text', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
// Assert - Translation mock returns key with namespace prefix
expect(screen.getByText('datasetCreation.stepOne.button')).toBeInTheDocument()
})
it('should render with correct container structure', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { container } = render(<Actions handleNextStep={handleNextStep} />)
// Assert
const wrapper = container.firstChild as HTMLElement
expect(wrapper.className).toContain('flex')
expect(wrapper.className).toContain('justify-end')
expect(wrapper.className).toContain('p-4')
expect(wrapper.className).toContain('pt-2')
})
it('should render span with px-0.5 class around text', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { container } = render(<Actions handleNextStep={handleNextStep} />)
// Assert
const span = container.querySelector('span')
expect(span).toBeInTheDocument()
expect(span?.className).toContain('px-0.5')
})
})
// -------------------------------------------------------------------------
// Props Variations Tests
// -------------------------------------------------------------------------
describe('Props Variations', () => {
it('should pass disabled=true to button when disabled prop is true', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
// Assert
expect(screen.getByRole('button')).toBeDisabled()
})
it('should pass disabled=false to button when disabled prop is false', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions disabled={false} handleNextStep={handleNextStep} />)
// Assert
expect(screen.getByRole('button')).not.toBeDisabled()
})
it('should not disable button when disabled prop is undefined', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
// Assert
expect(screen.getByRole('button')).not.toBeDisabled()
})
it('should handle disabled switching from true to false', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { rerender } = render(
<Actions disabled={true} handleNextStep={handleNextStep} />,
)
// Assert - Initially disabled
expect(screen.getByRole('button')).toBeDisabled()
// Act - Rerender with disabled=false
rerender(<Actions disabled={false} handleNextStep={handleNextStep} />)
// Assert - Now enabled
expect(screen.getByRole('button')).not.toBeDisabled()
})
it('should handle disabled switching from false to true', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { rerender } = render(
<Actions disabled={false} handleNextStep={handleNextStep} />,
)
// Assert - Initially enabled
expect(screen.getByRole('button')).not.toBeDisabled()
// Act - Rerender with disabled=true
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
// Assert - Now disabled
expect(screen.getByRole('button')).toBeDisabled()
})
it('should handle undefined disabled becoming true', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { rerender } = render(
<Actions handleNextStep={handleNextStep} />,
)
// Assert - Initially not disabled (undefined)
expect(screen.getByRole('button')).not.toBeDisabled()
// Act - Rerender with disabled=true
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
// Assert - Now disabled
expect(screen.getByRole('button')).toBeDisabled()
})
})
// -------------------------------------------------------------------------
// User Interaction Tests
// -------------------------------------------------------------------------
describe('User Interactions', () => {
it('should call handleNextStep when button is clicked', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep).toHaveBeenCalledTimes(1)
})
it('should call handleNextStep exactly once per click', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep).toHaveBeenCalled()
expect(handleNextStep.mock.calls).toHaveLength(1)
})
it('should call handleNextStep multiple times on multiple clicks', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
const button = screen.getByRole('button')
fireEvent.click(button)
fireEvent.click(button)
fireEvent.click(button)
// Assert
expect(handleNextStep).toHaveBeenCalledTimes(3)
})
it('should not call handleNextStep when button is disabled and clicked', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
fireEvent.click(screen.getByRole('button'))
// Assert - Disabled button should not trigger onClick
expect(handleNextStep).not.toHaveBeenCalled()
})
it('should handle rapid clicks when not disabled', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
const button = screen.getByRole('button')
// Simulate rapid clicks
for (let i = 0; i < 10; i++)
fireEvent.click(button)
// Assert
expect(handleNextStep).toHaveBeenCalledTimes(10)
})
})
// -------------------------------------------------------------------------
// Callback Stability Tests
// -------------------------------------------------------------------------
describe('Callback Stability', () => {
it('should use the new handleNextStep when prop changes', () => {
// Arrange
const handleNextStep1 = vi.fn()
const handleNextStep2 = vi.fn()
// Act
const { rerender } = render(
<Actions handleNextStep={handleNextStep1} />,
)
fireEvent.click(screen.getByRole('button'))
rerender(<Actions handleNextStep={handleNextStep2} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep1).toHaveBeenCalledTimes(1)
expect(handleNextStep2).toHaveBeenCalledTimes(1)
})
it('should maintain functionality after rerender with same props', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { rerender } = render(
<Actions handleNextStep={handleNextStep} />,
)
fireEvent.click(screen.getByRole('button'))
rerender(<Actions handleNextStep={handleNextStep} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep).toHaveBeenCalledTimes(2)
})
it('should work correctly when handleNextStep changes multiple times', () => {
// Arrange
const handleNextStep1 = vi.fn()
const handleNextStep2 = vi.fn()
const handleNextStep3 = vi.fn()
// Act
const { rerender } = render(
<Actions handleNextStep={handleNextStep1} />,
)
fireEvent.click(screen.getByRole('button'))
rerender(<Actions handleNextStep={handleNextStep2} />)
fireEvent.click(screen.getByRole('button'))
rerender(<Actions handleNextStep={handleNextStep3} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep1).toHaveBeenCalledTimes(1)
expect(handleNextStep2).toHaveBeenCalledTimes(1)
expect(handleNextStep3).toHaveBeenCalledTimes(1)
})
})
// -------------------------------------------------------------------------
// Memoization Tests
// -------------------------------------------------------------------------
describe('Memoization', () => {
it('should be wrapped with React.memo', () => {
// Arrange
const handleNextStep = vi.fn()
// Act - Verify component is memoized by checking display name pattern
const { rerender } = render(
<Actions handleNextStep={handleNextStep} />,
)
// Rerender with same props should work without issues
rerender(<Actions handleNextStep={handleNextStep} />)
// Assert - Component should render correctly after rerender
expect(screen.getByRole('button')).toBeInTheDocument()
})
it('should not break when props remain the same across rerenders', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { rerender } = render(
<Actions disabled={false} handleNextStep={handleNextStep} />,
)
// Multiple rerenders with same props
for (let i = 0; i < 5; i++) {
rerender(<Actions disabled={false} handleNextStep={handleNextStep} />)
}
// Assert - Should still function correctly
fireEvent.click(screen.getByRole('button'))
expect(handleNextStep).toHaveBeenCalledTimes(1)
})
it('should update correctly when only disabled prop changes', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { rerender } = render(
<Actions disabled={false} handleNextStep={handleNextStep} />,
)
// Assert - Initially not disabled
expect(screen.getByRole('button')).not.toBeDisabled()
// Act - Change only disabled prop
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
// Assert - Should reflect the new disabled state
expect(screen.getByRole('button')).toBeDisabled()
})
it('should update correctly when only handleNextStep prop changes', () => {
// Arrange
const handleNextStep1 = vi.fn()
const handleNextStep2 = vi.fn()
// Act
const { rerender } = render(
<Actions disabled={false} handleNextStep={handleNextStep1} />,
)
fireEvent.click(screen.getByRole('button'))
expect(handleNextStep1).toHaveBeenCalledTimes(1)
// Act - Change only handleNextStep prop
rerender(<Actions disabled={false} handleNextStep={handleNextStep2} />)
fireEvent.click(screen.getByRole('button'))
// Assert - New callback should be used
expect(handleNextStep1).toHaveBeenCalledTimes(1)
expect(handleNextStep2).toHaveBeenCalledTimes(1)
})
})
// -------------------------------------------------------------------------
// Edge Cases Tests
// -------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should call handleNextStep even if it has side effects', () => {
// Arrange
let sideEffectValue = 0
const handleNextStep = vi.fn(() => {
sideEffectValue = 42
})
// Act
render(<Actions handleNextStep={handleNextStep} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep).toHaveBeenCalledTimes(1)
expect(sideEffectValue).toBe(42)
})
it('should handle handleNextStep that returns a value', () => {
// Arrange
const handleNextStep = vi.fn(() => 'return value')
// Act
render(<Actions handleNextStep={handleNextStep} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep).toHaveBeenCalledTimes(1)
expect(handleNextStep).toHaveReturnedWith('return value')
})
it('should handle handleNextStep that is async', async () => {
// Arrange
const handleNextStep = vi.fn().mockResolvedValue(undefined)
// Act
render(<Actions handleNextStep={handleNextStep} />)
fireEvent.click(screen.getByRole('button'))
// Assert
expect(handleNextStep).toHaveBeenCalledTimes(1)
})
it('should render correctly with both disabled=true and handleNextStep', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
// Assert
const button = screen.getByRole('button')
expect(button).toBeDisabled()
})
it('should handle component unmount gracefully', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { unmount } = render(<Actions handleNextStep={handleNextStep} />)
// Assert - Unmount should not throw
expect(() => unmount()).not.toThrow()
})
it('should handle disabled as boolean-like falsy value', () => {
// Arrange
const handleNextStep = vi.fn()
// Act - Test with explicit false
render(<Actions disabled={false} handleNextStep={handleNextStep} />)
// Assert
expect(screen.getByRole('button')).not.toBeDisabled()
})
})
// -------------------------------------------------------------------------
// Accessibility Tests
// -------------------------------------------------------------------------
describe('Accessibility', () => {
it('should have button element that can receive focus', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions handleNextStep={handleNextStep} />)
const button = screen.getByRole('button')
// Assert - Button should be focusable (not disabled by default)
expect(button).not.toBeDisabled()
})
it('should indicate disabled state correctly', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
// Assert
expect(screen.getByRole('button')).toHaveAttribute('disabled')
})
})
// -------------------------------------------------------------------------
// Integration Tests
// -------------------------------------------------------------------------
describe('Integration', () => {
it('should work in a typical workflow: enable -> click -> disable', () => {
// Arrange
const handleNextStep = vi.fn()
// Act - Start enabled
const { rerender } = render(
<Actions disabled={false} handleNextStep={handleNextStep} />,
)
// Assert - Can click when enabled
expect(screen.getByRole('button')).not.toBeDisabled()
fireEvent.click(screen.getByRole('button'))
expect(handleNextStep).toHaveBeenCalledTimes(1)
// Act - Disable after click (simulating loading state)
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
// Assert - Cannot click when disabled
expect(screen.getByRole('button')).toBeDisabled()
fireEvent.click(screen.getByRole('button'))
expect(handleNextStep).toHaveBeenCalledTimes(1) // Still 1, not 2
// Act - Re-enable
rerender(<Actions disabled={false} handleNextStep={handleNextStep} />)
// Assert - Can click again
expect(screen.getByRole('button')).not.toBeDisabled()
fireEvent.click(screen.getByRole('button'))
expect(handleNextStep).toHaveBeenCalledTimes(2)
})
it('should maintain consistent rendering across multiple state changes', () => {
// Arrange
const handleNextStep = vi.fn()
// Act
const { rerender } = render(
<Actions disabled={false} handleNextStep={handleNextStep} />,
)
// Toggle disabled state multiple times
const states = [true, false, true, false, true]
states.forEach((disabled) => {
rerender(<Actions disabled={disabled} handleNextStep={handleNextStep} />)
if (disabled)
expect(screen.getByRole('button')).toBeDisabled()
else
expect(screen.getByRole('button')).not.toBeDisabled()
})
// Assert - Button should still render correctly
expect(screen.getByRole('button')).toBeInTheDocument()
expect(screen.getByText('datasetCreation.stepOne.button')).toBeInTheDocument()
})
})
})

File diff suppressed because it is too large Load Diff

View File

@ -85,7 +85,11 @@ const PublishAsKnowledgePipelineModal = ({
>
<div className="title-2xl-semi-bold relative flex items-center p-6 pb-3 pr-14 text-text-primary">
{t('common.publishAs', { ns: 'pipeline' })}
<div className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center" onClick={onCancel}>
<div
data-testid="publish-modal-close-btn"
className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center"
onClick={onCancel}
>
<RiCloseLine className="h-4 w-4 text-text-tertiary" />
</div>
</div>

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More