mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/hitl-frontend
This commit is contained in:
commit
cf9b72d574
|
|
@ -3,6 +3,7 @@
|
|||
"feature-dev@claude-plugins-official": true,
|
||||
"context7@claude-plugins-official": true,
|
||||
"typescript-lsp@claude-plugins-official": true,
|
||||
"pyright-lsp@claude-plugins-official": true
|
||||
"pyright-lsp@claude-plugins-official": true,
|
||||
"ralph-wiggum@claude-plugins-official": true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
---
|
||||
name: frontend-code-review
|
||||
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
|
||||
---
|
||||
|
||||
# Frontend Code Review
|
||||
|
||||
## Intent
|
||||
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
|
||||
|
||||
1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission.
|
||||
2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings.
|
||||
|
||||
Stick to the checklist below for every applicable file and mode.
|
||||
|
||||
## Checklist
|
||||
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
|
||||
|
||||
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
|
||||
|
||||
## Review Process
|
||||
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
|
||||
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
|
||||
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
|
||||
|
||||
## Required output
|
||||
When invoked, the response must exactly follow one of the two templates:
|
||||
|
||||
### Template A (any findings)
|
||||
```
|
||||
# Code review
|
||||
Found <N> urgent issues need to be fixed:
|
||||
|
||||
## 1 <brief description of bug>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
|
||||
---
|
||||
... (repeat for each urgent issue) ...
|
||||
|
||||
Found <M> suggestions for improvement:
|
||||
|
||||
## 1 <brief description of suggestion>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
|
||||
---
|
||||
|
||||
... (repeat for each suggestion) ...
|
||||
```
|
||||
|
||||
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
|
||||
|
||||
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
|
||||
|
||||
Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
|
||||
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
```
|
||||
## Code review
|
||||
No issues found.
|
||||
```
|
||||
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Rule Catalog — Business Logic
|
||||
|
||||
## Can't use workflowStore in Node components
|
||||
|
||||
IsUrgent: True
|
||||
|
||||
### Description
|
||||
|
||||
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
|
||||
|
||||
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# Rule Catalog — Code Quality
|
||||
|
||||
## Conditional class names use utility function
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
```ts
|
||||
import { cn } from '@/utils/classnames'
|
||||
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
|
||||
```
|
||||
|
||||
## Tailwind-first styling
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
|
||||
### Description
|
||||
|
||||
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
|
||||
|
||||
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
|
||||
|
||||
## Classname ordering for easy overrides
|
||||
|
||||
### Description
|
||||
|
||||
When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles.
|
||||
|
||||
Example:
|
||||
|
||||
```tsx
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const Button = ({ className }) => {
|
||||
return <div className={cn('bg-primary-600', className)}></div>
|
||||
}
|
||||
```
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# Rule Catalog — Performance
|
||||
|
||||
## React Flow data usage
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
|
||||
|
||||
## Complex prop memoization
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
|
||||
|
||||
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
|
||||
|
||||
Wrong:
|
||||
|
||||
```tsx
|
||||
<HeavyComp
|
||||
config={{
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}}
|
||||
/>
|
||||
```
|
||||
|
||||
Right:
|
||||
|
||||
```tsx
|
||||
const config = useMemo(() => ({
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}), [provider, detail]);
|
||||
|
||||
<HeavyComp
|
||||
config={config}
|
||||
/>
|
||||
```
|
||||
|
|
@ -20,4 +20,4 @@
|
|||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ on:
|
|||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
|
@ -18,7 +19,8 @@ jobs:
|
|||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
@ -26,21 +28,28 @@ jobs:
|
|||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
# Skip check for manual trigger, translate all files
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .json)
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
echo "FILE_ARGS=" >> $GITHUB_ENV
|
||||
echo "Manual trigger: translating all files"
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .json)
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Install pnpm
|
||||
|
|
|
|||
|
|
@ -235,3 +235,4 @@ scripts/stress-test/reports/
|
|||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
|
|
|||
5
Makefile
5
Makefile
|
|
@ -60,9 +60,10 @@ check:
|
|||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
|
|
@ -122,7 +123,7 @@ help:
|
|||
@echo "Backend Code Quality:"
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format and fix code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make test - Run backend unit tests"
|
||||
@echo ""
|
||||
|
|
|
|||
|
|
@ -502,6 +502,8 @@ LOG_FILE_BACKUP_COUNT=5
|
|||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log output format: text or json
|
||||
LOG_OUTPUT_FORMAT=text
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@ root_packages =
|
|||
core
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
models
|
||||
tasks
|
||||
services
|
||||
include_external_packages = True
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
|
|
@ -33,6 +35,29 @@ ignore_imports =
|
|||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow
|
||||
forbidden_modules =
|
||||
extensions.ext_database
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.variable_assigner.common.impl -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ import logging
|
|||
import time
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# add an unique identifier to each request
|
||||
# Initialize logging context for this request
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
def add_trace_id_header(response):
|
||||
def add_trace_headers(response):
|
||||
try:
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
if ctx and ctx.is_valid:
|
||||
trace_id_hex = format(ctx.trace_id, "032x")
|
||||
# Avoid duplicates if some middleware added it
|
||||
if "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = trace_id_hex
|
||||
|
||||
if not ctx or not ctx.is_valid:
|
||||
return response
|
||||
|
||||
# Inject trace headers from OTEL context
|
||||
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
|
||||
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
|
||||
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
|
||||
|
||||
except Exception:
|
||||
# Never break the response due to tracing header injection
|
||||
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
_ = add_trace_id_header
|
||||
_ = add_trace_headers
|
||||
|
||||
return dify_app
|
||||
|
||||
|
|
|
|||
211
api/commands.py
211
api/commands.py
|
|
@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool):
|
|||
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
|
||||
|
||||
|
||||
@click.command("file-usage", help="Query file usages and show where files are referenced.")
|
||||
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
|
||||
@click.option("--key", type=str, default=None, help="Filter by storage key.")
|
||||
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
|
||||
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
|
||||
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
|
||||
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
|
||||
def file_usage(
|
||||
file_id: str | None,
|
||||
key: str | None,
|
||||
src: str | None,
|
||||
limit: int,
|
||||
offset: int,
|
||||
output_json: bool,
|
||||
):
|
||||
"""
|
||||
Query file usages and show where files are referenced in the database.
|
||||
|
||||
This command reuses the same reference checking logic as clear-orphaned-file-records
|
||||
and displays detailed information about where each file is referenced.
|
||||
"""
|
||||
# define tables and columns to process
|
||||
files_tables = [
|
||||
{"table": "upload_files", "id_column": "id", "key_column": "key"},
|
||||
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
|
||||
]
|
||||
ids_tables = [
|
||||
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
|
||||
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
|
||||
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
|
||||
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
|
||||
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
|
||||
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
|
||||
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
|
||||
]
|
||||
|
||||
# Stream file usages with pagination to avoid holding all results in memory
|
||||
paginated_usages = []
|
||||
total_count = 0
|
||||
|
||||
# First, build a mapping of file_id -> storage_key from the base tables
|
||||
file_key_map = {}
|
||||
for files_table in files_tables:
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
|
||||
|
||||
# If filtering by key or file_id, verify it exists
|
||||
if file_id and file_id not in file_key_map:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
if key:
|
||||
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
|
||||
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
|
||||
if not matching_file_ids:
|
||||
if output_json:
|
||||
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
|
||||
else:
|
||||
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
|
||||
return
|
||||
|
||||
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# For each reference table/column, find matching file IDs and record the references
|
||||
for ids_table in ids_tables:
|
||||
src_filter = f"{ids_table['table']}.{ids_table['column']}"
|
||||
|
||||
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
|
||||
if src:
|
||||
if "%" in src or "_" in src:
|
||||
import fnmatch
|
||||
|
||||
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
|
||||
pattern = src.replace("%", "*").replace("_", "?")
|
||||
if not fnmatch.fnmatch(src_filter, pattern):
|
||||
continue
|
||||
else:
|
||||
if src_filter != src:
|
||||
continue
|
||||
|
||||
if ids_table["type"] == "uuid":
|
||||
# Direct UUID match
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
ref_file_id = str(row[1])
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
elif ids_table["type"] in ("text", "json"):
|
||||
# Extract UUIDs from text/json content
|
||||
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
|
||||
query = (
|
||||
f"SELECT {ids_table['pk_column']}, {column_cast} "
|
||||
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query))
|
||||
for row in rs:
|
||||
record_id = str(row[0])
|
||||
content = str(row[1])
|
||||
|
||||
# Find all UUIDs in the content
|
||||
import re
|
||||
|
||||
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
|
||||
matches = uuid_pattern.findall(content)
|
||||
|
||||
for ref_file_id in matches:
|
||||
if ref_file_id not in file_key_map:
|
||||
continue
|
||||
storage_key = file_key_map[ref_file_id]
|
||||
|
||||
# Apply filters
|
||||
if file_id and ref_file_id != file_id:
|
||||
continue
|
||||
if key and not storage_key.endswith(key):
|
||||
continue
|
||||
|
||||
# Only collect items within the requested page range
|
||||
if offset <= total_count < offset + limit:
|
||||
paginated_usages.append(
|
||||
{
|
||||
"src": f"{ids_table['table']}.{ids_table['column']}",
|
||||
"record_id": record_id,
|
||||
"file_id": ref_file_id,
|
||||
"key": storage_key,
|
||||
}
|
||||
)
|
||||
total_count += 1
|
||||
|
||||
# Output results
|
||||
if output_json:
|
||||
result = {
|
||||
"total": total_count,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"usages": paginated_usages,
|
||||
}
|
||||
click.echo(json.dumps(result, indent=2))
|
||||
else:
|
||||
click.echo(
|
||||
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
|
||||
)
|
||||
click.echo("")
|
||||
|
||||
if not paginated_usages:
|
||||
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
|
||||
return
|
||||
|
||||
# Print table header
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
click.echo(click.style("-" * 190, fg="white"))
|
||||
|
||||
# Print each usage
|
||||
for usage in paginated_usages:
|
||||
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
|
||||
|
||||
# Show pagination info
|
||||
if offset + limit < total_count:
|
||||
click.echo("")
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
|
||||
)
|
||||
)
|
||||
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
|
||||
|
||||
|
||||
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
|
|
|
|||
|
|
@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
|
|||
default="INFO",
|
||||
)
|
||||
|
||||
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
|
||||
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
LOG_FILE: str | None = Field(
|
||||
description="File path for log output.",
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
|
|||
description="Authentication token for Milvus, if token-based authentication is enabled",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: str | None = Field(
|
||||
description="Username for authenticating with Milvus, if username/password authentication is enabled",
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
|
|
@ -177,6 +176,12 @@ annotation_hit_history_model = console_ns.model(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
|
|
|
|||
|
|
@ -751,12 +751,12 @@ class DocumentApi(DocumentResource):
|
|||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"data_source_info": document.data_source_info_dict,
|
||||
"data_source_detail_dict": document.data_source_detail_dict,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
|
|
@ -784,12 +784,12 @@ class DocumentApi(DocumentResource):
|
|||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
"position": document.position,
|
||||
"data_source_type": document.data_source_type,
|
||||
"data_source_info": data_source_info,
|
||||
"data_source_info": document.data_source_info_dict,
|
||||
"data_source_detail_dict": document.data_source_detail_dict,
|
||||
"dataset_process_rule_id": document.dataset_process_rule_id,
|
||||
"dataset_process_rule": dataset_process_rules,
|
||||
"document_process_rule": document_process_rules,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError
|
|||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
ResultResponse,
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
|
@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
|
|||
endpoint="installed_app_conversations",
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
|
|
@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
|
|
@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource):
|
|||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=args.pinned,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource):
|
|||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource):
|
|||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
|
|
@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource):
|
|||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
return ConversationService.rename(
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, current_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource):
|
|||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource):
|
|||
raise ValueError("current_user must be an Account instance")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import logging
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -23,7 +22,8 @@ from controllers.console.explore.wraps import InstalledAppResource
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
|
|
@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
|
|||
endpoint="installed_app_messages",
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource):
|
|||
args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.conversation_id),
|
||||
str(args.first_id) if args.first_id else None,
|
||||
args.limit,
|
||||
)
|
||||
adapter = TypeAdapter(MessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return MessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
|
|
@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
|
@ -26,28 +26,8 @@ class SavedMessageCreatePayload(BaseModel):
|
|||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -57,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
|
||||
args = SavedMessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(
|
||||
pagination = SavedMessageService.pagination_by_last_id(
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.last_id) if args.last_id else None,
|
||||
args.limit,
|
||||
)
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
|
|
@ -78,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
|
@ -96,4 +83,4 @@ class SavedMessageApi(InstalledAppResource):
|
|||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ from typing import Any
|
|||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel):
|
|||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_at_least_one_field(self):
|
||||
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
|
||||
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class TriggerSubscriptionVerifyRequest(BaseModel):
|
||||
"""Request payload for verifying subscription credentials."""
|
||||
|
|
@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
|||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=user.current_tenant_id,
|
||||
|
|
@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource):
|
|||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
|
||||
try:
|
||||
# rename only
|
||||
if (
|
||||
args.name is not None
|
||||
and args.credentials is None
|
||||
and args.parameters is None
|
||||
and args.properties is None
|
||||
):
|
||||
# For rename only, just update the name
|
||||
rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
|
||||
# When credential type is UNAUTHORIZED, it indicates the subscription was manually created
|
||||
# For Manually created subscription, they dont have credentials, parameters
|
||||
# They only have name and properties(which is input by user)
|
||||
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
|
||||
if rename or manually_created:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
name=request.name,
|
||||
properties=request.properties,
|
||||
)
|
||||
return 200
|
||||
|
||||
# rebuild for create automatically by the provider
|
||||
match subscription.credential_type:
|
||||
case CredentialType.UNAUTHORIZED:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
properties=args.properties,
|
||||
)
|
||||
return 200
|
||||
case CredentialType.API_KEY | CredentialType.OAUTH2:
|
||||
if args.credentials:
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in args.credentials.items()
|
||||
}
|
||||
else:
|
||||
new_credentials = subscription.credentials
|
||||
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
name=args.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
credentials=new_credentials,
|
||||
parameters=args.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
case _:
|
||||
raise BadRequest("Invalid credential type")
|
||||
# For the rest cases(API_KEY, OAUTH2)
|
||||
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
name=request.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
credentials=request.credentials or subscription.credentials,
|
||||
parameters=request.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ from uuid import UUID
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
|
|
@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
build_conversation_delete_model,
|
||||
build_conversation_infinite_scroll_pagination_model,
|
||||
build_simple_conversation_model,
|
||||
ConversationDelete,
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
|
|
@ -105,7 +104,6 @@ class ConversationApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List all conversations for the current user.
|
||||
|
||||
|
|
@ -120,7 +118,7 @@ class ConversationApi(Resource):
|
|||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
pagination = ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
|
|
@ -129,6 +127,13 @@ class ConversationApi(Resource):
|
|||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=query_args.sort_by,
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
@ -146,7 +151,6 @@ class ConversationDetailApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -159,7 +163,7 @@ class ConversationDetailApi(Resource):
|
|||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
return ConversationDelete(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
|
@ -176,7 +180,6 @@ class ConversationRenameApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -188,7 +191,14 @@ class ConversationRenameApi(Resource):
|
|||
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, payload.name, payload.auto_generate
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns
|
|||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
|
|
@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel):
|
|||
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Namespace):
|
||||
"""Build the message model for the API or Namespace."""
|
||||
# First build the nested models
|
||||
feedback_model = build_feedback_model(api_or_ns)
|
||||
agent_thought_model = build_agent_thought_model(api_or_ns)
|
||||
message_file_model = build_message_file_model(api_or_ns)
|
||||
|
||||
# Then build the message fields with nested models
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.Raw(
|
||||
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
|
||||
if obj.message_metadata
|
||||
else []
|
||||
),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the message infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested message model first
|
||||
message_model = build_message_model(api_or_ns)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_model)),
|
||||
}
|
||||
return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
|
||||
|
||||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
|
||||
|
|
@ -104,7 +58,6 @@ class MessageListApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""List messages in a conversation.
|
||||
|
||||
|
|
@ -119,9 +72,16 @@ class MessageListApi(Resource):
|
|||
first_id = str(query_args.first_id) if query_args.first_id else None
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model, end_user, conversation_id, first_id, query_args.limit
|
||||
)
|
||||
adapter = TypeAdapter(MessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return MessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
|
|
@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -8,7 +9,11 @@ from controllers.web.error import NotChatAppError
|
|||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from fields.conversation_fields import (
|
||||
ConversationInfiniteScrollPagination,
|
||||
ResultResponse,
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
|
|
@ -54,7 +59,6 @@ class ConversationListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -82,7 +86,7 @@ class ConversationListApi(WebApiResource):
|
|||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
|
|
@ -92,16 +96,19 @@ class ConversationListApi(WebApiResource):
|
|||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
adapter = TypeAdapter(SimpleConversation)
|
||||
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
|
||||
return ConversationInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=conversations,
|
||||
).model_dump(mode="json")
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>")
|
||||
class ConversationApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Conversation")
|
||||
@web_ns.doc(description="Delete a specific conversation.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
|
|
@ -115,7 +122,6 @@ class ConversationApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -126,7 +132,7 @@ class ConversationApi(WebApiResource):
|
|||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
|
@ -155,7 +161,6 @@ class ConversationRenameApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -171,17 +176,20 @@ class ConversationRenameApi(WebApiResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
|
||||
)
|
||||
return (
|
||||
TypeAdapter(SimpleConversation)
|
||||
.validate_python(conversation, from_attributes=True)
|
||||
.model_dump(mode="json")
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/pin")
|
||||
class ConversationPinApi(WebApiResource):
|
||||
pin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Pin Conversation")
|
||||
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
|
|
@ -195,7 +203,6 @@ class ConversationPinApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(pin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -208,15 +215,11 @@ class ConversationPinApi(WebApiResource):
|
|||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/unpin")
|
||||
class ConversationUnPinApi(WebApiResource):
|
||||
unpin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Unpin Conversation")
|
||||
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
|
|
@ -230,7 +233,6 @@ class ConversationUnPinApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(unpin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -239,4 +241,4 @@ class ConversationUnPinApi(WebApiResource):
|
|||
conversation_id = str(c_id)
|
||||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ import logging
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
|
|
@ -70,29 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message
|
|||
|
||||
@web_ns.route("/messages")
|
||||
class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Message List")
|
||||
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
|
||||
@web_ns.doc(
|
||||
|
|
@ -121,7 +96,6 @@ class MessageListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -131,9 +105,16 @@ class MessageListApi(WebApiResource):
|
|||
query = MessageListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model, end_user, query.conversation_id, query.first_id, query.limit
|
||||
)
|
||||
adapter = TypeAdapter(WebMessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return WebMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except FirstMessageNotExistsError:
|
||||
|
|
@ -142,10 +123,6 @@ class MessageListApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(WebApiResource):
|
||||
feedback_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Create Message Feedback")
|
||||
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
|
|
@ -170,7 +147,6 @@ class MessageFeedbackApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(feedback_response_fields)
|
||||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
|
|
@ -187,7 +163,7 @@ class MessageFeedbackApi(WebApiResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
|
||||
|
|
@ -247,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
suggested_questions_response_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Suggested Questions")
|
||||
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
|
|
@ -264,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(suggested_questions_response_fields)
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
|
@ -277,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
|
||||
)
|
||||
# questions is a list of strings, not a list of Message objects
|
||||
# so we can directly return it
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
|
|
@ -296,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
|
||||
|
|
|
|||
|
|
@ -1,40 +1,20 @@
|
|||
from flask_restx import fields, marshal_with, reqparse
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import uuid_value
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
feedback_fields = {"rating": fields.String}
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages")
|
||||
class SavedMessageListApi(WebApiResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
post_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Saved Messages")
|
||||
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
|
||||
@web_ns.doc(
|
||||
|
|
@ -58,7 +38,6 @@ class SavedMessageListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
|
@ -70,7 +49,14 @@ class SavedMessageListApi(WebApiResource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
).model_dump(mode="json")
|
||||
|
||||
@web_ns.doc("Save Message")
|
||||
@web_ns.doc(description="Save a specific message for later reference.")
|
||||
|
|
@ -89,7 +75,6 @@ class SavedMessageListApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(post_response_fields)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
|
@ -102,15 +87,11 @@ class SavedMessageListApi(WebApiResource):
|
|||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages/<uuid:message_id>")
|
||||
class SavedMessageApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Saved Message")
|
||||
@web_ns.doc(description="Remove a message from saved messages.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
|
||||
|
|
@ -124,7 +105,6 @@ class SavedMessageApi(WebApiResource):
|
|||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
|
|
@ -133,4 +113,4 @@ class SavedMessageApi(WebApiResource):
|
|||
|
||||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
|
|
|
|||
|
|
@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _inject_trace_headers(headers: dict | None) -> dict:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
|
||||
When OTEL is disabled, we manually inject the traceparent header.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return headers
|
||||
|
||||
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
|
||||
if dify_config.ENABLE_OTEL:
|
||||
return headers
|
||||
|
||||
# Generate and inject traceparent for non-OTEL scenarios
|
||||
try:
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
except Exception:
|
||||
# Silently ignore errors to avoid breaking requests
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
# Convert requests-style allow_redirects to httpx-style follow_redirects
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
if "follow_redirects" not in kwargs:
|
||||
|
|
@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
|
||||
headers = kwargs.get("headers") or {}
|
||||
headers = _inject_trace_headers(headers)
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Preserve user-provided Host header
|
||||
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||
headers = kwargs.get("headers", {})
|
||||
user_provided_host = _get_user_provided_host_header(headers)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
# Build the request manually to preserve the Host header
|
||||
# httpx may override the Host header when using a proxy, so we use
|
||||
# the request API to explicitly set headers before sending
|
||||
# Preserve the user-provided Host header
|
||||
# httpx may override the Host header when using a proxy
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||
if user_provided_host is not None:
|
||||
headers["host"] = user_provided_host
|
||||
|
|
|
|||
|
|
@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
|
|||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def get_span_id_from_otel_context() -> str | None:
|
||||
"""
|
||||
Retrieve the current span ID from the active OpenTelemetry trace context.
|
||||
|
||||
Returns:
|
||||
A 16-character hex string representing the span ID, or None if not available.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return None
|
||||
|
||||
span_context = span.get_span_context()
|
||||
if not span_context or span_context.span_id == INVALID_SPAN_ID:
|
||||
return None
|
||||
|
||||
return f"{span_context.span_id:016x}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_traceparent_header() -> str | None:
|
||||
"""
|
||||
Generate a W3C traceparent header from the current context.
|
||||
|
||||
Uses OpenTelemetry context if available, otherwise uses the
|
||||
ContextVar-based trace_id from the logging context.
|
||||
|
||||
Format: {version}-{trace_id}-{span_id}-{flags}
|
||||
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
|
||||
|
||||
Returns:
|
||||
A valid traceparent header string, or None if generation fails.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Try OTEL context first
|
||||
trace_id = get_trace_id_from_otel_context()
|
||||
span_id = get_span_id_from_otel_context()
|
||||
|
||||
if trace_id and span_id:
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
||||
# Fallback: use ContextVar-based trace_id or generate new one
|
||||
from core.logging.context import get_trace_id as get_logging_trace_id
|
||||
|
||||
trace_id = get_logging_trace_id() or uuid.uuid4().hex
|
||||
|
||||
# Generate a new span_id (16 hex chars)
|
||||
span_id = uuid.uuid4().hex[:16]
|
||||
|
||||
return f"00-{trace_id}-{span_id}-01"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
"""Structured logging components for Dify."""
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
__all__ = [
|
||||
"IdentityContextFilter",
|
||||
"StructuredJSONFormatter",
|
||||
"TraceContextFilter",
|
||||
"clear_request_context",
|
||||
"get_request_id",
|
||||
"get_trace_id",
|
||||
"init_request_context",
|
||||
]
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
"""Request context for logging - framework agnostic.
|
||||
|
||||
This module provides request-scoped context variables for logging,
|
||||
using Python's contextvars for thread-safe and async-safe storage.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
|
||||
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
|
||||
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get current request ID (10 hex chars)."""
|
||||
return _request_id.get()
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
|
||||
return _trace_id.get()
|
||||
|
||||
|
||||
def init_request_context() -> None:
|
||||
"""Initialize request context. Call at start of each request."""
|
||||
req_id = uuid.uuid4().hex[:10]
|
||||
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
|
||||
_request_id.set(req_id)
|
||||
_trace_id.set(trace_id)
|
||||
|
||||
|
||||
def clear_request_context() -> None:
|
||||
"""Clear request context. Call at end of request (optional)."""
|
||||
_request_id.set("")
|
||||
_trace_id.set("")
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
"""Logging filters for structured logging."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import flask
|
||||
|
||||
from core.logging.context import get_request_id, get_trace_id
|
||||
|
||||
|
||||
class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds trace_id and span_id to log records.
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = get_request_id()
|
||||
|
||||
return True
|
||||
|
||||
def _get_otel_context(self) -> tuple[str, str]:
|
||||
"""Extract trace_id and span_id from OpenTelemetry context."""
|
||||
with contextlib.suppress(Exception):
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
if span and span.get_span_context():
|
||||
ctx = span.get_span_context()
|
||||
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
|
||||
trace_id = f"{ctx.trace_id:032x}"
|
||||
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
|
||||
return trace_id, span_id
|
||||
return "", ""
|
||||
|
||||
|
||||
class IdentityContextFilter(logging.Filter):
|
||||
"""
|
||||
Filter that adds user identity context to log records.
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
"""Extract identity from current_user if in request context."""
|
||||
try:
|
||||
if not flask.has_request_context():
|
||||
return {}
|
||||
from flask_login import current_user
|
||||
|
||||
# Check if user is authenticated using the proxy
|
||||
if not current_user.is_authenticated:
|
||||
return {}
|
||||
|
||||
# Access the underlying user object
|
||||
user = current_user
|
||||
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
|
||||
if isinstance(user, Account):
|
||||
if user.current_tenant_id:
|
||||
identity["tenant_id"] = user.current_tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = "account"
|
||||
elif isinstance(user, EndUser):
|
||||
identity["tenant_id"] = user.tenant_id
|
||||
identity["user_id"] = user.id
|
||||
identity["user_type"] = user.type or "end_user"
|
||||
|
||||
return identity
|
||||
except Exception:
|
||||
return {}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Structured JSON log formatter for Dify."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
{
|
||||
"ts": "ISO 8601 UTC",
|
||||
"severity": "INFO|ERROR|WARN|DEBUG",
|
||||
"service": "service name",
|
||||
"caller": "file:line",
|
||||
"trace_id": "hex 32",
|
||||
"span_id": "hex 16",
|
||||
"identity": { "tenant_id", "user_id", "user_type" },
|
||||
"message": "log message",
|
||||
"attributes": { ... },
|
||||
"stack_trace": "..."
|
||||
}
|
||||
"""
|
||||
|
||||
SEVERITY_MAP: dict[int, str] = {
|
||||
logging.DEBUG: "DEBUG",
|
||||
logging.INFO: "INFO",
|
||||
logging.WARNING: "WARN",
|
||||
logging.ERROR: "ERROR",
|
||||
logging.CRITICAL: "ERROR",
|
||||
}
|
||||
|
||||
def __init__(self, service_name: str | None = None):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
return orjson.dumps(log_dict).decode("utf-8")
|
||||
except TypeError:
|
||||
# Fallback: convert non-serializable objects to string
|
||||
import json
|
||||
|
||||
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
"caller": f"{record.filename}:{record.lineno}",
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Trace context (from TraceContextFilter)
|
||||
trace_id = getattr(record, "trace_id", "")
|
||||
span_id = getattr(record, "span_id", "")
|
||||
|
||||
if trace_id:
|
||||
log_dict["trace_id"] = trace_id
|
||||
if span_id:
|
||||
log_dict["span_id"] = span_id
|
||||
|
||||
# Identity context (from IdentityContextFilter)
|
||||
identity = self._extract_identity(record)
|
||||
if identity:
|
||||
log_dict["identity"] = identity
|
||||
|
||||
# Dynamic attributes
|
||||
attributes = getattr(record, "attributes", None)
|
||||
if attributes:
|
||||
log_dict["attributes"] = attributes
|
||||
|
||||
# Stack trace for errors with exceptions
|
||||
if record.exc_info and record.levelno >= logging.ERROR:
|
||||
log_dict["stack_trace"] = self._format_exception(record.exc_info)
|
||||
|
||||
return log_dict
|
||||
|
||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||
tenant_id = getattr(record, "tenant_id", None)
|
||||
user_id = getattr(record, "user_id", None)
|
||||
user_type = getattr(record, "user_type", None)
|
||||
|
||||
if not any([tenant_id, user_id, user_type]):
|
||||
return None
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
if tenant_id:
|
||||
identity["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
identity["user_id"] = user_id
|
||||
if user_type:
|
||||
identity["user_type"] = user_type
|
||||
return identity
|
||||
|
||||
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
|
||||
if exc_info and exc_info[0] is not None:
|
||||
return "".join(traceback.format_exception(*exc_info))
|
||||
return ""
|
||||
|
|
@ -103,6 +103,9 @@ class BasePluginClient:
|
|||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
# Inject traceparent header for distributed tracing
|
||||
self._inject_trace_headers(prepared_headers)
|
||||
|
||||
prepared_data: bytes | dict[str, Any] | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
|
|
@ -114,6 +117,31 @@ class BasePluginClient:
|
|||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
|
||||
"""
|
||||
Inject W3C traceparent header for distributed tracing.
|
||||
|
||||
This ensures trace context is propagated to plugin daemon even if
|
||||
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
|
||||
"""
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
import contextlib
|
||||
|
||||
# Skip if already present (case-insensitive check)
|
||||
for key in headers:
|
||||
if key.lower() == "traceparent":
|
||||
return
|
||||
|
||||
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
|
||||
with contextlib.suppress(Exception):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
traceparent = generate_traceparent_header()
|
||||
if traceparent:
|
||||
headers["traceparent"] = traceparent
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
|
|
|||
|
|
@ -515,6 +515,7 @@ class DatasetRetrieval:
|
|||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
dataset_count = len(available_datasets)
|
||||
with measure_time() as timer:
|
||||
cancel_event = threading.Event()
|
||||
thread_exceptions: list[Exception] = []
|
||||
|
|
@ -537,6 +538,7 @@ class DatasetRetrieval:
|
|||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
"dataset_count": dataset_count,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
|
|
@ -562,6 +564,7 @@ class DatasetRetrieval:
|
|||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
"dataset_count": dataset_count,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
|
|
@ -1422,6 +1425,7 @@ class DatasetRetrieval:
|
|||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
dataset_count: int,
|
||||
cancel_event: threading.Event | None = None,
|
||||
thread_exceptions: list[Exception] | None = None,
|
||||
):
|
||||
|
|
@ -1470,37 +1474,38 @@ class DatasetRetrieval:
|
|||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.TEXT_QUERY,
|
||||
)
|
||||
if attachment_id:
|
||||
all_documents_item = data_post_processor.invoke(
|
||||
documents=all_documents_item,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k,
|
||||
query_type=QueryType.IMAGE_QUERY,
|
||||
query=attachment_id,
|
||||
)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
if index_type == IndexTechniqueType.ECONOMY:
|
||||
if not query:
|
||||
all_documents_item = []
|
||||
else:
|
||||
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
|
||||
elif index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
|
||||
else:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
except Exception as e:
|
||||
if cancel_event:
|
||||
cancel_event.set()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
|
|
@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
|||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
|
|
@ -20,9 +20,41 @@ from .exc import (
|
|||
OutputValidationError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
|
||||
Python3CodeProvider,
|
||||
JavascriptCodeProvider,
|
||||
)
|
||||
_limits: CodeNodeLimits
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
|
||||
)
|
||||
self._limits = code_limits
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
|
@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
|
|||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||
code_provider: type[CodeNodeProvider] = next(
|
||||
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
|
||||
)
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
|
||||
return cls._DEFAULT_CODE_PROVIDERS
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
|
|||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
_ = self._select_code_provider(code_language)
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
|
|
@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
|
|||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
|
||||
for provider in self._code_providers:
|
||||
if provider.is_accept_language(code_language):
|
||||
return provider
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
|
|
@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
if len(value) > self._limits.max_string_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
f" less than {self._limits.max_string_length} characters"
|
||||
)
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
|
@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
|
|||
if value is None:
|
||||
return None
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
if value > self._limits.max_number or value < self._limits.min_number:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
decimal_value = Decimal(str(value)).normalize()
|
||||
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
|
||||
# raise error if precision is too high
|
||||
if precision > dify_config.CODE_MAX_PRECISION:
|
||||
if precision > self._limits.max_precision:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
f" it must be less than {self._limits.max_precision} digits."
|
||||
)
|
||||
|
||||
return value
|
||||
|
|
@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
|
|||
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
|
||||
# Note that `_transform_result` may produce lists containing `None` values,
|
||||
# which don't conform to the type requirements of `Array*Segment` classes.
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
if depth > self._limits.max_depth:
|
||||
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
|
|
@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
|
|||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
if len(value) > self._limits.max_number_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_number_array_length} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
|
|
@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
|
|||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_string_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_string_array_length} elements."
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
|
|
@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
|
|||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
if len(result[output_name]) > self._limits.max_object_array_length:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
f" less than {self._limits.max_object_array_length} elements."
|
||||
)
|
||||
|
||||
for i, value in enumerate(result[output_name]):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CodeNodeLimits:
|
||||
max_string_length: int
|
||||
max_number: int | float
|
||||
min_number: int | float
|
||||
max_precision: int
|
||||
max_depth: int
|
||||
max_number_array_length: int
|
||||
max_string_array_length: int
|
||||
max_object_array_length: int
|
||||
|
|
@ -1,10 +1,21 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
|
@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory):
|
|||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
|
||||
)
|
||||
self._code_limits = code_limits or CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
|
|
@ -72,6 +103,26 @@ class DifyNodeFactory(NodeFactory):
|
|||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
if node_type == NodeType.CODE:
|
||||
return CodeNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: type[CodeExecutor]
|
||||
|
||||
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
|
||||
self._code_executor = code_executor or CodeExecutor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as exc:
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
|
|
@ -1,18 +1,44 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
|
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
|||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
|||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -46,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
|
|||
def init_app(app: DifyApp) -> Celery:
|
||||
class FlaskTask(Task):
|
||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||
from core.logging.context import init_request_context
|
||||
|
||||
with app.app_context():
|
||||
# Initialize logging context for this task (similar to before_request in Flask)
|
||||
init_request_context()
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
broker_transport_options = {}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ def init_app(app: DifyApp):
|
|||
create_tenant,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
fix_app_site_missing,
|
||||
install_plugins,
|
||||
install_rag_pipeline_plugins,
|
||||
|
|
@ -47,6 +48,7 @@ def init_app(app: DifyApp):
|
|||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
file_usage,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
cleanup_orphaned_draft_variables,
|
||||
|
|
|
|||
|
|
@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
|
|||
def init_app(app: DifyApp):
|
||||
db.init_app(app)
|
||||
_setup_gevent_compatibility()
|
||||
|
||||
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
|
||||
try:
|
||||
with app.app_context():
|
||||
_ = db.engine # triggers engine creation with the configured options
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize SQLAlchemy engine during app startup")
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
"""Logging extension for Dify Flask application."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize logging with support for text or JSON format."""
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
# File handler
|
||||
log_file = dify_config.LOG_FILE
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
|
|
@ -25,27 +26,53 @@ def init_app(app: DifyApp):
|
|||
)
|
||||
)
|
||||
|
||||
# Always add StreamHandler to log to console
|
||||
# Console handler
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
log_handlers.append(sh)
|
||||
|
||||
# Apply RequestIdFilter to all handlers
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(RequestIdFilter())
|
||||
# Apply filters to all handlers
|
||||
from core.logging.filters import IdentityContextFilter, TraceContextFilter
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.addFilter(TraceContextFilter())
|
||||
handler.addFilter(IdentityContextFilter())
|
||||
|
||||
# Configure formatter based on format type
|
||||
formatter = _create_formatter()
|
||||
for handler in log_handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=dify_config.LOG_LEVEL,
|
||||
format=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Apply RequestIdFormatter to all handlers
|
||||
apply_request_id_formatter()
|
||||
|
||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
|
||||
# Apply timezone if specified (only for text format)
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "text":
|
||||
_apply_timezone(log_handlers)
|
||||
|
||||
|
||||
def _create_formatter() -> logging.Formatter:
|
||||
"""Create appropriate formatter based on configuration."""
|
||||
if dify_config.LOG_OUTPUT_FORMAT == "json":
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
return StructuredJSONFormatter()
|
||||
else:
|
||||
# Text format - use existing pattern with backward compatible formatter
|
||||
return _TextFormatter(
|
||||
fmt=dify_config.LOG_FORMAT,
|
||||
datefmt=dify_config.LOG_DATEFORMAT,
|
||||
)
|
||||
|
||||
|
||||
def _apply_timezone(handlers: list[logging.Handler]):
|
||||
"""Apply timezone conversion to text formatters."""
|
||||
log_tz = dify_config.LOG_TZ
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
|
@ -57,34 +84,51 @@ def init_app(app: DifyApp):
|
|||
def time_converter(seconds):
|
||||
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
for handler in handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter.converter = time_converter
|
||||
handler.formatter.converter = time_converter # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def get_request_id():
|
||||
if getattr(flask.g, "request_id", None):
|
||||
return flask.g.request_id
|
||||
class _TextFormatter(logging.Formatter):
|
||||
"""Text formatter that ensures trace_id and req_id are always present."""
|
||||
|
||||
new_uuid = uuid.uuid4().hex[:10]
|
||||
flask.g.request_id = new_uuid
|
||||
|
||||
return new_uuid
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
record.trace_id = ""
|
||||
if not hasattr(record, "span_id"):
|
||||
record.span_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get request ID for current request context.
|
||||
|
||||
Deprecated: Use core.logging.context.get_request_id() directly.
|
||||
"""
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
|
||||
return _get_request_id()
|
||||
|
||||
|
||||
# Backward compatibility aliases
|
||||
class RequestIdFilter(logging.Filter):
|
||||
# This is a logging filter that makes the request ID available for use in
|
||||
# the logging format. Note that we're checking if we're in a request
|
||||
# context, as we may want to log things before Flask is fully loaded.
|
||||
def filter(self, record):
|
||||
trace_id = get_trace_id_from_otel_context() or ""
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
record.trace_id = trace_id
|
||||
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
from core.logging.context import get_request_id as _get_request_id
|
||||
from core.logging.context import get_trace_id as _get_trace_id
|
||||
|
||||
record.req_id = _get_request_id()
|
||||
record.trace_id = _get_trace_id()
|
||||
return True
|
||||
|
||||
|
||||
class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
"""Deprecated: Use _TextFormatter instead."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
|
|
@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
|
|||
|
||||
|
||||
def apply_request_id_formatter():
|
||||
"""Deprecated: Formatter is now applied in init_app."""
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_serializable(obj):
|
||||
"""
|
||||
Convert non-JSON-serializable objects into JSON-compatible formats.
|
||||
|
||||
- Uses `to_dict()` if it's a callable method.
|
||||
- Falls back to string representation.
|
||||
"""
|
||||
if hasattr(obj, "to_dict") and callable(obj.to_dict):
|
||||
return obj.to_dict()
|
||||
return str(obj)
|
||||
|
||||
|
||||
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -108,9 +120,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
),
|
||||
("type", domain_model.workflow_type.value),
|
||||
("version", domain_model.workflow_version),
|
||||
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
|
||||
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
|
||||
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
|
||||
(
|
||||
"graph",
|
||||
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.graph
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"inputs",
|
||||
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.inputs
|
||||
else "{}",
|
||||
),
|
||||
(
|
||||
"outputs",
|
||||
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
|
||||
if domain_model.outputs
|
||||
else "{}",
|
||||
),
|
||||
("status", domain_model.status.value),
|
||||
("error_message", domain_model.error_message or ""),
|
||||
("total_tokens", str(domain_model.total_tokens)),
|
||||
|
|
|
|||
|
|
@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""
|
||||
Handler that records exceptions to the current OpenTelemetry span.
|
||||
|
||||
Unlike creating a new span, this records exceptions on the existing span
|
||||
to maintain trace context consistency throughout the request lifecycle.
|
||||
"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
with contextlib.suppress(Exception):
|
||||
if record.exc_info:
|
||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||
with tracer.start_as_current_span(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
) as span:
|
||||
span.set_status(StatusCode.ERROR)
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
span.set_attribute("exception.message", str(record.exc_info[1]))
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
if not record.exc_info:
|
||||
return
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
span = get_current_span()
|
||||
if not span or not span.is_recording():
|
||||
return
|
||||
|
||||
# Record exception on the current span instead of creating a new one
|
||||
span.set_status(StatusCode.ERROR, record.getMessage())
|
||||
|
||||
# Add log context as span events/attributes
|
||||
span.add_event(
|
||||
"log.exception",
|
||||
attributes={
|
||||
"log.level": record.levelname,
|
||||
"log.message": record.getMessage(),
|
||||
"log.logger": record.name,
|
||||
"log.file.path": record.pathname,
|
||||
"log.file.line": record.lineno,
|
||||
},
|
||||
)
|
||||
|
||||
if record.exc_info[1]:
|
||||
span.record_exception(record.exc_info[1])
|
||||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
|
||||
|
||||
def instrument_exception_logging() -> None:
|
||||
|
|
|
|||
|
|
@ -1,236 +1,338 @@
|
|||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from .raws import FilesContainedField
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.file import File
|
||||
|
||||
JSONValue: TypeAlias = Any
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]["text"] if value else ""
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
|
||||
feedback_fields = {
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
}
|
||||
class MessageFile(ResponseModel):
|
||||
id: str
|
||||
filename: str
|
||||
type: str
|
||||
url: str | None = None
|
||||
mime_type: str | None = None
|
||||
size: int | None = None
|
||||
transfer_method: str
|
||||
belongs_to: str | None = None
|
||||
upload_file_id: str | None = None
|
||||
|
||||
annotation_fields = {
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
annotation_hit_history_fields = {
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
message_file_fields = {
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
}
|
||||
@field_validator("transfer_method", mode="before")
|
||||
@classmethod
|
||||
def _normalize_transfer_method(cls, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def build_message_file_model(api_or_ns: Namespace):
|
||||
"""Build the message file fields for the API or Namespace."""
|
||||
return api_or_ns.model("MessageFile", message_file_fields)
|
||||
class SimpleConversation(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputs: dict[str, JSONValue]
|
||||
status: str
|
||||
introduction: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_fields)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_fields, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
}
|
||||
|
||||
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
|
||||
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
|
||||
model_config_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
}
|
||||
|
||||
simple_message_detail_fields = {
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_fields, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields), attribute="items"),
|
||||
}
|
||||
|
||||
conversation_message_detail_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_fields),
|
||||
"message": fields.Nested(message_detail_fields, attribute="first_message"),
|
||||
}
|
||||
|
||||
conversation_with_summary_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"status_count": fields.Nested(status_count_fields),
|
||||
}
|
||||
|
||||
conversation_with_summary_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
|
||||
}
|
||||
|
||||
conversation_detail_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_fields),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
}
|
||||
|
||||
simple_conversation_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"status": fields.String,
|
||||
"introduction": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
conversation_delete_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(simple_conversation_fields)),
|
||||
}
|
||||
class ConversationInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[SimpleConversation]
|
||||
|
||||
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||
|
||||
copied_fields = conversation_infinite_scroll_pagination_fields.copy()
|
||||
copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
|
||||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||
class ConversationDelete(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
def build_conversation_delete_model(api_or_ns: Namespace):
|
||||
"""Build the conversation delete model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||
class ResultResponse(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
def build_simple_conversation_model(api_or_ns: Namespace):
|
||||
"""Build the simple conversation model for the API or Namespace."""
|
||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||
class SimpleAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class Feedback(ResponseModel):
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account: SimpleAccount | None = None
|
||||
|
||||
|
||||
class Annotation(ResponseModel):
|
||||
id: str
|
||||
question: str | None = None
|
||||
content: str
|
||||
account: SimpleAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class AnnotationHitHistory(ResponseModel):
|
||||
annotation_id: str
|
||||
annotation_create_account: SimpleAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class AgentThought(ResponseModel):
|
||||
id: str
|
||||
chain_id: str | None = None
|
||||
message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
|
||||
message_id: str
|
||||
position: int
|
||||
thought: str | None = None
|
||||
tool: str | None = None
|
||||
tool_labels: JSONValue
|
||||
tool_input: str | None = None
|
||||
created_at: int | None = None
|
||||
observation: str | None = None
|
||||
files: list[str]
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _fallback_chain_id(self):
|
||||
if self.chain_id is None and self.message_chain_id:
|
||||
self.chain_id = self.message_chain_id
|
||||
return self
|
||||
|
||||
|
||||
class MessageDetail(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: JSONValue
|
||||
message_tokens: int
|
||||
answer: str
|
||||
answer_tokens: int
|
||||
provider_response_latency: float
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
feedbacks: list[Feedback]
|
||||
workflow_run_id: str | None = None
|
||||
annotation: Annotation | None = None
|
||||
annotation_hit_history: AnnotationHitHistory | None = None
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought]
|
||||
message_files: list[MessageFile]
|
||||
metadata: JSONValue
|
||||
status: str
|
||||
error: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class FeedbackStat(ResponseModel):
|
||||
like: int
|
||||
dislike: int
|
||||
|
||||
|
||||
class StatusCount(ResponseModel):
|
||||
success: int
|
||||
failed: int
|
||||
partial_success: int
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: JSONValue | None = None
|
||||
model: JSONValue | None = None
|
||||
user_input_form: JSONValue | None = None
|
||||
pre_prompt: str | None = None
|
||||
agent_mode: JSONValue | None = None
|
||||
|
||||
|
||||
class SimpleModelConfig(ResponseModel):
|
||||
model: JSONValue | None = None
|
||||
pre_prompt: str | None = None
|
||||
|
||||
|
||||
class SimpleMessageDetail(ResponseModel):
|
||||
inputs: dict[str, JSONValue]
|
||||
query: str
|
||||
message: str
|
||||
answer: str
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
|
||||
return format_files_contained(value)
|
||||
|
||||
|
||||
class Conversation(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_end_user_session_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
from_account_name: str | None = None
|
||||
read_at: int | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotation: Annotation | None = None
|
||||
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
message: SimpleMessageDetail | None = None
|
||||
|
||||
|
||||
class ConversationPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[Conversation]
|
||||
|
||||
|
||||
class ConversationMessageDetail(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: int | None = None
|
||||
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
|
||||
message: MessageDetail | None = None
|
||||
|
||||
|
||||
class ConversationWithSummary(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_end_user_session_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
from_account_name: str | None = None
|
||||
name: str
|
||||
summary: str
|
||||
read_at: int | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotated: bool
|
||||
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
|
||||
message_count: int
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
status_count: StatusCount | None = None
|
||||
|
||||
|
||||
class ConversationWithSummaryPagination(ResponseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[ConversationWithSummary]
|
||||
|
||||
|
||||
class ConversationDetail(ResponseModel):
|
||||
id: str
|
||||
status: str
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
annotated: bool
|
||||
introduction: str | None = None
|
||||
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
|
||||
message_count: int
|
||||
user_feedback_stats: FeedbackStat | None = None
|
||||
admin_feedback_stats: FeedbackStat | None = None
|
||||
|
||||
|
||||
def to_timestamp(value: datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def format_files_contained(value: JSONValue) -> JSONValue:
|
||||
if isinstance(value, File):
|
||||
return value.model_dump()
|
||||
if isinstance(value, dict):
|
||||
return {k: format_files_contained(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [format_files_contained(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def message_text(value: JSONValue) -> str:
|
||||
if isinstance(value, list) and value:
|
||||
first = value[0]
|
||||
if isinstance(first, dict):
|
||||
text = first.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def extract_model_config(value: object | None) -> dict[str, JSONValue]:
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if hasattr(value, "to_dict"):
|
||||
return value.to_dict()
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,77 +1,137 @@
|
|||
from flask_restx import Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
from typing import TypeAlias
|
||||
|
||||
from .raws import FilesContainedField
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
feedback_fields = {
|
||||
"rating": fields.String,
|
||||
}
|
||||
from core.file import File
|
||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||
|
||||
JSONValueType: TypeAlias = JSONValue
|
||||
|
||||
|
||||
def build_feedback_model(api_or_ns: Namespace):
|
||||
"""Build the feedback model for the API or Namespace."""
|
||||
return api_or_ns.model("Feedback", feedback_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
}
|
||||
class SimpleFeedback(ResponseModel):
|
||||
rating: str | None = None
|
||||
|
||||
|
||||
def build_agent_thought_model(api_or_ns: Namespace):
|
||||
"""Build the agent thought model for the API or Namespace."""
|
||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||
class RetrieverResource(ResponseModel):
|
||||
id: str
|
||||
message_id: str
|
||||
position: int
|
||||
dataset_id: str | None = None
|
||||
dataset_name: str | None = None
|
||||
document_id: str | None = None
|
||||
document_name: str | None = None
|
||||
data_source_type: str | None = None
|
||||
segment_id: str | None = None
|
||||
score: float | None = None
|
||||
hit_count: int | None = None
|
||||
word_count: int | None = None
|
||||
segment_position: int | None = None
|
||||
index_node_hash: str | None = None
|
||||
content: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
class MessageListItem(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
parent_message_id: str | None = None
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str = Field(validation_alias="re_sign_file_url_answer")
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
retriever_resources: list[RetrieverResource]
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought]
|
||||
message_files: list[MessageFile]
|
||||
status: str
|
||||
error: str | None = None
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class WebMessageListItem(MessageListItem):
|
||||
metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
|
||||
|
||||
|
||||
class MessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[MessageListItem]
|
||||
|
||||
|
||||
class WebMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[WebMessageListItem]
|
||||
|
||||
|
||||
class SavedMessageItem(ResponseModel):
|
||||
id: str
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str
|
||||
message_files: list[MessageFile]
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class SavedMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[SavedMessageItem]
|
||||
|
||||
|
||||
class SuggestedQuestionsResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
def to_timestamp(value: datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def format_files_contained(value: JSONValueType) -> JSONValueType:
|
||||
if isinstance(value, File):
|
||||
return value.model_dump()
|
||||
if isinstance(value, dict):
|
||||
return {k: format_files_contained(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [format_files_contained(v) for v in value]
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import re
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
|
|||
data.setdefault("code", "unknown")
|
||||
data.setdefault("status", status_code)
|
||||
|
||||
# Log stack
|
||||
exc_info: Any = sys.exc_info()
|
||||
if exc_info[1] is None:
|
||||
exc_info = (None, None, None)
|
||||
current_app.log_exception(exc_info)
|
||||
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
|
||||
# Explicit log_exception call removed to avoid duplicate log entries
|
||||
|
||||
return data, status_code
|
||||
|
||||
|
|
|
|||
|
|
@ -853,7 +853,7 @@ class TriggerProviderService:
|
|||
"""
|
||||
Create a subscription builder for rebuilding an existing subscription.
|
||||
|
||||
This method creates a builder pre-filled with data from the rebuild request,
|
||||
This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub)
|
||||
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
|
|
@ -868,111 +868,50 @@ class TriggerProviderService:
|
|||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
# Get subscription within the transaction
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||
raise ValueError("Credential type not supported for rebuild")
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}:
|
||||
raise ValueError(f"Credential type {credential_type} not supported for auto creation")
|
||||
|
||||
# Decrypt existing credentials for merging
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials))
|
||||
# Delete the previous subscription
|
||||
user_id = subscription.user_id
|
||||
unsubscribe_result = TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
if not unsubscribe_result.success:
|
||||
raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}")
|
||||
|
||||
# Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value
|
||||
merged_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
user_id = subscription.user_id
|
||||
|
||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||
|
||||
# FALLBACK: If the update api is not implemented,
|
||||
# delete the previous subscription and create a new one
|
||||
|
||||
# Unsubscribe the previous subscription (external call, but we'll handle errors)
|
||||
try:
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=decrypted_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error unsubscribing trigger during rebuild", exc_info=e)
|
||||
# Continue anyway - the subscription might already be deleted externally
|
||||
|
||||
# Create a new subscription with the same subscription_id and endpoint_id (external call)
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=merged_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
# Update the subscription in the same transaction
|
||||
# Inline update logic to reuse the same session
|
||||
if name is not None and name != subscription.name:
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing and existing.id != subscription.id:
|
||||
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||
subscription.name = name
|
||||
|
||||
# Update parameters
|
||||
subscription.parameters = dict(parameters)
|
||||
|
||||
# Update credentials with merged (and encrypted) values
|
||||
subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials))
|
||||
|
||||
# Update properties
|
||||
if new_subscription.properties:
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties)))
|
||||
|
||||
# Update expiration timestamp
|
||||
if new_subscription.expires_at is not None:
|
||||
subscription.expires_at = new_subscription.expires_at
|
||||
|
||||
# Commit the transaction
|
||||
session.commit()
|
||||
|
||||
# Clear subscription cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on any error
|
||||
session.rollback()
|
||||
logger.exception("Failed to rebuild trigger subscription", exc_info=e)
|
||||
raise
|
||||
# Create a new subscription with the same subscription_id and endpoint_id
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=new_credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription.id,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
credentials=new_credentials,
|
||||
properties=new_subscription.properties,
|
||||
expires_at=new_subscription.expires_at,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from core.workflow.enums import WorkflowNodeExecutionStatus
|
|||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -67,6 +68,16 @@ def init_code_node(code_config: dict):
|
|||
config=code_config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
code_limits=CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
),
|
||||
)
|
||||
|
||||
return node
|
||||
|
|
|
|||
|
|
@ -474,64 +474,6 @@ class TestTriggerProviderService:
|
|||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
def test_rebuild_trigger_subscription_unsubscribe_error_continues(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that unsubscribe errors are handled gracefully and operation continues.
|
||||
|
||||
This test verifies:
|
||||
- Unsubscribe errors are caught and logged but don't stop the rebuild
|
||||
- Rebuild continues even if unsubscribe fails
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
original_credentials = {"api_key": "original-key"}
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
original_credentials,
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
# Make unsubscribe_trigger raise an error (should be caught and continue)
|
||||
mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError(
|
||||
"Unsubscribe failed"
|
||||
)
|
||||
|
||||
new_subscription_entity = TriggerSubscriptionEntity(
|
||||
endpoint=subscription.endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
expires_at=-1,
|
||||
)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity
|
||||
|
||||
# Execute rebuild - should succeed despite unsubscribe error
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={"api_key": "new-key"},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Verify subscribe was still called (operation continued)
|
||||
mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once()
|
||||
|
||||
# Verify subscription was updated
|
||||
db.session.refresh(subscription)
|
||||
assert subscription.parameters == {}
|
||||
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
|
@ -558,70 +500,6 @@ class TestTriggerProviderService:
|
|||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when provider is not found.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised when provider doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider")
|
||||
|
||||
# Make get_trigger_provider return None
|
||||
mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Provider.*not found"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=fake.uuid4(),
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_unsupported_credential_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when credential type is not supported for rebuild.
|
||||
|
||||
This test verifies:
|
||||
- Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY)
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.UNAUTHORIZED # Not supported
|
||||
|
||||
subscription = self._create_test_subscription(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
account.id,
|
||||
provider_id,
|
||||
credential_type,
|
||||
{},
|
||||
mock_external_service_dependencies,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Credential type not supported for rebuild"):
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
credentials={},
|
||||
parameters={},
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_name_uniqueness_check(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
"""
|
||||
Test for document detail API data_source_info serialization fix.
|
||||
|
||||
This test verifies that the document detail API returns both data_source_info
|
||||
and data_source_detail_dict for all data_source_type values, including "local_file".
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union
|
||||
|
||||
from models.dataset import Document
|
||||
|
||||
|
||||
class LocalFileInfo(TypedDict):
|
||||
file_path: str
|
||||
size: int
|
||||
created_at: NotRequired[str]
|
||||
|
||||
|
||||
class UploadFileInfo(TypedDict):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class NotionImportInfo(TypedDict):
|
||||
notion_page_id: str
|
||||
workspace_id: str
|
||||
|
||||
|
||||
class WebsiteCrawlInfo(TypedDict):
|
||||
url: str
|
||||
job_id: str
|
||||
|
||||
|
||||
RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]
|
||||
T_type = TypeVar("T_type", bound=str)
|
||||
T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo])
|
||||
|
||||
|
||||
class Case(TypedDict, Generic[T_type, T_info]):
|
||||
data_source_type: T_type
|
||||
data_source_info: str
|
||||
expected_raw: T_info
|
||||
|
||||
|
||||
LocalFileCase = Case[Literal["local_file"], LocalFileInfo]
|
||||
UploadFileCase = Case[Literal["upload_file"], UploadFileInfo]
|
||||
NotionImportCase = Case[Literal["notion_import"], NotionImportInfo]
|
||||
WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo]
|
||||
|
||||
AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase]
|
||||
|
||||
|
||||
case_1: LocalFileCase = {
|
||||
"data_source_type": "local_file",
|
||||
"data_source_info": json.dumps({"file_path": "/tmp/test.txt", "size": 1024}),
|
||||
"expected_raw": {"file_path": "/tmp/test.txt", "size": 1024},
|
||||
}
|
||||
|
||||
|
||||
# ERROR: Expected LocalFileInfo, but got WebsiteCrawlInfo
|
||||
case_2: LocalFileCase = {
|
||||
"data_source_type": "local_file",
|
||||
"data_source_info": "...",
|
||||
"expected_raw": {"file_path": "https://google.com", "size": 123},
|
||||
}
|
||||
|
||||
cases: list[AnyCase] = [case_1]
|
||||
|
||||
|
||||
class TestDocumentDetailDataSourceInfo:
|
||||
"""Test cases for document detail API data_source_info serialization."""
|
||||
|
||||
def test_data_source_info_dict_returns_raw_data(self):
|
||||
"""Test that data_source_info_dict returns raw JSON data for all data_source_type values."""
|
||||
# Test data for different data_source_type values
|
||||
for case in cases:
|
||||
document = Document(
|
||||
data_source_type=case["data_source_type"],
|
||||
data_source_info=case["data_source_info"],
|
||||
)
|
||||
|
||||
# Test data_source_info_dict (raw data)
|
||||
raw_result = document.data_source_info_dict
|
||||
assert raw_result == case["expected_raw"], f"Failed for {case['data_source_type']}"
|
||||
|
||||
# Verify raw_result is always a valid dict
|
||||
assert isinstance(raw_result, dict)
|
||||
|
||||
def test_local_file_data_source_info_without_db_context(self):
|
||||
"""Test that local_file type data_source_info_dict works without database context."""
|
||||
test_data: LocalFileInfo = {
|
||||
"file_path": "/local/path/document.txt",
|
||||
"size": 512,
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
document = Document(
|
||||
data_source_type="local_file",
|
||||
data_source_info=json.dumps(test_data),
|
||||
)
|
||||
|
||||
# data_source_info_dict should return the raw data (this doesn't need DB context)
|
||||
raw_data = document.data_source_info_dict
|
||||
assert raw_data == test_data
|
||||
assert isinstance(raw_data, dict)
|
||||
|
||||
# Verify the data contains expected keys for pipeline mode
|
||||
assert "file_path" in raw_data
|
||||
assert "size" in raw_data
|
||||
|
||||
def test_notion_and_website_crawl_data_source_detail(self):
|
||||
"""Test that notion_import and website_crawl return raw data in data_source_detail_dict."""
|
||||
# Test notion_import
|
||||
notion_data: NotionImportInfo = {"notion_page_id": "page-123", "workspace_id": "ws-456"}
|
||||
document = Document(
|
||||
data_source_type="notion_import",
|
||||
data_source_info=json.dumps(notion_data),
|
||||
)
|
||||
|
||||
# data_source_detail_dict should return raw data for notion_import
|
||||
detail_result = document.data_source_detail_dict
|
||||
assert detail_result == notion_data
|
||||
|
||||
# Test website_crawl
|
||||
website_data: WebsiteCrawlInfo = {"url": "https://example.com", "job_id": "job-789"}
|
||||
document = Document(
|
||||
data_source_type="website_crawl",
|
||||
data_source_info=json.dumps(website_data),
|
||||
)
|
||||
|
||||
# data_source_detail_dict should return raw data for website_crawl
|
||||
detail_result = document.data_source_detail_dict
|
||||
assert detail_result == website_data
|
||||
|
||||
def test_local_file_data_source_detail_dict_without_db(self):
|
||||
"""Test that local_file returns empty data_source_detail_dict (this doesn't need DB context)."""
|
||||
# Test local_file - this should work without database context since it returns {} early
|
||||
document = Document(
|
||||
data_source_type="local_file",
|
||||
data_source_info=json.dumps({"file_path": "/tmp/test.txt"}),
|
||||
)
|
||||
|
||||
# Should return empty dict for local_file type (handled in the model)
|
||||
detail_result = document.data_source_detail_dict
|
||||
assert detail_result == {}
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""Unit tests for controllers.web.message message list mapping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from datetime import datetime
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_controller_module():
|
||||
"""Import controllers.web.message using a stub package."""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
parent_module_name = "controllers.web"
|
||||
module_name = f"{parent_module_name}.message"
|
||||
|
||||
if parent_module_name not in sys.modules:
|
||||
from flask_restx import Namespace
|
||||
|
||||
stub = ModuleType(parent_module_name)
|
||||
stub.__file__ = "controllers/web/__init__.py"
|
||||
stub.__path__ = ["controllers/web"]
|
||||
stub.__package__ = "controllers"
|
||||
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||
sys.modules[parent_module_name] = stub
|
||||
|
||||
wraps_module_name = f"{parent_module_name}.wraps"
|
||||
if wraps_module_name not in sys.modules:
|
||||
wraps_stub = ModuleType(wraps_module_name)
|
||||
|
||||
class WebApiResource:
|
||||
pass
|
||||
|
||||
wraps_stub.WebApiResource = WebApiResource
|
||||
sys.modules[wraps_module_name] = wraps_stub
|
||||
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
message_module = _load_controller_module()
|
||||
MessageListApi = message_module.MessageListApi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
def test_message_list_mapping(app: Flask) -> None:
|
||||
conversation_id = str(uuid4())
|
||||
message_id = str(uuid4())
|
||||
|
||||
created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
resource_created_at = datetime(2024, 1, 1, 13, 0, 0)
|
||||
thought_created_at = datetime(2024, 1, 1, 14, 0, 0)
|
||||
|
||||
retriever_resource_obj = SimpleNamespace(
|
||||
id="res-obj",
|
||||
message_id=message_id,
|
||||
position=2,
|
||||
dataset_id="ds-1",
|
||||
dataset_name="dataset",
|
||||
document_id="doc-1",
|
||||
document_name="document",
|
||||
data_source_type="file",
|
||||
segment_id="seg-1",
|
||||
score=0.9,
|
||||
hit_count=1,
|
||||
word_count=10,
|
||||
segment_position=0,
|
||||
index_node_hash="hash",
|
||||
content="content",
|
||||
created_at=resource_created_at,
|
||||
)
|
||||
|
||||
agent_thought = SimpleNamespace(
|
||||
id="thought-1",
|
||||
chain_id=None,
|
||||
message_chain_id="chain-1",
|
||||
message_id=message_id,
|
||||
position=1,
|
||||
thought="thinking",
|
||||
tool="tool",
|
||||
tool_labels={"label": "value"},
|
||||
tool_input="{}",
|
||||
created_at=thought_created_at,
|
||||
observation="observed",
|
||||
files=["file-a"],
|
||||
)
|
||||
|
||||
message_file_obj = SimpleNamespace(
|
||||
id="file-obj",
|
||||
filename="b.txt",
|
||||
type="file",
|
||||
url=None,
|
||||
mime_type=None,
|
||||
size=None,
|
||||
transfer_method="local",
|
||||
belongs_to=None,
|
||||
upload_file_id=None,
|
||||
)
|
||||
|
||||
message = SimpleNamespace(
|
||||
id=message_id,
|
||||
conversation_id=conversation_id,
|
||||
parent_message_id=None,
|
||||
inputs={"foo": "bar"},
|
||||
query="hello",
|
||||
re_sign_file_url_answer="answer",
|
||||
user_feedback=SimpleNamespace(rating="like"),
|
||||
retriever_resources=[
|
||||
{"id": "res-dict", "message_id": message_id, "position": 1},
|
||||
retriever_resource_obj,
|
||||
],
|
||||
created_at=created_at,
|
||||
agent_thoughts=[agent_thought],
|
||||
message_files=[
|
||||
{"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"},
|
||||
message_file_obj,
|
||||
],
|
||||
status="success",
|
||||
error=None,
|
||||
message_metadata_dict={"meta": "value"},
|
||||
)
|
||||
|
||||
pagination = SimpleNamespace(limit=20, has_more=False, data=[message])
|
||||
app_model = SimpleNamespace(mode="chat")
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with (
|
||||
patch.object(message_module.MessageService, "pagination_by_first_id", return_value=pagination) as mock_page,
|
||||
app.test_request_context(f"/messages?conversation_id={conversation_id}&limit=20"),
|
||||
):
|
||||
response = MessageListApi().get(app_model, end_user)
|
||||
|
||||
mock_page.assert_called_once_with(app_model, end_user, conversation_id, None, 20)
|
||||
assert response["limit"] == 20
|
||||
assert response["has_more"] is False
|
||||
assert len(response["data"]) == 1
|
||||
|
||||
item = response["data"][0]
|
||||
assert item["id"] == message_id
|
||||
assert item["conversation_id"] == conversation_id
|
||||
assert item["inputs"] == {"foo": "bar"}
|
||||
assert item["answer"] == "answer"
|
||||
assert item["feedback"]["rating"] == "like"
|
||||
assert item["metadata"] == {"meta": "value"}
|
||||
assert item["created_at"] == int(created_at.timestamp())
|
||||
|
||||
assert item["retriever_resources"][0]["id"] == "res-dict"
|
||||
assert item["retriever_resources"][1]["id"] == "res-obj"
|
||||
assert item["retriever_resources"][1]["created_at"] == int(resource_created_at.timestamp())
|
||||
|
||||
assert item["agent_thoughts"][0]["chain_id"] == "chain-1"
|
||||
assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp())
|
||||
|
||||
assert item["message_files"][0]["id"] == "file-dict"
|
||||
assert item["message_files"][1]["id"] == "file-obj"
|
||||
|
|
@ -14,12 +14,12 @@ def test_successful_request(mock_get_client):
|
|||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
assert response.status_code == 200
|
||||
mock_client.request.assert_called_once()
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
|
|
@ -27,7 +27,6 @@ def test_retry_exceed_max_retries(mock_get_client):
|
|||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
|
|
@ -72,34 +71,12 @@ class TestGetUserProvidedHostHeader:
|
|||
assert result in ("first.com", "second.com")
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_without_user_header(mock_get_client):
|
||||
"""Test that when no Host header is provided, the default behavior is maintained."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
|
||||
assert response.status_code == 200
|
||||
# Host should not be set if not provided by user
|
||||
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
"""Test that user-provided Host header is preserved in the request."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
|
|
@ -107,3 +84,93 @@ def test_host_header_preservation_with_user_header(mock_get_client):
|
|||
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
||||
|
||||
assert response.status_code == 200
|
||||
# Verify client.request was called with the host header preserved (lowercase)
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs["headers"]["host"] == custom_host
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"])
|
||||
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
|
||||
"""Test that Host header is preserved regardless of case."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
# Host header should be normalized to lowercase "host"
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs["headers"]["host"] == "api.example.com"
|
||||
|
||||
|
||||
class TestFollowRedirectsParameter:
|
||||
"""Tests for follow_redirects parameter handling.
|
||||
|
||||
These tests verify that follow_redirects is correctly passed to client.request().
|
||||
"""
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_passed_to_request(self, mock_get_client):
|
||||
"""Verify follow_redirects IS passed to client.request()."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
make_request("GET", "http://example.com", follow_redirects=True)
|
||||
|
||||
# Verify follow_redirects was passed to request
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
|
||||
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Use allow_redirects (requests-style parameter)
|
||||
make_request("GET", "http://example.com", allow_redirects=True)
|
||||
|
||||
# Verify it was converted to follow_redirects
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
assert "allow_redirects" not in call_kwargs
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_not_set_when_not_specified(self, mock_get_client):
|
||||
"""Verify follow_redirects is not in kwargs when not specified (httpx default behavior)."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
make_request("GET", "http://example.com")
|
||||
|
||||
# follow_redirects should not be in kwargs, letting httpx use its default
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert "follow_redirects" not in call_kwargs
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
|
||||
"""Verify follow_redirects takes precedence when both are specified."""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Both specified - follow_redirects should take precedence
|
||||
make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True)
|
||||
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"""Tests for logging context module."""
|
||||
|
||||
import uuid
|
||||
|
||||
from core.logging.context import (
|
||||
clear_request_context,
|
||||
get_request_id,
|
||||
get_trace_id,
|
||||
init_request_context,
|
||||
)
|
||||
|
||||
|
||||
class TestLoggingContext:
|
||||
"""Tests for the logging context functions."""
|
||||
|
||||
def test_init_creates_request_id(self):
|
||||
"""init_request_context should create a 10-char request ID."""
|
||||
init_request_context()
|
||||
request_id = get_request_id()
|
||||
assert len(request_id) == 10
|
||||
assert all(c in "0123456789abcdef" for c in request_id)
|
||||
|
||||
def test_init_creates_trace_id(self):
|
||||
"""init_request_context should create a 32-char trace ID."""
|
||||
init_request_context()
|
||||
trace_id = get_trace_id()
|
||||
assert len(trace_id) == 32
|
||||
assert all(c in "0123456789abcdef" for c in trace_id)
|
||||
|
||||
def test_trace_id_derived_from_request_id(self):
|
||||
"""trace_id should be deterministically derived from request_id."""
|
||||
init_request_context()
|
||||
request_id = get_request_id()
|
||||
trace_id = get_trace_id()
|
||||
|
||||
# Verify trace_id is derived using uuid5
|
||||
expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex
|
||||
assert trace_id == expected_trace
|
||||
|
||||
def test_clear_resets_context(self):
|
||||
"""clear_request_context should reset both IDs to empty strings."""
|
||||
init_request_context()
|
||||
assert get_request_id() != ""
|
||||
assert get_trace_id() != ""
|
||||
|
||||
clear_request_context()
|
||||
assert get_request_id() == ""
|
||||
assert get_trace_id() == ""
|
||||
|
||||
def test_default_values_are_empty(self):
|
||||
"""Default values should be empty strings before init."""
|
||||
clear_request_context()
|
||||
assert get_request_id() == ""
|
||||
assert get_trace_id() == ""
|
||||
|
||||
def test_multiple_inits_create_different_ids(self):
|
||||
"""Each init should create new unique IDs."""
|
||||
init_request_context()
|
||||
first_request_id = get_request_id()
|
||||
first_trace_id = get_trace_id()
|
||||
|
||||
init_request_context()
|
||||
second_request_id = get_request_id()
|
||||
second_trace_id = get_trace_id()
|
||||
|
||||
assert first_request_id != second_request_id
|
||||
assert first_trace_id != second_trace_id
|
||||
|
||||
def test_context_isolation(self):
|
||||
"""Context should be isolated per-call (no thread leakage in same thread)."""
|
||||
init_request_context()
|
||||
id1 = get_request_id()
|
||||
|
||||
# Simulate another request
|
||||
init_request_context()
|
||||
id2 = get_request_id()
|
||||
|
||||
# IDs should be different
|
||||
assert id1 != id2
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
"""Tests for logging filters."""
|
||||
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_record():
|
||||
return logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
|
||||
class TestTraceContextFilter:
|
||||
def test_sets_empty_trace_id_without_context(self, log_record):
|
||||
from core.logging.context import clear_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Ensure no context is set
|
||||
clear_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
|
||||
assert result is True
|
||||
assert hasattr(log_record, "trace_id")
|
||||
assert hasattr(log_record, "span_id")
|
||||
assert hasattr(log_record, "req_id")
|
||||
# Without context, IDs should be empty
|
||||
assert log_record.trace_id == ""
|
||||
assert log_record.req_id == ""
|
||||
|
||||
def test_sets_trace_id_from_context(self, log_record):
|
||||
"""Test that trace_id and req_id are set from ContextVar when initialized."""
|
||||
from core.logging.context import init_request_context
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
# Initialize context (no Flask needed!)
|
||||
init_request_context()
|
||||
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
# With context initialized, IDs should be set
|
||||
assert log_record.trace_id != ""
|
||||
assert len(log_record.trace_id) == 32
|
||||
assert log_record.req_id != ""
|
||||
assert len(log_record.req_id) == 10
|
||||
|
||||
def test_filter_always_returns_true(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
filter = TraceContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
|
||||
def test_sets_trace_id_from_otel_when_available(self, log_record):
|
||||
from core.logging.filters import TraceContextFilter
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
filter = TraceContextFilter()
|
||||
filter.filter(log_record)
|
||||
|
||||
assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_record.span_id == "051581bf3bb55c45"
|
||||
|
||||
|
||||
class TestIdentityContextFilter:
|
||||
def test_sets_empty_identity_without_request_context(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
|
||||
assert result is True
|
||||
assert log_record.tenant_id == ""
|
||||
assert log_record.user_id == ""
|
||||
assert log_record.user_type == ""
|
||||
|
||||
def test_filter_always_returns_true(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
|
||||
def test_handles_exception_gracefully(self, log_record):
|
||||
from core.logging.filters import IdentityContextFilter
|
||||
|
||||
filter = IdentityContextFilter()
|
||||
|
||||
# Should not raise even if something goes wrong
|
||||
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
assert log_record.tenant_id == ""
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
"""Tests for structured JSON formatter."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
class TestStructuredJSONFormatter:
|
||||
def test_basic_log_format(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter(service_name="test-service")
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=42,
|
||||
msg="Test message",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["severity"] == "INFO"
|
||||
assert log_dict["service"] == "test-service"
|
||||
assert log_dict["caller"] == "test.py:42"
|
||||
assert log_dict["message"] == "Test message"
|
||||
assert "ts" in log_dict
|
||||
assert log_dict["ts"].endswith("Z")
|
||||
|
||||
def test_severity_mapping(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
test_cases = [
|
||||
(logging.DEBUG, "DEBUG"),
|
||||
(logging.INFO, "INFO"),
|
||||
(logging.WARNING, "WARN"),
|
||||
(logging.ERROR, "ERROR"),
|
||||
(logging.CRITICAL, "ERROR"),
|
||||
]
|
||||
|
||||
for level, expected_severity in test_cases:
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=level,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}"
|
||||
|
||||
def test_error_with_stack_trace(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.ERROR,
|
||||
pathname="test.py",
|
||||
lineno=10,
|
||||
msg="Error occurred",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["severity"] == "ERROR"
|
||||
assert "stack_trace" in log_dict
|
||||
assert "ValueError: Test error" in log_dict["stack_trace"]
|
||||
|
||||
def test_no_stack_trace_for_info(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
|
||||
try:
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=10,
|
||||
msg="Info message",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert "stack_trace" not in log_dict
|
||||
|
||||
def test_trace_context_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
record.span_id = "051581bf3bb55c45"
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
assert log_dict["span_id"] == "051581bf3bb55c45"
|
||||
|
||||
def test_identity_context_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.tenant_id = "t-global-corp"
|
||||
record.user_id = "u-admin-007"
|
||||
record.user_type = "admin"
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert "identity" in log_dict
|
||||
assert log_dict["identity"]["tenant_id"] == "t-global-corp"
|
||||
assert log_dict["identity"]["user_id"] == "u-admin-007"
|
||||
assert log_dict["identity"]["user_type"] == "admin"
|
||||
|
||||
def test_no_identity_when_empty(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert "identity" not in log_dict
|
||||
|
||||
def test_attributes_included(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.attributes = {"order_id": "ord-123", "amount": 99.99}
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["attributes"]["order_id"] == "ord-123"
|
||||
assert log_dict["attributes"]["amount"] == 99.99
|
||||
|
||||
def test_message_with_args(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="User %s logged in from %s",
|
||||
args=("john", "192.168.1.1"),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
assert log_dict["message"] == "User john logged in from 192.168.1.1"
|
||||
|
||||
def test_timestamp_format(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
|
||||
output = formatter.format(record)
|
||||
log_dict = orjson.loads(output)
|
||||
|
||||
# Verify ISO 8601 format with Z suffix
|
||||
ts = log_dict["ts"]
|
||||
assert ts.endswith("Z")
|
||||
assert "T" in ts
|
||||
# Should have milliseconds
|
||||
assert "." in ts
|
||||
|
||||
def test_fallback_for_non_serializable_attributes(self):
|
||||
from core.logging.structured_formatter import StructuredJSONFormatter
|
||||
|
||||
formatter = StructuredJSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="Test with non-serializable",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
# Set is not serializable by orjson
|
||||
record.attributes = {"items": {1, 2, 3}, "custom": object()}
|
||||
|
||||
# Should not raise, fallback to json.dumps with default=str
|
||||
output = formatter.format(record)
|
||||
|
||||
# Verify it's valid JSON (parsed by stdlib json since orjson may fail)
|
||||
import json
|
||||
|
||||
log_dict = json.loads(output)
|
||||
assert log_dict["message"] == "Test with non-serializable"
|
||||
assert "attributes" in log_dict
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
"""Tests for trace helper functions."""
|
||||
|
||||
import re
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestGetSpanIdFromOtelContext:
|
||||
def test_returns_none_without_span(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
def test_returns_span_id_when_available(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result == "051581bf3bb55c45"
|
||||
|
||||
def test_returns_none_on_exception(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenerateTraceparentHeader:
|
||||
def test_generates_valid_format(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
assert result is not None
|
||||
# Format: 00-{trace_id}-{span_id}-01
|
||||
parts = result.split("-")
|
||||
assert len(parts) == 4
|
||||
assert parts[0] == "00" # version
|
||||
assert len(parts[1]) == 32 # trace_id (32 hex chars)
|
||||
assert len(parts[2]) == 16 # span_id (16 hex chars)
|
||||
assert parts[3] == "01" # flags
|
||||
|
||||
def test_uses_otel_context_when_available(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
mock_span = mock.MagicMock()
|
||||
mock_context = mock.MagicMock()
|
||||
mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
|
||||
def test_generates_hex_only_values(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
parts = result.split("-")
|
||||
# All parts should be valid hex
|
||||
assert re.match(r"^[0-9a-f]+$", parts[1])
|
||||
assert re.match(r"^[0-9a-f]+$", parts[2])
|
||||
|
||||
|
||||
class TestParseTraceparentHeader:
|
||||
def test_parses_valid_traceparent(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
result = parse_traceparent_header(traceparent)
|
||||
|
||||
assert result == "5b8aa5a2d2c872e8321cf37308d69df2"
|
||||
|
||||
def test_returns_none_for_invalid_format(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
# Wrong number of parts
|
||||
assert parse_traceparent_header("00-abc-def") is None
|
||||
# Wrong trace_id length
|
||||
assert parse_traceparent_header("00-abc-def-01") is None
|
||||
|
||||
def test_returns_none_for_empty_string(self):
|
||||
from core.helper.trace_id_helper import parse_traceparent_header
|
||||
|
||||
assert parse_traceparent_header("") is None
|
||||
|
|
@ -73,6 +73,7 @@ import pytest
|
|||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
|
@ -1518,6 +1519,282 @@ class TestRetrievalService:
|
|||
call_kwargs = mock_retrieve.call_args.kwargs
|
||||
assert call_kwargs["reranking_model"] == reranking_model
|
||||
|
||||
# ==================== Multiple Retrieve Thread Tests ====================
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||
def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset(
|
||||
self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||
):
|
||||
"""
|
||||
Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1.
|
||||
|
||||
When there is only one dataset, the second reranking is unnecessary
|
||||
because the documents are already ranked from the first retrieval.
|
||||
This optimization avoids the overhead of reranking when it won't
|
||||
provide any benefit.
|
||||
|
||||
Verifies:
|
||||
- DataPostProcessor is NOT called when dataset_count == 1
|
||||
- Documents are still added to all_documents
|
||||
- Standard scoring logic is applied instead
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Create test documents
|
||||
doc1 = Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
doc2 = Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
# Mock _retriever to return documents
|
||||
def side_effect_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.extend([doc1, doc2])
|
||||
|
||||
mock_retriever.side_effect = side_effect_retriever
|
||||
|
||||
# Set up dataset with high_quality indexing
|
||||
mock_dataset.indexing_technique = "high_quality"
|
||||
|
||||
all_documents = []
|
||||
|
||||
# Act - Call with dataset_count = 1
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=mock_flask_app,
|
||||
available_datasets=[mock_dataset],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
weights=None,
|
||||
top_k=5,
|
||||
score_threshold=0.5,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=1, # Single dataset - should skip second reranking
|
||||
)
|
||||
|
||||
# Assert
|
||||
# DataPostProcessor should NOT be called (second reranking skipped)
|
||||
mock_data_processor_class.assert_not_called()
|
||||
|
||||
# Documents should still be added to all_documents
|
||||
assert len(all_documents) == 2
|
||||
assert all_documents[0].page_content == "Test content 1"
|
||||
assert all_documents[1].page_content == "Test content 2"
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
|
||||
def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets(
|
||||
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||
):
|
||||
"""
|
||||
Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1.
|
||||
|
||||
When there are multiple datasets, the second reranking is necessary
|
||||
to merge and re-rank results from different datasets. This ensures
|
||||
the most relevant documents across all datasets are returned.
|
||||
|
||||
Verifies:
|
||||
- DataPostProcessor IS called when dataset_count > 1
|
||||
- Reranking is applied with correct parameters
|
||||
- Documents are processed correctly
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Create test documents
|
||||
doc1 = Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
doc2 = Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
# Mock _retriever to return documents
|
||||
def side_effect_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.extend([doc1, doc2])
|
||||
|
||||
mock_retriever.side_effect = side_effect_retriever
|
||||
|
||||
# Set up dataset with high_quality indexing
|
||||
mock_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# Mock DataPostProcessor instance and its invoke method
|
||||
mock_processor_instance = Mock()
|
||||
# Simulate reranking - return documents in reversed order with updated scores
|
||||
reranked_docs = [
|
||||
Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
),
|
||||
Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
),
|
||||
]
|
||||
mock_processor_instance.invoke.return_value = reranked_docs
|
||||
mock_data_processor_class.return_value = mock_processor_instance
|
||||
|
||||
all_documents = []
|
||||
|
||||
# Create second dataset
|
||||
mock_dataset2 = Mock(spec=Dataset)
|
||||
mock_dataset2.id = str(uuid4())
|
||||
mock_dataset2.indexing_technique = "high_quality"
|
||||
mock_dataset2.provider = "dify"
|
||||
|
||||
# Act - Call with dataset_count = 2
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=mock_flask_app,
|
||||
available_datasets=[mock_dataset, mock_dataset2],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
weights=None,
|
||||
top_k=5,
|
||||
score_threshold=0.5,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=2, # Multiple datasets - should perform second reranking
|
||||
)
|
||||
|
||||
# Assert
|
||||
# DataPostProcessor SHOULD be called (second reranking performed)
|
||||
mock_data_processor_class.assert_called_once_with(
|
||||
tenant_id,
|
||||
"reranking_model",
|
||||
{"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
# Verify invoke was called with correct parameters
|
||||
mock_processor_instance.invoke.assert_called_once()
|
||||
|
||||
# Documents should be added to all_documents after reranking
|
||||
assert len(all_documents) == 2
|
||||
# The reranked order should be reflected
|
||||
assert all_documents[0].page_content == "Test content 2"
|
||||
assert all_documents[1].page_content == "Test content 1"
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
|
||||
def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring(
|
||||
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||
):
|
||||
"""
|
||||
Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1
|
||||
and reranking is enabled.
|
||||
|
||||
When there's only one dataset, instead of using DataPostProcessor,
|
||||
the method should fall through to the standard scoring logic
|
||||
(calculate_vector_score for high_quality datasets).
|
||||
|
||||
Verifies:
|
||||
- DataPostProcessor is NOT called
|
||||
- calculate_vector_score IS called for high_quality indexing
|
||||
- Documents are scored correctly
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Create test documents
|
||||
doc1 = Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
doc2 = Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
# Mock _retriever to return documents
|
||||
def side_effect_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.extend([doc1, doc2])
|
||||
|
||||
mock_retriever.side_effect = side_effect_retriever
|
||||
|
||||
# Set up dataset with high_quality indexing
|
||||
mock_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# Mock calculate_vector_score to return scored documents
|
||||
scored_docs = [
|
||||
Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
),
|
||||
]
|
||||
mock_calculate_vector_score.return_value = scored_docs
|
||||
|
||||
all_documents = []
|
||||
|
||||
# Act - Call with dataset_count = 1
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=mock_flask_app,
|
||||
available_datasets=[mock_dataset],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True, # Reranking enabled but should be skipped for single dataset
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
weights=None,
|
||||
top_k=5,
|
||||
score_threshold=0.5,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=1,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# DataPostProcessor should NOT be called
|
||||
mock_data_processor_class.assert_not_called()
|
||||
|
||||
# calculate_vector_score SHOULD be called for high_quality datasets
|
||||
mock_calculate_vector_score.assert_called_once()
|
||||
call_args = mock_calculate_vector_score.call_args
|
||||
assert call_args[0][1] == 5 # top_k
|
||||
|
||||
# Documents should be added after standard scoring
|
||||
assert len(all_documents) == 1
|
||||
assert all_documents[0].page_content == "Test content 1"
|
||||
|
||||
|
||||
class TestRetrievalMethods:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class TestRetrievalService:
|
||||
@pytest.fixture
|
||||
def mock_dataset(self) -> Dataset:
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = str(uuid4())
|
||||
dataset.tenant_id = str(uuid4())
|
||||
dataset.name = "test_dataset"
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.provider = "dify"
|
||||
return dataset
|
||||
|
||||
def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset):
|
||||
"""
|
||||
Repro test for current bug:
|
||||
reranking runs after `with flask_app.app_context():` exits.
|
||||
`_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`,
|
||||
so we must assert from that list (not from an outer try/except).
|
||||
"""
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
flask_app = Flask(__name__)
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# second dataset to ensure dataset_count > 1 reranking branch
|
||||
secondary_dataset = Mock(spec=Dataset)
|
||||
secondary_dataset.id = str(uuid4())
|
||||
secondary_dataset.provider = "dify"
|
||||
secondary_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# retriever returns 1 doc into internal list (all_documents_item)
|
||||
document = Document(
|
||||
page_content="Context aware doc",
|
||||
metadata={
|
||||
"doc_id": "doc1",
|
||||
"score": 0.95,
|
||||
"document_id": str(uuid4()),
|
||||
"dataset_id": mock_dataset.id,
|
||||
},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
def fake_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.append(document)
|
||||
|
||||
called = {"init": 0, "invoke": 0}
|
||||
|
||||
class ContextRequiredPostProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
called["init"] += 1
|
||||
# will raise RuntimeError if no Flask app context exists
|
||||
_ = current_app.name
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
called["invoke"] += 1
|
||||
_ = current_app.name
|
||||
return kwargs.get("documents") or args[1]
|
||||
|
||||
# output list from _multiple_retrieve_thread
|
||||
all_documents: list[Document] = []
|
||||
|
||||
# IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here
|
||||
thread_exceptions: list[Exception] = []
|
||||
|
||||
def target():
|
||||
with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever):
|
||||
with patch(
|
||||
"core.rag.retrieval.dataset_retrieval.DataPostProcessor",
|
||||
ContextRequiredPostProcessor,
|
||||
):
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=flask_app,
|
||||
available_datasets=[mock_dataset, secondary_dataset],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={
|
||||
"reranking_provider_name": "cohere",
|
||||
"reranking_model_name": "rerank-v2",
|
||||
},
|
||||
weights=None,
|
||||
top_k=3,
|
||||
score_threshold=0.0,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=2, # force reranking branch
|
||||
thread_exceptions=thread_exceptions, # ✅ key
|
||||
)
|
||||
|
||||
t = threading.Thread(target=target)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
# Ensure reranking branch was actually executed
|
||||
assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run."
|
||||
|
||||
# Current buggy code should record an exception (not raise it)
|
||||
assert not thread_exceptions, thread_exceptions
|
||||
|
|
@ -103,13 +103,25 @@ class MockNodeFactory(DifyNodeFactory):
|
|||
|
||||
# Create mock node instance
|
||||
mock_class = self._mock_node_types[node_type]
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
)
|
||||
if node_type == NodeType.CODE:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
code_executor=self._code_executor,
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
)
|
||||
|
||||
return mock_instance
|
||||
|
||||
|
|
|
|||
|
|
@ -40,12 +40,14 @@ class MockNodeMixin:
|
|||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
**kwargs,
|
||||
)
|
||||
self.mock_config = mock_config
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,24 @@ This module tests the functionality of MockTemplateTransformNode and MockCodeNod
|
|||
to ensure they work correctly with the TableTestRunner.
|
||||
"""
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode
|
||||
|
||||
DEFAULT_CODE_LIMITS = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
class TestMockTemplateTransformNode:
|
||||
"""Test cases for MockTemplateTransformNode."""
|
||||
|
|
@ -306,6 +319,7 @@ class TestMockCodeNode:
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
code_limits=DEFAULT_CODE_LIMITS,
|
||||
)
|
||||
|
||||
# Run the node
|
||||
|
|
@ -370,6 +384,7 @@ class TestMockCodeNode:
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
code_limits=DEFAULT_CODE_LIMITS,
|
||||
)
|
||||
|
||||
# Run the node
|
||||
|
|
@ -438,6 +453,7 @@ class TestMockCodeNode:
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
code_limits=DEFAULT_CODE_LIMITS,
|
||||
)
|
||||
|
||||
# Run the node
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
|
|
@ -7,6 +8,18 @@ from core.workflow.nodes.code.exc import (
|
|||
DepthLimitError,
|
||||
OutputValidationError,
|
||||
)
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
CodeNode._limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
max_precision=dify_config.CODE_MAX_PRECISION,
|
||||
max_depth=dify_config.CODE_MAX_DEPTH,
|
||||
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
class TestCodeNodeExceptions:
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
|
@ -127,7 +127,9 @@ class TestTemplateTransformNode:
|
|||
"""Test version class method."""
|
||||
assert TemplateTransformNode.version() == "1"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_simple_template(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
|
|
@ -145,7 +147,7 @@ class TestTemplateTransformNode:
|
|||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
|
||||
# Setup mock executor
|
||||
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
|
||||
mock_execute.return_value = "Hello Alice, you are 30 years old!"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -162,7 +164,9 @@ class TestTemplateTransformNode:
|
|||
assert result.inputs["name"] == "Alice"
|
||||
assert result.inputs["age"] == 30
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with None variable values."""
|
||||
node_data = {
|
||||
|
|
@ -172,7 +176,7 @@ class TestTemplateTransformNode:
|
|||
}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
mock_execute.return_value = {"result": "Value: "}
|
||||
mock_execute.return_value = "Value: "
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -187,13 +191,15 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.inputs["value"] is None
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_code_execution_error(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when code execution fails."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.side_effect = CodeExecutionError("Template syntax error")
|
||||
mock_execute.side_effect = TemplateRenderError("Template syntax error")
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -208,14 +214,16 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Template syntax error" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
|
||||
def test_run_output_length_exceeds_limit(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when output exceeds maximum length."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
|
||||
mock_execute.return_value = "This is a very long output that exceeds the limit"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -230,7 +238,9 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Output length exceeds" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_complex_jinja2_template(
|
||||
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
|
|
@ -257,7 +267,7 @@ class TestTemplateTransformNode:
|
|||
("sys", "show_total"): mock_show_total,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
|
||||
mock_execute.return_value = "apple, banana, orange (Total: 3)"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -292,7 +302,9 @@ class TestTemplateTransformNode:
|
|||
assert mapping["node_123.var1"] == ["sys", "input1"]
|
||||
assert mapping["node_123.var2"] == ["sys", "input2"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with no variables (static template)."""
|
||||
node_data = {
|
||||
|
|
@ -301,7 +313,7 @@ class TestTemplateTransformNode:
|
|||
"template": "This is a static message.",
|
||||
}
|
||||
|
||||
mock_execute.return_value = {"result": "This is a static message."}
|
||||
mock_execute.return_value = "This is a static message."
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -317,7 +329,9 @@ class TestTemplateTransformNode:
|
|||
assert result.outputs["output"] == "This is a static message."
|
||||
assert result.inputs == {}
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with numeric variable values."""
|
||||
node_data = {
|
||||
|
|
@ -339,7 +353,7 @@ class TestTemplateTransformNode:
|
|||
("sys", "quantity"): mock_quantity,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "Total: $31.5"}
|
||||
mock_execute.return_value = "Total: $31.5"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -354,7 +368,9 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["output"] == "Total: $31.5"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with dictionary variable values."""
|
||||
node_data = {
|
||||
|
|
@ -367,7 +383,7 @@ class TestTemplateTransformNode:
|
|||
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
|
||||
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
|
||||
mock_execute.return_value = "Name: John Doe, Email: john@example.com"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -383,7 +399,9 @@ class TestTemplateTransformNode:
|
|||
assert "John Doe" in result.outputs["output"]
|
||||
assert "john@example.com" in result.outputs["output"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with list variable values."""
|
||||
node_data = {
|
||||
|
|
@ -396,7 +414,7 @@ class TestTemplateTransformNode:
|
|||
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
|
||||
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
|
||||
mock_execute.return_value = "Tags: #python #ai #workflow "
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
|
|||
|
|
@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite():
|
|||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
def test_external_api_param_mapping_and_quota():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None)
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
|
||||
|
||||
def test_unauthorized_and_force_logout_clears_cookies():
|
||||
|
|
|
|||
18
api/uv.lock
18
api/uv.lock
|
|
@ -1953,14 +1953,14 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "fickling"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "stdlib-list" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/41/94/0d0ce455952c036cfee235637f786c1d1d07d1b90f6a4dfb50e0eff929d6/fickling-0.1.5.tar.gz", hash = "sha256:92f9b49e717fa8dbc198b4b7b685587adb652d85aa9ede8131b3e44494efca05", size = 282462, upload-time = "2025-11-18T05:04:30.748Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/07/ab/7571453f9365c17c047b5a7b7e82692a7f6be51203f295030886758fd57a/fickling-0.1.6.tar.gz", hash = "sha256:03cb5d7bd09f9169c7583d2079fad4b3b88b25f865ed0049172e5cb68582311d", size = 284033, upload-time = "2025-12-15T18:14:58.721Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/a7/d25912b2e3a5b0a37e6f460050bbc396042b5906a6563a1962c484abc3c6/fickling-0.1.5-py3-none-any.whl", hash = "sha256:6aed7270bfa276e188b0abe043a27b3a042129d28ec1fa6ff389bdcc5ad178bb", size = 46240, upload-time = "2025-11-18T05:04:29.048Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/99/cc04258dda421bc612cdfe4be8c253f45b922f1c7f268b5a0b9962d9cd12/fickling-0.1.6-py3-none-any.whl", hash = "sha256:465d0069548bfc731bdd75a583cb4cf5a4b2666739c0f76287807d724b147ed3", size = 47922, upload-time = "2025-12-15T18:14:57.526Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2955,14 +2955,14 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "intersystems-irispython"
|
||||
version = "5.3.0"
|
||||
version = "5.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/5b/8eac672a6ef26bef6ef79a7c9557096167b50c4d3577d558ae6999c195fe/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-macosx_10_9_universal2.whl", hash = "sha256:634c9b4ec620837d830ff49543aeb2797a1ce8d8570a0e868398b85330dfcc4d", size = 6736686, upload-time = "2025-12-19T16:24:57.734Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/17/bab3e525ffb6711355f7feea18c1b7dced9c2484cecbcdd83f74550398c0/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf912f30f85e2a42f2c2ea77fbeb98a24154d5ea7428a50382786a684ec4f583", size = 16005259, upload-time = "2025-12-19T16:25:05.578Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/59/9bb79d9e32e3e55fc9aed8071a797b4497924cbc6457cea9255bb09320b7/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be5659a6bb57593910f2a2417eddb9f5dc2f93a337ead6ddca778f557b8a359a", size = 15638040, upload-time = "2025-12-19T16:24:54.429Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/47/654ccf9c5cca4f5491f070888544165c9e2a6a485e320ea703e4e38d2358/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win32.whl", hash = "sha256:583e4f17088c1e0530f32efda1c0ccb02993cbc22035bc8b4c71d8693b04ee7e", size = 2879644, upload-time = "2025-12-19T16:24:59.945Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/95/19cc13d09f1b4120bd41b1434509052e1d02afd27f2679266d7ad9cc1750/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win_amd64.whl", hash = "sha256:1d5d40450a0cdeec2a1f48d12d946a8a8ffc7c128576fcae7d58e66e3a127eae", size = 3522092, upload-time = "2025-12-19T16:25:01.834Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ PYTHONIOENCODING=utf-8
|
|||
# The log level for the application.
|
||||
# Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`
|
||||
LOG_LEVEL=INFO
|
||||
# Log output format: text or json
|
||||
LOG_OUTPUT_FORMAT=text
|
||||
# Log file path
|
||||
LOG_FILE=/app/logs/server.log
|
||||
# Log file max size, the unit is MB
|
||||
|
|
@ -231,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
|
|||
# You can adjust the database configuration according to your needs.
|
||||
# ------------------------------
|
||||
|
||||
# Database type, supported values are `postgresql` and `mysql`
|
||||
# Database type, supported values are `postgresql`, `mysql`, `oceanbase`, `seekdb`
|
||||
DB_TYPE=postgresql
|
||||
# For MySQL, only `root` user is supported for now
|
||||
DB_USERNAME=postgres
|
||||
|
|
@ -531,7 +533,7 @@ SUPABASE_URL=your-server-url
|
|||
# ------------------------------
|
||||
|
||||
# The type of vector store to use.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
|
@ -542,9 +544,9 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
|||
WEAVIATE_GRPC_ENDPOINT=grpc://weaviate:50051
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
# For OceanBase metadata database configuration, available when `DB_TYPE` is `mysql` and `COMPOSE_PROFILES` includes `oceanbase`.
|
||||
# For OceanBase metadata database configuration, available when `DB_TYPE` is `oceanbase`.
|
||||
# For OceanBase vector database configuration, available when `VECTOR_STORE` is `oceanbase`
|
||||
# If you want to use OceanBase as both vector database and metadata database, you need to set `DB_TYPE` to `mysql`, `COMPOSE_PROFILES` is `oceanbase`, and set Database Configuration is the same as the vector database.
|
||||
# If you want to use OceanBase as both vector database and metadata database, you need to set both `DB_TYPE` and `VECTOR_STORE` to `oceanbase`, and set Database Configuration is the same as the vector database.
|
||||
# seekdb is the lite version of OceanBase and shares the connection configuration with OceanBase.
|
||||
OCEANBASE_VECTOR_HOST=oceanbase
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ services:
|
|||
- ./middleware.env
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text}
|
||||
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
|
||||
REDIS_HOST: ${REDIS_HOST:-redis}
|
||||
REDIS_PORT: ${REDIS_PORT:-6379}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ x-shared-env: &shared-api-worker-env
|
|||
LC_ALL: ${LC_ALL:-en_US.UTF-8}
|
||||
PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8}
|
||||
LOG_LEVEL: ${LOG_LEVEL:-INFO}
|
||||
LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text}
|
||||
LOG_FILE: ${LOG_FILE:-/app/logs/server.log}
|
||||
LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20}
|
||||
LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5}
|
||||
|
|
|
|||
|
|
@ -54,3 +54,52 @@ http_access allow src_all
|
|||
|
||||
# Unless the option's size is increased, an error will occur when uploading more than two files.
|
||||
client_request_buffer_max_size 100 MB
|
||||
|
||||
################################## Performance & Concurrency ###############################
|
||||
# Increase file descriptor limit for high concurrency
|
||||
max_filedescriptors 65536
|
||||
|
||||
# Timeout configurations for image requests
|
||||
connect_timeout 30 seconds
|
||||
request_timeout 2 minutes
|
||||
read_timeout 2 minutes
|
||||
client_lifetime 5 minutes
|
||||
shutdown_lifetime 30 seconds
|
||||
|
||||
# Persistent connections - improve performance for multiple requests
|
||||
server_persistent_connections on
|
||||
client_persistent_connections on
|
||||
persistent_request_timeout 30 seconds
|
||||
pconn_timeout 1 minute
|
||||
|
||||
# Connection pool and concurrency limits
|
||||
client_db on
|
||||
server_idle_pconn_timeout 2 minutes
|
||||
client_idle_pconn_timeout 2 minutes
|
||||
|
||||
# Quick abort settings - don't abort requests that are mostly done
|
||||
quick_abort_min 16 KB
|
||||
quick_abort_max 16 MB
|
||||
quick_abort_pct 95
|
||||
|
||||
# Memory and cache optimization
|
||||
memory_cache_mode disk
|
||||
cache_mem 256 MB
|
||||
maximum_object_size_in_memory 512 KB
|
||||
|
||||
# DNS resolver settings for better performance
|
||||
dns_timeout 30 seconds
|
||||
dns_retransmit_interval 5 seconds
|
||||
# By default, Squid uses the system's configured DNS resolvers.
|
||||
# If you need to override them, set dns_nameservers to appropriate servers
|
||||
# for your environment (for example, internal/corporate DNS). The following
|
||||
# is an example using public DNS and SHOULD be customized before use:
|
||||
# dns_nameservers 8.8.8.8 8.8.4.4
|
||||
|
||||
# Logging format for better debugging
|
||||
logformat dify_log %ts.%03tu %6tr %>a %Ss/%03>Hs %<st %rm %ru %[un %Sh/%<a %mt
|
||||
access_log daemon:/var/log/squid/access.log dify_log
|
||||
|
||||
# Access log to track concurrent requests and timeouts
|
||||
logfile_rotate 10
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,12 @@ logs
|
|||
|
||||
# node
|
||||
node_modules
|
||||
dist
|
||||
build
|
||||
coverage
|
||||
.husky
|
||||
.next
|
||||
.pnpm-store
|
||||
|
||||
# vscode
|
||||
.vscode
|
||||
|
|
@ -22,3 +26,7 @@ node_modules
|
|||
|
||||
# Jetbrains
|
||||
.idea
|
||||
|
||||
# git
|
||||
.git
|
||||
.gitignore
|
||||
|
|
@ -47,6 +47,8 @@ NEXT_PUBLIC_TOP_K_MAX_VALUE=10
|
|||
|
||||
# The maximum number of tokens for segmentation
|
||||
NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH at container startup (Docker only)
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
|
||||
# Maximum loop count in the workflow
|
||||
NEXT_PUBLIC_LOOP_NODE_MAX_COUNT=100
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ RUN apk add --no-cache tzdata
|
|||
RUN corepack enable
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
ENV NEXT_PUBLIC_BASE_PATH=""
|
||||
ARG NEXT_PUBLIC_BASE_PATH=""
|
||||
ENV NEXT_PUBLIC_BASE_PATH="$NEXT_PUBLIC_BASE_PATH"
|
||||
|
||||
|
||||
# install packages
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ const GotoAnything: FC<Props> = ({
|
|||
isWorkflowPage,
|
||||
isRagPipelinePage,
|
||||
defaultLocale,
|
||||
Object.keys(Actions).sort().join(','),
|
||||
Actions,
|
||||
],
|
||||
queryFn: async () => {
|
||||
const query = searchQueryDebouncedValue.toLowerCase()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
'use client'
|
||||
import type { PluginDetail } from '../types'
|
||||
import type { FilterState } from './filter-management'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import PluginDetailPanel from '@/app/components/plugins/plugin-detail-panel'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { renderI18nObject } from '@/i18n-config'
|
||||
import { useInstalledLatestVersion, useInstalledPluginList, useInvalidateInstalledPluginList } from '@/service/use-plugins'
|
||||
import Loading from '../../base/loading'
|
||||
import { PluginSource } from '../types'
|
||||
|
|
@ -13,8 +16,34 @@ import Empty from './empty'
|
|||
import FilterManagement from './filter-management'
|
||||
import List from './list'
|
||||
|
||||
const matchesSearchQuery = (plugin: PluginDetail & { latest_version: string }, query: string, locale: string): boolean => {
|
||||
if (!query)
|
||||
return true
|
||||
const lowerQuery = query.toLowerCase()
|
||||
const { declaration } = plugin
|
||||
// Match plugin_id
|
||||
if (plugin.plugin_id.toLowerCase().includes(lowerQuery))
|
||||
return true
|
||||
// Match plugin name
|
||||
if (plugin.name?.toLowerCase().includes(lowerQuery))
|
||||
return true
|
||||
// Match declaration name
|
||||
if (declaration.name?.toLowerCase().includes(lowerQuery))
|
||||
return true
|
||||
// Match localized label
|
||||
const label = renderI18nObject(declaration.label, locale)
|
||||
if (label?.toLowerCase().includes(lowerQuery))
|
||||
return true
|
||||
// Match localized description
|
||||
const description = renderI18nObject(declaration.description, locale)
|
||||
if (description?.toLowerCase().includes(lowerQuery))
|
||||
return true
|
||||
return false
|
||||
}
|
||||
|
||||
const PluginsPanel = () => {
|
||||
const { t } = useTranslation()
|
||||
const locale = useGetLanguage()
|
||||
const filters = usePluginPageContext(v => v.filters) as FilterState
|
||||
const setFilters = usePluginPageContext(v => v.setFilters)
|
||||
const { data: pluginList, isLoading: isPluginListLoading, isFetching, isLastPage, loadNextPage } = useInstalledPluginList()
|
||||
|
|
@ -48,11 +77,11 @@ const PluginsPanel = () => {
|
|||
return (
|
||||
(categories.length === 0 || categories.includes(plugin.declaration.category))
|
||||
&& (tags.length === 0 || tags.some(tag => plugin.declaration.tags.includes(tag)))
|
||||
&& (searchQuery === '' || plugin.plugin_id.toLowerCase().includes(searchQuery.toLowerCase()))
|
||||
&& matchesSearchQuery(plugin, searchQuery, locale)
|
||||
)
|
||||
})
|
||||
return filteredList
|
||||
}, [pluginListWithLatestVersion, filters])
|
||||
}, [pluginListWithLatestVersion, filters, locale])
|
||||
|
||||
const currentPluginDetail = useMemo(() => {
|
||||
const detail = pluginListWithLatestVersion.find(plugin => plugin.plugin_id === currentPluginID)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,971 @@
|
|||
import type { PanelProps } from '@/app/components/workflow/panel'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import RagPipelinePanel from './index'
|
||||
|
||||
// ============================================================================
|
||||
// Mock External Dependencies
|
||||
// ============================================================================
|
||||
|
||||
// Type definitions for dynamic module
|
||||
type DynamicModule = {
|
||||
default?: React.ComponentType<Record<string, unknown>>
|
||||
}
|
||||
|
||||
type PromiseOrModule = Promise<DynamicModule> | DynamicModule
|
||||
|
||||
// Mock next/dynamic to return synchronous components immediately
|
||||
vi.mock('next/dynamic', () => ({
|
||||
default: (loader: () => PromiseOrModule, _options?: Record<string, unknown>) => {
|
||||
let Component: React.ComponentType<Record<string, unknown>> | null = null
|
||||
|
||||
// Try to resolve the loader synchronously for mocked modules
|
||||
try {
|
||||
const result = loader() as PromiseOrModule
|
||||
if (result && typeof (result as Promise<DynamicModule>).then === 'function') {
|
||||
// For async modules, we need to handle them specially
|
||||
// This will work with vi.mock since mocks resolve synchronously
|
||||
(result as Promise<DynamicModule>).then((mod: DynamicModule) => {
|
||||
Component = (mod.default || mod) as React.ComponentType<Record<string, unknown>>
|
||||
})
|
||||
}
|
||||
else if (result) {
|
||||
Component = ((result as DynamicModule).default || result) as React.ComponentType<Record<string, unknown>>
|
||||
}
|
||||
}
|
||||
catch {
|
||||
// If the module can't be resolved, Component stays null
|
||||
}
|
||||
|
||||
// Return a simple wrapper that renders the component or null
|
||||
const DynamicComponent = React.forwardRef((props: Record<string, unknown>, ref: React.Ref<unknown>) => {
|
||||
// For mocked modules, Component should already be set
|
||||
if (Component)
|
||||
return <Component {...props} ref={ref} />
|
||||
|
||||
return null
|
||||
})
|
||||
|
||||
DynamicComponent.displayName = 'DynamicComponent'
|
||||
return DynamicComponent
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock workflow store
|
||||
let mockHistoryWorkflowData: Record<string, unknown> | null = null
|
||||
let mockShowDebugAndPreviewPanel = false
|
||||
let mockShowGlobalVariablePanel = false
|
||||
let mockShowInputFieldPanel = false
|
||||
let mockShowInputFieldPreviewPanel = false
|
||||
let mockInputFieldEditPanelProps: Record<string, unknown> | null = null
|
||||
let mockPipelineId = 'test-pipeline-123'
|
||||
|
||||
type MockStoreState = {
|
||||
historyWorkflowData: Record<string, unknown> | null
|
||||
showDebugAndPreviewPanel: boolean
|
||||
showGlobalVariablePanel: boolean
|
||||
showInputFieldPanel: boolean
|
||||
showInputFieldPreviewPanel: boolean
|
||||
inputFieldEditPanelProps: Record<string, unknown> | null
|
||||
pipelineId: string
|
||||
}
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useStore: (selector: (state: MockStoreState) => unknown) => {
|
||||
const state: MockStoreState = {
|
||||
historyWorkflowData: mockHistoryWorkflowData,
|
||||
showDebugAndPreviewPanel: mockShowDebugAndPreviewPanel,
|
||||
showGlobalVariablePanel: mockShowGlobalVariablePanel,
|
||||
showInputFieldPanel: mockShowInputFieldPanel,
|
||||
showInputFieldPreviewPanel: mockShowInputFieldPreviewPanel,
|
||||
inputFieldEditPanelProps: mockInputFieldEditPanelProps,
|
||||
pipelineId: mockPipelineId,
|
||||
}
|
||||
return selector(state)
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock Panel component to capture props and render children
|
||||
let capturedPanelProps: PanelProps | null = null
|
||||
vi.mock('@/app/components/workflow/panel', () => ({
|
||||
default: (props: PanelProps) => {
|
||||
capturedPanelProps = props
|
||||
return (
|
||||
<div data-testid="workflow-panel">
|
||||
<div data-testid="panel-left">{props.components?.left}</div>
|
||||
<div data-testid="panel-right">{props.components?.right}</div>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock Record component
|
||||
vi.mock('@/app/components/workflow/panel/record', () => ({
|
||||
default: () => <div data-testid="record-panel">Record Panel</div>,
|
||||
}))
|
||||
|
||||
// Mock TestRunPanel component
|
||||
vi.mock('@/app/components/rag-pipeline/components/panel/test-run', () => ({
|
||||
default: () => <div data-testid="test-run-panel">Test Run Panel</div>,
|
||||
}))
|
||||
|
||||
// Mock InputFieldPanel component
|
||||
vi.mock('./input-field', () => ({
|
||||
default: () => <div data-testid="input-field-panel">Input Field Panel</div>,
|
||||
}))
|
||||
|
||||
// Mock InputFieldEditorPanel component
|
||||
const mockInputFieldEditorProps = vi.fn()
|
||||
vi.mock('./input-field/editor', () => ({
|
||||
default: (props: Record<string, unknown>) => {
|
||||
mockInputFieldEditorProps(props)
|
||||
return <div data-testid="input-field-editor-panel">Input Field Editor Panel</div>
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock PreviewPanel component
|
||||
vi.mock('./input-field/preview', () => ({
|
||||
default: () => <div data-testid="preview-panel">Preview Panel</div>,
|
||||
}))
|
||||
|
||||
// Mock GlobalVariablePanel component
|
||||
vi.mock('@/app/components/workflow/panel/global-variable-panel', () => ({
|
||||
default: () => <div data-testid="global-variable-panel">Global Variable Panel</div>,
|
||||
}))
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
type SetupMockOptions = {
|
||||
historyWorkflowData?: Record<string, unknown> | null
|
||||
showDebugAndPreviewPanel?: boolean
|
||||
showGlobalVariablePanel?: boolean
|
||||
showInputFieldPanel?: boolean
|
||||
showInputFieldPreviewPanel?: boolean
|
||||
inputFieldEditPanelProps?: Record<string, unknown> | null
|
||||
pipelineId?: string
|
||||
}
|
||||
|
||||
const setupMocks = (options?: SetupMockOptions) => {
|
||||
mockHistoryWorkflowData = options?.historyWorkflowData ?? null
|
||||
mockShowDebugAndPreviewPanel = options?.showDebugAndPreviewPanel ?? false
|
||||
mockShowGlobalVariablePanel = options?.showGlobalVariablePanel ?? false
|
||||
mockShowInputFieldPanel = options?.showInputFieldPanel ?? false
|
||||
mockShowInputFieldPreviewPanel = options?.showInputFieldPreviewPanel ?? false
|
||||
mockInputFieldEditPanelProps = options?.inputFieldEditPanelProps ?? null
|
||||
mockPipelineId = options?.pipelineId ?? 'test-pipeline-123'
|
||||
capturedPanelProps = null
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// RagPipelinePanel Component Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('RagPipelinePanel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Rendering Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('workflow-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render Panel component with correct structure', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('panel-left')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('panel-right')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass versionHistoryPanelProps to Panel', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: 'my-pipeline-456' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps).toBeDefined()
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
|
||||
'/rag/pipelines/my-pipeline-456/workflows',
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Memoization Tests - versionHistoryPanelProps
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Memoization - versionHistoryPanelProps', () => {
|
||||
it('should compute correct getVersionListUrl based on pipelineId', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: 'pipeline-abc' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
|
||||
'/rag/pipelines/pipeline-abc/workflows',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should compute correct deleteVersionUrl function', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: 'pipeline-xyz' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
const deleteUrl = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-1')
|
||||
expect(deleteUrl).toBe('/rag/pipelines/pipeline-xyz/workflows/version-1')
|
||||
})
|
||||
})
|
||||
|
||||
it('should compute correct updateVersionUrl function', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: 'pipeline-def' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
const updateUrl = capturedPanelProps?.versionHistoryPanelProps?.updateVersionUrl?.('version-2')
|
||||
expect(updateUrl).toBe('/rag/pipelines/pipeline-def/workflows/version-2')
|
||||
})
|
||||
})
|
||||
|
||||
it('should set latestVersionId to empty string', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps?.latestVersionId).toBe('')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Memoization Tests - panelProps
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Memoization - panelProps', () => {
|
||||
it('should pass components.left to Panel', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.components?.left).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass components.right to Panel', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.components?.right).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass versionHistoryPanelProps to panelProps', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Component Memoization Tests (React.memo)
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Component Memoization', () => {
|
||||
it('should be wrapped with React.memo', async () => {
|
||||
// The component should not break when re-rendered
|
||||
const { rerender } = render(<RagPipelinePanel />)
|
||||
|
||||
// Act - rerender without prop changes
|
||||
rerender(<RagPipelinePanel />)
|
||||
|
||||
// Assert - component should still render correctly
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('workflow-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// RagPipelinePanelOnRight Component Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('RagPipelinePanelOnRight', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Conditional Rendering - Record Panel
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Record Panel Conditional Rendering', () => {
|
||||
it('should render Record panel when historyWorkflowData exists', async () => {
|
||||
// Arrange
|
||||
setupMocks({ historyWorkflowData: { id: 'history-1' } })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render Record panel when historyWorkflowData is null', async () => {
|
||||
// Arrange
|
||||
setupMocks({ historyWorkflowData: null })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('record-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render Record panel when historyWorkflowData is undefined', async () => {
|
||||
// Arrange
|
||||
setupMocks({ historyWorkflowData: undefined })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('record-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Conditional Rendering - TestRun Panel
|
||||
// -------------------------------------------------------------------------
|
||||
describe('TestRun Panel Conditional Rendering', () => {
|
||||
it('should render TestRun panel when showDebugAndPreviewPanel is true', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showDebugAndPreviewPanel: true })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render TestRun panel when showDebugAndPreviewPanel is false', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showDebugAndPreviewPanel: false })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('test-run-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Conditional Rendering - GlobalVariable Panel
|
||||
// -------------------------------------------------------------------------
|
||||
describe('GlobalVariable Panel Conditional Rendering', () => {
|
||||
it('should render GlobalVariable panel when showGlobalVariablePanel is true', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showGlobalVariablePanel: true })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render GlobalVariable panel when showGlobalVariablePanel is false', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showGlobalVariablePanel: false })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('global-variable-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Multiple Panels Rendering
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Multiple Panels Rendering', () => {
|
||||
it('should render all right panels when all conditions are true', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
historyWorkflowData: { id: 'history-1' },
|
||||
showDebugAndPreviewPanel: true,
|
||||
showGlobalVariablePanel: true,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render no right panels when all conditions are false', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
historyWorkflowData: null,
|
||||
showDebugAndPreviewPanel: false,
|
||||
showGlobalVariablePanel: false,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('record-panel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('test-run-panel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('global-variable-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render only Record and TestRun panels', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
historyWorkflowData: { id: 'history-1' },
|
||||
showDebugAndPreviewPanel: true,
|
||||
showGlobalVariablePanel: false,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('global-variable-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// RagPipelinePanelOnLeft Component Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('RagPipelinePanelOnLeft', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Conditional Rendering - Preview Panel
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Preview Panel Conditional Rendering', () => {
|
||||
it('should render Preview panel when showInputFieldPreviewPanel is true', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showInputFieldPreviewPanel: true })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render Preview panel when showInputFieldPreviewPanel is false', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showInputFieldPreviewPanel: false })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('preview-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Conditional Rendering - InputFieldEditor Panel
|
||||
// -------------------------------------------------------------------------
|
||||
describe('InputFieldEditor Panel Conditional Rendering', () => {
|
||||
it('should render InputFieldEditor panel when inputFieldEditPanelProps is provided', async () => {
|
||||
// Arrange
|
||||
const editProps = {
|
||||
onClose: vi.fn(),
|
||||
onSubmit: vi.fn(),
|
||||
initialData: { variable: 'test' },
|
||||
}
|
||||
setupMocks({ inputFieldEditPanelProps: editProps })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render InputFieldEditor panel when inputFieldEditPanelProps is null', async () => {
|
||||
// Arrange
|
||||
setupMocks({ inputFieldEditPanelProps: null })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('input-field-editor-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass props to InputFieldEditor panel', async () => {
|
||||
// Arrange
|
||||
const editProps = {
|
||||
onClose: vi.fn(),
|
||||
onSubmit: vi.fn(),
|
||||
initialData: { variable: 'test_var', label: 'Test Label' },
|
||||
}
|
||||
setupMocks({ inputFieldEditPanelProps: editProps })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockInputFieldEditorProps).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
onClose: editProps.onClose,
|
||||
onSubmit: editProps.onSubmit,
|
||||
initialData: editProps.initialData,
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Conditional Rendering - InputField Panel
|
||||
// -------------------------------------------------------------------------
|
||||
describe('InputField Panel Conditional Rendering', () => {
|
||||
it('should render InputField panel when showInputFieldPanel is true', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showInputFieldPanel: true })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not render InputField panel when showInputFieldPanel is false', async () => {
|
||||
// Arrange
|
||||
setupMocks({ showInputFieldPanel: false })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('input-field-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Multiple Panels Rendering
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Multiple Left Panels Rendering', () => {
|
||||
it('should render all left panels when all conditions are true', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
showInputFieldPreviewPanel: true,
|
||||
inputFieldEditPanelProps: { onClose: vi.fn(), onSubmit: vi.fn() },
|
||||
showInputFieldPanel: true,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render no left panels when all conditions are false', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
showInputFieldPreviewPanel: false,
|
||||
inputFieldEditPanelProps: null,
|
||||
showInputFieldPanel: false,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('preview-panel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('input-field-editor-panel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('input-field-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render only Preview and InputField panels', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
showInputFieldPreviewPanel: true,
|
||||
inputFieldEditPanelProps: null,
|
||||
showInputFieldPanel: true,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('input-field-editor-panel')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Empty/Undefined Values
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Empty/Undefined Values', () => {
|
||||
it('should handle empty pipelineId gracefully', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: '' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
|
||||
'/rag/pipelines//workflows',
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle special characters in pipelineId', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: 'pipeline-with-special_chars.123' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
|
||||
'/rag/pipelines/pipeline-with-special_chars.123/workflows',
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Props Spreading Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Props Spreading', () => {
|
||||
it('should correctly spread inputFieldEditPanelProps to editor component', async () => {
|
||||
// Arrange
|
||||
const customProps = {
|
||||
onClose: vi.fn(),
|
||||
onSubmit: vi.fn(),
|
||||
initialData: {
|
||||
variable: 'custom_var',
|
||||
label: 'Custom Label',
|
||||
type: 'text',
|
||||
},
|
||||
extraProp: 'extra-value',
|
||||
}
|
||||
setupMocks({ inputFieldEditPanelProps: customProps })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockInputFieldEditorProps).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
extraProp: 'extra-value',
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// State Combinations
|
||||
// -------------------------------------------------------------------------
|
||||
describe('State Combinations', () => {
|
||||
it('should handle all panels visible simultaneously', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
historyWorkflowData: { id: 'h1' },
|
||||
showDebugAndPreviewPanel: true,
|
||||
showGlobalVariablePanel: true,
|
||||
showInputFieldPreviewPanel: true,
|
||||
inputFieldEditPanelProps: { onClose: vi.fn(), onSubmit: vi.fn() },
|
||||
showInputFieldPanel: true,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert - All panels should be visible
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// URL Generator Functions Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('URL Generator Functions', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
it('should return consistent URLs for same versionId', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: 'stable-pipeline' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
const deleteUrl1 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-x')
|
||||
const deleteUrl2 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-x')
|
||||
expect(deleteUrl1).toBe(deleteUrl2)
|
||||
})
|
||||
})
|
||||
|
||||
it('should return different URLs for different versionIds', async () => {
|
||||
// Arrange
|
||||
setupMocks({ pipelineId: 'stable-pipeline' })
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
const deleteUrl1 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-1')
|
||||
const deleteUrl2 = capturedPanelProps?.versionHistoryPanelProps?.deleteVersionUrl?.('version-2')
|
||||
expect(deleteUrl1).not.toBe(deleteUrl2)
|
||||
expect(deleteUrl1).toBe('/rag/pipelines/stable-pipeline/workflows/version-1')
|
||||
expect(deleteUrl2).toBe('/rag/pipelines/stable-pipeline/workflows/version-2')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Type Safety Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('Type Safety', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
it('should pass correct PanelProps structure', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert - Check structure matches PanelProps
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps).toHaveProperty('components')
|
||||
expect(capturedPanelProps).toHaveProperty('versionHistoryPanelProps')
|
||||
expect(capturedPanelProps?.components).toHaveProperty('left')
|
||||
expect(capturedPanelProps?.components).toHaveProperty('right')
|
||||
})
|
||||
})
|
||||
|
||||
it('should pass correct versionHistoryPanelProps structure', async () => {
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('getVersionListUrl')
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('deleteVersionUrl')
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('updateVersionUrl')
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps).toHaveProperty('latestVersionId')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Performance Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('Performance', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
it('should handle multiple rerenders without issues', async () => {
|
||||
// Arrange
|
||||
const { rerender } = render(<RagPipelinePanel />)
|
||||
|
||||
// Act - Multiple rerenders
|
||||
for (let i = 0; i < 10; i++)
|
||||
rerender(<RagPipelinePanel />)
|
||||
|
||||
// Assert - Component should still work
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('workflow-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('Integration Tests', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
setupMocks()
|
||||
})
|
||||
|
||||
it('should pass correct components to Panel', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
historyWorkflowData: { id: 'h1' },
|
||||
showInputFieldPanel: true,
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(capturedPanelProps?.components?.left).toBeDefined()
|
||||
expect(capturedPanelProps?.components?.right).toBeDefined()
|
||||
|
||||
// Check that the components are React elements
|
||||
expect(React.isValidElement(capturedPanelProps?.components?.left)).toBe(true)
|
||||
expect(React.isValidElement(capturedPanelProps?.components?.right)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('should correctly consume all store selectors', async () => {
|
||||
// Arrange
|
||||
setupMocks({
|
||||
historyWorkflowData: { id: 'test-history' },
|
||||
showDebugAndPreviewPanel: true,
|
||||
showGlobalVariablePanel: true,
|
||||
showInputFieldPanel: true,
|
||||
showInputFieldPreviewPanel: true,
|
||||
inputFieldEditPanelProps: { onClose: vi.fn(), onSubmit: vi.fn() },
|
||||
pipelineId: 'integration-test-pipeline',
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<RagPipelinePanel />)
|
||||
|
||||
// Assert - All store-dependent rendering should work
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('record-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('test-run-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('global-variable-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('input-field-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('preview-panel')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('input-field-editor-panel')).toBeInTheDocument()
|
||||
expect(capturedPanelProps?.versionHistoryPanelProps?.getVersionListUrl).toBe(
|
||||
'/rag/pipelines/integration-test-pipeline/workflows',
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -49,6 +49,7 @@ const InputFieldEditorPanel = ({
|
|||
</div>
|
||||
<button
|
||||
type="button"
|
||||
data-testid="input-field-editor-close-btn"
|
||||
className="absolute right-2.5 top-2.5 flex size-8 items-center justify-center"
|
||||
onClick={onClose}
|
||||
>
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -53,6 +53,7 @@ const FieldList = ({
|
|||
{LabelRightContent}
|
||||
</div>
|
||||
<ActionButton
|
||||
data-testid="field-list-add-btn"
|
||||
onClick={() => handleOpenInputFieldEditor()}
|
||||
disabled={readonly}
|
||||
className={cn(readonly && 'cursor-not-allowed')}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,937 @@
|
|||
import type { WorkflowRunningData } from '@/app/components/workflow/types'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
|
||||
import Header from './header'
|
||||
// Import components after mocks
|
||||
import TestRunPanel from './index'
|
||||
|
||||
// ============================================================================
|
||||
// Mocks
|
||||
// ============================================================================
|
||||
|
||||
// Mock workflow store
|
||||
const mockIsPreparingDataSource = vi.fn(() => true)
|
||||
const mockSetIsPreparingDataSource = vi.fn()
|
||||
const mockWorkflowRunningData = vi.fn<() => WorkflowRunningData | undefined>(() => undefined)
|
||||
const mockPipelineId = 'test-pipeline-id'
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useStore: (selector: (state: Record<string, unknown>) => unknown) => {
|
||||
const state = {
|
||||
isPreparingDataSource: mockIsPreparingDataSource(),
|
||||
workflowRunningData: mockWorkflowRunningData(),
|
||||
pipelineId: mockPipelineId,
|
||||
}
|
||||
return selector(state)
|
||||
},
|
||||
useWorkflowStore: () => ({
|
||||
getState: () => ({
|
||||
isPreparingDataSource: mockIsPreparingDataSource(),
|
||||
setIsPreparingDataSource: mockSetIsPreparingDataSource,
|
||||
}),
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock workflow interactions
|
||||
const mockHandleCancelDebugAndPreviewPanel = vi.fn()
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useWorkflowInteractions: () => ({
|
||||
handleCancelDebugAndPreviewPanel: mockHandleCancelDebugAndPreviewPanel,
|
||||
}),
|
||||
useWorkflowRun: () => ({
|
||||
handleRun: vi.fn(),
|
||||
}),
|
||||
useToolIcon: () => 'mock-tool-icon',
|
||||
}))
|
||||
|
||||
// Mock data source provider
|
||||
vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store/provider', () => ({
|
||||
default: ({ children }: { children: React.ReactNode }) => <div data-testid="data-source-provider">{children}</div>,
|
||||
}))
|
||||
|
||||
// Mock Preparation component
|
||||
vi.mock('./preparation', () => ({
|
||||
default: () => <div data-testid="preparation-component">Preparation</div>,
|
||||
}))
|
||||
|
||||
// Mock Result component (for TestRunPanel tests only)
|
||||
vi.mock('./result', () => ({
|
||||
default: () => <div data-testid="result-component">Result</div>,
|
||||
}))
|
||||
|
||||
// Mock ResultPanel from workflow
|
||||
vi.mock('@/app/components/workflow/run/result-panel', () => ({
|
||||
default: (props: Record<string, unknown>) => (
|
||||
<div data-testid="result-panel">
|
||||
ResultPanel -
|
||||
{' '}
|
||||
{props.status as string}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
// Mock TracingPanel from workflow
|
||||
vi.mock('@/app/components/workflow/run/tracing-panel', () => ({
|
||||
default: (props: { list: unknown[] }) => (
|
||||
<div data-testid="tracing-panel">
|
||||
TracingPanel -
|
||||
{' '}
|
||||
{props.list?.length ?? 0}
|
||||
{' '}
|
||||
items
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
// Mock Loading component
|
||||
vi.mock('@/app/components/base/loading', () => ({
|
||||
default: () => <div data-testid="loading">Loading...</div>,
|
||||
}))
|
||||
|
||||
// Mock config
|
||||
vi.mock('@/config', () => ({
|
||||
RAG_PIPELINE_PREVIEW_CHUNK_NUM: 5,
|
||||
}))
|
||||
|
||||
// ============================================================================
|
||||
// Test Data Factories
|
||||
// ============================================================================
|
||||
|
||||
const createMockWorkflowRunningData = (overrides: Partial<WorkflowRunningData> = {}): WorkflowRunningData => ({
|
||||
result: {
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: '{"test": "output"}',
|
||||
outputs_truncated: false,
|
||||
inputs: '{"test": "input"}',
|
||||
inputs_truncated: false,
|
||||
process_data_truncated: false,
|
||||
error: undefined,
|
||||
elapsed_time: 1000,
|
||||
total_tokens: 100,
|
||||
created_at: Date.now(),
|
||||
created_by: 'Test User',
|
||||
total_steps: 5,
|
||||
exceptions_count: 0,
|
||||
},
|
||||
tracing: [],
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createMockGeneralOutputs = (chunkContents: string[] = ['chunk1', 'chunk2']) => ({
|
||||
chunk_structure: ChunkingMode.text,
|
||||
preview: chunkContents.map(content => ({ content })),
|
||||
})
|
||||
|
||||
const createMockParentChildOutputs = (parentMode: 'paragraph' | 'full-doc' = 'paragraph') => ({
|
||||
chunk_structure: ChunkingMode.parentChild,
|
||||
parent_mode: parentMode,
|
||||
preview: [
|
||||
{ content: 'parent1', child_chunks: ['child1', 'child2'] },
|
||||
{ content: 'parent2', child_chunks: ['child3', 'child4'] },
|
||||
],
|
||||
})
|
||||
|
||||
const createMockQAOutputs = () => ({
|
||||
chunk_structure: ChunkingMode.qa,
|
||||
qa_preview: [
|
||||
{ question: 'Q1', answer: 'A1' },
|
||||
{ question: 'Q2', answer: 'A2' },
|
||||
],
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// TestRunPanel Component Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('TestRunPanel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockIsPreparingDataSource.mockReturnValue(true)
|
||||
mockWorkflowRunningData.mockReturnValue(undefined)
|
||||
})
|
||||
|
||||
// Basic rendering tests
|
||||
describe('Rendering', () => {
|
||||
it('should render with correct container styles', () => {
|
||||
const { container } = render(<TestRunPanel />)
|
||||
const panelDiv = container.firstChild as HTMLElement
|
||||
|
||||
expect(panelDiv).toHaveClass('relative', 'flex', 'h-full', 'w-[480px]', 'flex-col')
|
||||
})
|
||||
|
||||
it('should render Header component', () => {
|
||||
render(<TestRunPanel />)
|
||||
|
||||
expect(screen.getByText('datasetPipeline.testRun.title')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Conditional rendering based on isPreparingDataSource
|
||||
describe('Conditional Content Rendering', () => {
|
||||
it('should render Preparation inside DataSourceProvider when isPreparingDataSource is true', () => {
|
||||
mockIsPreparingDataSource.mockReturnValue(true)
|
||||
|
||||
render(<TestRunPanel />)
|
||||
|
||||
expect(screen.getByTestId('data-source-provider')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('preparation-component')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('result-component')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render Result when isPreparingDataSource is false', () => {
|
||||
mockIsPreparingDataSource.mockReturnValue(false)
|
||||
|
||||
render(<TestRunPanel />)
|
||||
|
||||
expect(screen.getByTestId('result-component')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('data-source-provider')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('preparation-component')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Header Component Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('Header', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockIsPreparingDataSource.mockReturnValue(true)
|
||||
})
|
||||
|
||||
// Rendering tests
|
||||
describe('Rendering', () => {
|
||||
it('should render title with correct translation key', () => {
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.getByText('datasetPipeline.testRun.title')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render close button', () => {
|
||||
render(<Header />)
|
||||
|
||||
const closeButton = screen.getByRole('button')
|
||||
expect(closeButton).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should have correct layout classes', () => {
|
||||
const { container } = render(<Header />)
|
||||
const headerDiv = container.firstChild as HTMLElement
|
||||
|
||||
expect(headerDiv).toHaveClass('flex', 'items-center', 'gap-x-2', 'pl-4', 'pr-3', 'pt-4')
|
||||
})
|
||||
})
|
||||
|
||||
// Close button interactions
|
||||
describe('Close Button Interaction', () => {
|
||||
it('should call setIsPreparingDataSource(false) and handleCancelDebugAndPreviewPanel when clicked and isPreparingDataSource is true', () => {
|
||||
mockIsPreparingDataSource.mockReturnValue(true)
|
||||
|
||||
render(<Header />)
|
||||
|
||||
const closeButton = screen.getByRole('button')
|
||||
fireEvent.click(closeButton)
|
||||
|
||||
expect(mockSetIsPreparingDataSource).toHaveBeenCalledWith(false)
|
||||
expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should only call handleCancelDebugAndPreviewPanel when isPreparingDataSource is false', () => {
|
||||
mockIsPreparingDataSource.mockReturnValue(false)
|
||||
|
||||
render(<Header />)
|
||||
|
||||
const closeButton = screen.getByRole('button')
|
||||
fireEvent.click(closeButton)
|
||||
|
||||
expect(mockSetIsPreparingDataSource).not.toHaveBeenCalled()
|
||||
expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Result Component Tests (Real Implementation)
|
||||
// ============================================================================
|
||||
|
||||
// Unmock Result for these tests
|
||||
vi.doUnmock('./result')
|
||||
|
||||
describe('Result', () => {
|
||||
// Dynamically import Result to get real implementation
|
||||
let Result: typeof import('./result').default
|
||||
|
||||
beforeAll(async () => {
|
||||
const resultModule = await import('./result')
|
||||
Result = resultModule.default
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWorkflowRunningData.mockReturnValue(undefined)
|
||||
})
|
||||
|
||||
// Rendering tests
|
||||
describe('Rendering', () => {
|
||||
it('should render with RESULT tab active by default', async () => {
|
||||
render(<Result />)
|
||||
|
||||
await waitFor(() => {
|
||||
const resultTab = screen.getByRole('button', { name: /runLog\.result/i })
|
||||
expect(resultTab).toHaveClass('border-util-colors-blue-brand-blue-brand-600')
|
||||
})
|
||||
})
|
||||
|
||||
it('should render all three tabs', () => {
|
||||
render(<Result />)
|
||||
|
||||
expect(screen.getByRole('button', { name: /runLog\.result/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /runLog\.detail/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /runLog\.tracing/i })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Tab switching tests
|
||||
describe('Tab Switching', () => {
|
||||
it('should switch to DETAIL tab when clicked', async () => {
|
||||
mockWorkflowRunningData.mockReturnValue(createMockWorkflowRunningData())
|
||||
render(<Result />)
|
||||
|
||||
const detailTab = screen.getByRole('button', { name: /runLog\.detail/i })
|
||||
fireEvent.click(detailTab)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('result-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should switch to TRACING tab when clicked', async () => {
|
||||
mockWorkflowRunningData.mockReturnValue(createMockWorkflowRunningData({ tracing: [{ id: '1' }] as unknown as WorkflowRunningData['tracing'] }))
|
||||
render(<Result />)
|
||||
|
||||
const tracingTab = screen.getByRole('button', { name: /runLog\.tracing/i })
|
||||
fireEvent.click(tracingTab)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('tracing-panel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Loading states
|
||||
describe('Loading States', () => {
|
||||
it('should show loading in DETAIL tab when no result data', async () => {
|
||||
mockWorkflowRunningData.mockReturnValue({
|
||||
result: undefined as unknown as WorkflowRunningData['result'],
|
||||
tracing: [],
|
||||
})
|
||||
render(<Result />)
|
||||
|
||||
const detailTab = screen.getByRole('button', { name: /runLog\.detail/i })
|
||||
fireEvent.click(detailTab)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show loading in TRACING tab when no tracing data', async () => {
|
||||
mockWorkflowRunningData.mockReturnValue(createMockWorkflowRunningData({ tracing: [] }))
|
||||
render(<Result />)
|
||||
|
||||
const tracingTab = screen.getByRole('button', { name: /runLog\.tracing/i })
|
||||
fireEvent.click(tracingTab)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('loading')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// ResultPreview Component Tests
|
||||
// ============================================================================
|
||||
|
||||
// We need to import ResultPreview directly
|
||||
vi.doUnmock('./result/result-preview')
|
||||
|
||||
describe('ResultPreview', () => {
|
||||
let ResultPreview: typeof import('./result/result-preview').default
|
||||
|
||||
beforeAll(async () => {
|
||||
const previewModule = await import('./result/result-preview')
|
||||
ResultPreview = previewModule.default
|
||||
})
|
||||
|
||||
const mockOnSwitchToDetail = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Loading state
|
||||
describe('Loading State', () => {
|
||||
it('should show loading spinner when isRunning is true and no outputs', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={true}
|
||||
outputs={undefined}
|
||||
error={undefined}
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('pipeline.result.resultPreview.loading')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show loading when outputs are available', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={true}
|
||||
outputs={createMockGeneralOutputs()}
|
||||
error={undefined}
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('pipeline.result.resultPreview.loading')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Error state
|
||||
describe('Error State', () => {
|
||||
it('should show error message when not running and has error', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={false}
|
||||
outputs={undefined}
|
||||
error="Test error message"
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('pipeline.result.resultPreview.error')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'pipeline.result.resultPreview.viewDetails' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onSwitchToDetail when View Details button is clicked', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={false}
|
||||
outputs={undefined}
|
||||
error="Test error message"
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
const viewDetailsButton = screen.getByRole('button', { name: 'pipeline.result.resultPreview.viewDetails' })
|
||||
fireEvent.click(viewDetailsButton)
|
||||
|
||||
expect(mockOnSwitchToDetail).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should not show error when still running', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={true}
|
||||
outputs={undefined}
|
||||
error="Test error message"
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('pipeline.result.resultPreview.error')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Success state with outputs
|
||||
describe('Success State with Outputs', () => {
|
||||
it('should render chunk content when outputs are available', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={false}
|
||||
outputs={createMockGeneralOutputs(['test chunk content'])}
|
||||
error={undefined}
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Check that chunk content is rendered (the real ChunkCardList renders the content)
|
||||
expect(screen.getByText('test chunk content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render multiple chunks when provided', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={false}
|
||||
outputs={createMockGeneralOutputs(['chunk one', 'chunk two'])}
|
||||
error={undefined}
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('chunk one')).toBeInTheDocument()
|
||||
expect(screen.getByText('chunk two')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show footer tip', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={false}
|
||||
outputs={createMockGeneralOutputs()}
|
||||
error={undefined}
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/pipeline\.result\.resultPreview\.footerTip/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Edge cases
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty outputs gracefully', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={false}
|
||||
outputs={null}
|
||||
error={undefined}
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Should not crash and should not show chunk card list
|
||||
expect(screen.queryByTestId('chunk-card-list')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined outputs', () => {
|
||||
render(
|
||||
<ResultPreview
|
||||
isRunning={false}
|
||||
outputs={undefined}
|
||||
error={undefined}
|
||||
onSwitchToDetail={mockOnSwitchToDetail}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('chunk-card-list')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Tabs Component Tests
|
||||
// ============================================================================
|
||||
|
||||
vi.doUnmock('./result/tabs')
|
||||
|
||||
describe('Tabs', () => {
|
||||
let Tabs: typeof import('./result/tabs').default
|
||||
|
||||
beforeAll(async () => {
|
||||
const tabsModule = await import('./result/tabs')
|
||||
Tabs = tabsModule.default
|
||||
})
|
||||
|
||||
const mockSwitchTab = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Rendering tests
|
||||
describe('Rendering', () => {
|
||||
it('should render all three tabs', () => {
|
||||
render(
|
||||
<Tabs
|
||||
currentTab="RESULT"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
switchTab={mockSwitchTab}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: /runLog\.result/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /runLog\.detail/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /runLog\.tracing/i })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Active tab styling
|
||||
describe('Active Tab Styling', () => {
|
||||
it('should highlight RESULT tab when currentTab is RESULT', () => {
|
||||
render(
|
||||
<Tabs
|
||||
currentTab="RESULT"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
switchTab={mockSwitchTab}
|
||||
/>,
|
||||
)
|
||||
|
||||
const resultTab = screen.getByRole('button', { name: /runLog\.result/i })
|
||||
expect(resultTab).toHaveClass('border-util-colors-blue-brand-blue-brand-600')
|
||||
})
|
||||
|
||||
it('should highlight DETAIL tab when currentTab is DETAIL', () => {
|
||||
render(
|
||||
<Tabs
|
||||
currentTab="DETAIL"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
switchTab={mockSwitchTab}
|
||||
/>,
|
||||
)
|
||||
|
||||
const detailTab = screen.getByRole('button', { name: /runLog\.detail/i })
|
||||
expect(detailTab).toHaveClass('border-util-colors-blue-brand-blue-brand-600')
|
||||
})
|
||||
})
|
||||
|
||||
// Tab click handling
|
||||
describe('Tab Click Handling', () => {
|
||||
it('should call switchTab with RESULT when RESULT tab is clicked', () => {
|
||||
render(
|
||||
<Tabs
|
||||
currentTab="DETAIL"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
switchTab={mockSwitchTab}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /runLog\.result/i }))
|
||||
|
||||
expect(mockSwitchTab).toHaveBeenCalledWith('RESULT')
|
||||
})
|
||||
|
||||
it('should call switchTab with DETAIL when DETAIL tab is clicked', () => {
|
||||
render(
|
||||
<Tabs
|
||||
currentTab="RESULT"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
switchTab={mockSwitchTab}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /runLog\.detail/i }))
|
||||
|
||||
expect(mockSwitchTab).toHaveBeenCalledWith('DETAIL')
|
||||
})
|
||||
|
||||
it('should call switchTab with TRACING when TRACING tab is clicked', () => {
|
||||
render(
|
||||
<Tabs
|
||||
currentTab="RESULT"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
switchTab={mockSwitchTab}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /runLog\.tracing/i }))
|
||||
|
||||
expect(mockSwitchTab).toHaveBeenCalledWith('TRACING')
|
||||
})
|
||||
})
|
||||
|
||||
// Disabled state when no data
|
||||
describe('Disabled State', () => {
|
||||
it('should disable tabs when workflowRunningData is undefined', () => {
|
||||
render(
|
||||
<Tabs
|
||||
currentTab="RESULT"
|
||||
workflowRunningData={undefined}
|
||||
switchTab={mockSwitchTab}
|
||||
/>,
|
||||
)
|
||||
|
||||
const resultTab = screen.getByRole('button', { name: /runLog\.result/i })
|
||||
expect(resultTab).toBeDisabled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Tab Component Tests
|
||||
// ============================================================================
|
||||
|
||||
vi.doUnmock('./result/tabs/tab')
|
||||
|
||||
describe('Tab', () => {
|
||||
let Tab: typeof import('./result/tabs/tab').default
|
||||
|
||||
beforeAll(async () => {
|
||||
const tabModule = await import('./result/tabs/tab')
|
||||
Tab = tabModule.default
|
||||
})
|
||||
|
||||
const mockOnClick = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Rendering tests
|
||||
describe('Rendering', () => {
|
||||
it('should render tab with label', () => {
|
||||
render(
|
||||
<Tab
|
||||
isActive={false}
|
||||
label="Test Tab"
|
||||
value="TEST"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
onClick={mockOnClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'Test Tab' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Active state styling
|
||||
describe('Active State', () => {
|
||||
it('should have active styles when isActive is true', () => {
|
||||
render(
|
||||
<Tab
|
||||
isActive={true}
|
||||
label="Active Tab"
|
||||
value="TEST"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
onClick={mockOnClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
const tab = screen.getByRole('button')
|
||||
expect(tab).toHaveClass('border-util-colors-blue-brand-blue-brand-600', 'text-text-primary')
|
||||
})
|
||||
|
||||
it('should have inactive styles when isActive is false', () => {
|
||||
render(
|
||||
<Tab
|
||||
isActive={false}
|
||||
label="Inactive Tab"
|
||||
value="TEST"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
onClick={mockOnClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
const tab = screen.getByRole('button')
|
||||
expect(tab).toHaveClass('border-transparent', 'text-text-tertiary')
|
||||
})
|
||||
})
|
||||
|
||||
// Click handling
|
||||
describe('Click Handling', () => {
|
||||
it('should call onClick with value when clicked', () => {
|
||||
render(
|
||||
<Tab
|
||||
isActive={false}
|
||||
label="Test Tab"
|
||||
value="MY_VALUE"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
onClick={mockOnClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(mockOnClick).toHaveBeenCalledWith('MY_VALUE')
|
||||
})
|
||||
|
||||
it('should not call onClick when disabled (no workflowRunningData)', () => {
|
||||
render(
|
||||
<Tab
|
||||
isActive={false}
|
||||
label="Test Tab"
|
||||
value="MY_VALUE"
|
||||
workflowRunningData={undefined}
|
||||
onClick={mockOnClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
const tab = screen.getByRole('button')
|
||||
fireEvent.click(tab)
|
||||
|
||||
// The click handler is still called, but button is disabled
|
||||
expect(tab).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
// Disabled state
|
||||
describe('Disabled State', () => {
|
||||
it('should be disabled when workflowRunningData is undefined', () => {
|
||||
render(
|
||||
<Tab
|
||||
isActive={false}
|
||||
label="Test Tab"
|
||||
value="TEST"
|
||||
workflowRunningData={undefined}
|
||||
onClick={mockOnClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
const tab = screen.getByRole('button')
|
||||
expect(tab).toBeDisabled()
|
||||
expect(tab).toHaveClass('opacity-30')
|
||||
})
|
||||
|
||||
it('should not be disabled when workflowRunningData is provided', () => {
|
||||
render(
|
||||
<Tab
|
||||
isActive={false}
|
||||
label="Test Tab"
|
||||
value="TEST"
|
||||
workflowRunningData={createMockWorkflowRunningData()}
|
||||
onClick={mockOnClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
const tab = screen.getByRole('button')
|
||||
expect(tab).not.toBeDisabled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// formatPreviewChunks Utility Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('formatPreviewChunks', () => {
|
||||
let formatPreviewChunks: typeof import('./result/result-preview/utils').formatPreviewChunks
|
||||
|
||||
beforeAll(async () => {
|
||||
const utilsModule = await import('./result/result-preview/utils')
|
||||
formatPreviewChunks = utilsModule.formatPreviewChunks
|
||||
})
|
||||
|
||||
// Edge cases
|
||||
describe('Edge Cases', () => {
|
||||
it('should return undefined for null outputs', () => {
|
||||
expect(formatPreviewChunks(null)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return undefined for undefined outputs', () => {
|
||||
expect(formatPreviewChunks(undefined)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return undefined for unknown chunk structure', () => {
|
||||
const outputs = {
|
||||
chunk_structure: 'unknown_mode',
|
||||
preview: [],
|
||||
}
|
||||
expect(formatPreviewChunks(outputs)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
// General (text) chunks
|
||||
describe('General Chunks (ChunkingMode.text)', () => {
|
||||
it('should format general chunks correctly', () => {
|
||||
const outputs = createMockGeneralOutputs(['content1', 'content2', 'content3'])
|
||||
const result = formatPreviewChunks(outputs)
|
||||
|
||||
expect(result).toEqual(['content1', 'content2', 'content3'])
|
||||
})
|
||||
|
||||
it('should limit to RAG_PIPELINE_PREVIEW_CHUNK_NUM chunks', () => {
|
||||
const manyChunks = Array.from({ length: 10 }, (_, i) => `chunk${i}`)
|
||||
const outputs = createMockGeneralOutputs(manyChunks)
|
||||
const result = formatPreviewChunks(outputs) as string[]
|
||||
|
||||
// RAG_PIPELINE_PREVIEW_CHUNK_NUM is mocked to 5
|
||||
expect(result).toHaveLength(5)
|
||||
expect(result).toEqual(['chunk0', 'chunk1', 'chunk2', 'chunk3', 'chunk4'])
|
||||
})
|
||||
|
||||
it('should handle empty preview array', () => {
|
||||
const outputs = createMockGeneralOutputs([])
|
||||
const result = formatPreviewChunks(outputs)
|
||||
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
// Parent-child chunks
|
||||
describe('Parent-Child Chunks (ChunkingMode.parentChild)', () => {
|
||||
it('should format paragraph mode parent-child chunks correctly', () => {
|
||||
const outputs = createMockParentChildOutputs('paragraph')
|
||||
const result = formatPreviewChunks(outputs)
|
||||
|
||||
expect(result).toEqual({
|
||||
parent_child_chunks: [
|
||||
{ parent_content: 'parent1', child_contents: ['child1', 'child2'], parent_mode: 'paragraph' },
|
||||
{ parent_content: 'parent2', child_contents: ['child3', 'child4'], parent_mode: 'paragraph' },
|
||||
],
|
||||
parent_mode: 'paragraph',
|
||||
})
|
||||
})
|
||||
|
||||
it('should format full-doc mode parent-child chunks and limit child chunks', () => {
|
||||
const outputs = {
|
||||
chunk_structure: ChunkingMode.parentChild,
|
||||
parent_mode: 'full-doc' as const,
|
||||
preview: [
|
||||
{
|
||||
content: 'full-doc-parent',
|
||||
child_chunks: Array.from({ length: 10 }, (_, i) => `child${i}`),
|
||||
},
|
||||
],
|
||||
}
|
||||
const result = formatPreviewChunks(outputs)
|
||||
|
||||
expect(result).toEqual({
|
||||
parent_child_chunks: [
|
||||
{
|
||||
parent_content: 'full-doc-parent',
|
||||
child_contents: ['child0', 'child1', 'child2', 'child3', 'child4'], // Limited to 5
|
||||
parent_mode: 'full-doc',
|
||||
},
|
||||
],
|
||||
parent_mode: 'full-doc',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// QA chunks
|
||||
describe('QA Chunks (ChunkingMode.qa)', () => {
|
||||
it('should format QA chunks correctly', () => {
|
||||
const outputs = createMockQAOutputs()
|
||||
const result = formatPreviewChunks(outputs)
|
||||
|
||||
expect(result).toEqual({
|
||||
qa_chunks: [
|
||||
{ question: 'Q1', answer: 'A1' },
|
||||
{ question: 'Q2', answer: 'A2' },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it('should limit QA chunks to RAG_PIPELINE_PREVIEW_CHUNK_NUM', () => {
|
||||
const outputs = {
|
||||
chunk_structure: ChunkingMode.qa,
|
||||
qa_preview: Array.from({ length: 10 }, (_, i) => ({
|
||||
question: `Q${i}`,
|
||||
answer: `A${i}`,
|
||||
})),
|
||||
}
|
||||
const result = formatPreviewChunks(outputs) as { qa_chunks: Array<{ question: string, answer: string }> }
|
||||
|
||||
expect(result.qa_chunks).toHaveLength(5)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Types Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('Types', () => {
|
||||
describe('TestRunStep Enum', () => {
|
||||
it('should have correct enum values', async () => {
|
||||
const { TestRunStep } = await import('./types')
|
||||
|
||||
expect(TestRunStep.dataSource).toBe('dataSource')
|
||||
expect(TestRunStep.documentProcessing).toBe('documentProcessing')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,549 @@
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Actions from './index'
|
||||
|
||||
// ============================================================================
|
||||
// Actions Component Tests
|
||||
// ============================================================================
|
||||
|
||||
describe('Actions', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Rendering Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByRole('button')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render button with translated text', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Translation mock returns key with namespace prefix
|
||||
expect(screen.getByText('datasetCreation.stepOne.button')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with correct container structure', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { container } = render(<Actions handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper.className).toContain('flex')
|
||||
expect(wrapper.className).toContain('justify-end')
|
||||
expect(wrapper.className).toContain('p-4')
|
||||
expect(wrapper.className).toContain('pt-2')
|
||||
})
|
||||
|
||||
it('should render span with px-0.5 class around text', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { container } = render(<Actions handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
const span = container.querySelector('span')
|
||||
expect(span).toBeInTheDocument()
|
||||
expect(span?.className).toContain('px-0.5')
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Props Variations Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Props Variations', () => {
|
||||
it('should pass disabled=true to button when disabled prop is true', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should pass disabled=false to button when disabled prop is false', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions disabled={false} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
})
|
||||
|
||||
it('should not disable button when disabled prop is undefined', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
})
|
||||
|
||||
it('should handle disabled switching from true to false', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions disabled={true} handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Assert - Initially disabled
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
|
||||
// Act - Rerender with disabled=false
|
||||
rerender(<Actions disabled={false} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Now enabled
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
})
|
||||
|
||||
it('should handle disabled switching from false to true', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions disabled={false} handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Assert - Initially enabled
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
|
||||
// Act - Rerender with disabled=true
|
||||
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Now disabled
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should handle undefined disabled becoming true', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Assert - Initially not disabled (undefined)
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
|
||||
// Act - Rerender with disabled=true
|
||||
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Now disabled
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// User Interaction Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('User Interactions', () => {
|
||||
it('should call handleNextStep when button is clicked', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call handleNextStep exactly once per click', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalled()
|
||||
expect(handleNextStep.mock.calls).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should call handleNextStep multiple times on multiple clicks', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
const button = screen.getByRole('button')
|
||||
fireEvent.click(button)
|
||||
fireEvent.click(button)
|
||||
fireEvent.click(button)
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should not call handleNextStep when button is disabled and clicked', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert - Disabled button should not trigger onClick
|
||||
expect(handleNextStep).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle rapid clicks when not disabled', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
const button = screen.getByRole('button')
|
||||
|
||||
// Simulate rapid clicks
|
||||
for (let i = 0; i < 10; i++)
|
||||
fireEvent.click(button)
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(10)
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Callback Stability Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Callback Stability', () => {
|
||||
it('should use the new handleNextStep when prop changes', () => {
|
||||
// Arrange
|
||||
const handleNextStep1 = vi.fn()
|
||||
const handleNextStep2 = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions handleNextStep={handleNextStep1} />,
|
||||
)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
rerender(<Actions handleNextStep={handleNextStep2} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep1).toHaveBeenCalledTimes(1)
|
||||
expect(handleNextStep2).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should maintain functionality after rerender with same props', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions handleNextStep={handleNextStep} />,
|
||||
)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
rerender(<Actions handleNextStep={handleNextStep} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should work correctly when handleNextStep changes multiple times', () => {
|
||||
// Arrange
|
||||
const handleNextStep1 = vi.fn()
|
||||
const handleNextStep2 = vi.fn()
|
||||
const handleNextStep3 = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions handleNextStep={handleNextStep1} />,
|
||||
)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
rerender(<Actions handleNextStep={handleNextStep2} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
rerender(<Actions handleNextStep={handleNextStep3} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep1).toHaveBeenCalledTimes(1)
|
||||
expect(handleNextStep2).toHaveBeenCalledTimes(1)
|
||||
expect(handleNextStep3).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Memoization Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Memoization', () => {
|
||||
it('should be wrapped with React.memo', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act - Verify component is memoized by checking display name pattern
|
||||
const { rerender } = render(
|
||||
<Actions handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Rerender with same props should work without issues
|
||||
rerender(<Actions handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Component should render correctly after rerender
|
||||
expect(screen.getByRole('button')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not break when props remain the same across rerenders', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions disabled={false} handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Multiple rerenders with same props
|
||||
for (let i = 0; i < 5; i++) {
|
||||
rerender(<Actions disabled={false} handleNextStep={handleNextStep} />)
|
||||
}
|
||||
|
||||
// Assert - Should still function correctly
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should update correctly when only disabled prop changes', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions disabled={false} handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Assert - Initially not disabled
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
|
||||
// Act - Change only disabled prop
|
||||
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Should reflect the new disabled state
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should update correctly when only handleNextStep prop changes', () => {
|
||||
// Arrange
|
||||
const handleNextStep1 = vi.fn()
|
||||
const handleNextStep2 = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions disabled={false} handleNextStep={handleNextStep1} />,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(handleNextStep1).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Act - Change only handleNextStep prop
|
||||
rerender(<Actions disabled={false} handleNextStep={handleNextStep2} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert - New callback should be used
|
||||
expect(handleNextStep1).toHaveBeenCalledTimes(1)
|
||||
expect(handleNextStep2).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Edge Cases Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Edge Cases', () => {
|
||||
it('should call handleNextStep even if it has side effects', () => {
|
||||
// Arrange
|
||||
let sideEffectValue = 0
|
||||
const handleNextStep = vi.fn(() => {
|
||||
sideEffectValue = 42
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(1)
|
||||
expect(sideEffectValue).toBe(42)
|
||||
})
|
||||
|
||||
it('should handle handleNextStep that returns a value', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn(() => 'return value')
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(1)
|
||||
expect(handleNextStep).toHaveReturnedWith('return value')
|
||||
})
|
||||
|
||||
it('should handle handleNextStep that is async', async () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn().mockResolvedValue(undefined)
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render correctly with both disabled=true and handleNextStep', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
const button = screen.getByRole('button')
|
||||
expect(button).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should handle component unmount gracefully', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { unmount } = render(<Actions handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Unmount should not throw
|
||||
expect(() => unmount()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle disabled as boolean-like falsy value', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act - Test with explicit false
|
||||
render(<Actions disabled={false} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Accessibility Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Accessibility', () => {
|
||||
it('should have button element that can receive focus', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions handleNextStep={handleNextStep} />)
|
||||
const button = screen.getByRole('button')
|
||||
|
||||
// Assert - Button should be focusable (not disabled by default)
|
||||
expect(button).not.toBeDisabled()
|
||||
})
|
||||
|
||||
it('should indicate disabled state correctly', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
render(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByRole('button')).toHaveAttribute('disabled')
|
||||
})
|
||||
})
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Integration Tests
|
||||
// -------------------------------------------------------------------------
|
||||
describe('Integration', () => {
|
||||
it('should work in a typical workflow: enable -> click -> disable', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act - Start enabled
|
||||
const { rerender } = render(
|
||||
<Actions disabled={false} handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Assert - Can click when enabled
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Act - Disable after click (simulating loading state)
|
||||
rerender(<Actions disabled={true} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Cannot click when disabled
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(1) // Still 1, not 2
|
||||
|
||||
// Act - Re-enable
|
||||
rerender(<Actions disabled={false} handleNextStep={handleNextStep} />)
|
||||
|
||||
// Assert - Can click again
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
expect(handleNextStep).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should maintain consistent rendering across multiple state changes', () => {
|
||||
// Arrange
|
||||
const handleNextStep = vi.fn()
|
||||
|
||||
// Act
|
||||
const { rerender } = render(
|
||||
<Actions disabled={false} handleNextStep={handleNextStep} />,
|
||||
)
|
||||
|
||||
// Toggle disabled state multiple times
|
||||
const states = [true, false, true, false, true]
|
||||
states.forEach((disabled) => {
|
||||
rerender(<Actions disabled={disabled} handleNextStep={handleNextStep} />)
|
||||
if (disabled)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
else
|
||||
expect(screen.getByRole('button')).not.toBeDisabled()
|
||||
})
|
||||
|
||||
// Assert - Button should still render correctly
|
||||
expect(screen.getByRole('button')).toBeInTheDocument()
|
||||
expect(screen.getByText('datasetCreation.stepOne.button')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -85,7 +85,11 @@ const PublishAsKnowledgePipelineModal = ({
|
|||
>
|
||||
<div className="title-2xl-semi-bold relative flex items-center p-6 pb-3 pr-14 text-text-primary">
|
||||
{t('common.publishAs', { ns: 'pipeline' })}
|
||||
<div className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center" onClick={onCancel}>
|
||||
<div
|
||||
data-testid="publish-modal-close-btn"
|
||||
className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center"
|
||||
onClick={onCancel}
|
||||
>
|
||||
<RiCloseLine className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue