diff --git a/.claude/settings.json b/.claude/settings.json index 7d42234cae..509dbe8447 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -3,6 +3,7 @@ "feature-dev@claude-plugins-official": true, "context7@claude-plugins-official": true, "typescript-lsp@claude-plugins-official": true, - "pyright-lsp@claude-plugins-official": true + "pyright-lsp@claude-plugins-official": true, + "ralph-loop@claude-plugins-official": true } } diff --git a/.claude/skills/frontend-code-review/SKILL.md b/.claude/skills/frontend-code-review/SKILL.md new file mode 100644 index 0000000000..6cc23ca171 --- /dev/null +++ b/.claude/skills/frontend-code-review/SKILL.md @@ -0,0 +1,73 @@ +--- +name: frontend-code-review +description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules." +--- + +# Frontend Code Review + +## Intent +Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes: + +1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission. +2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings. + +Stick to the checklist below for every applicable file and mode. + +## Checklist +See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow. + +Flag each rule violation with urgency metadata so future reviewers can prioritize fixes. + +## Review Process +1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling. +2. For each rule in the review point, note where the code deviates and capture a representative snippet. +3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic). + +## Required output +When invoked, the response must exactly follow one of the two templates: + +### Template A (any findings) +``` +# Code review +Found urgent issues need to be fixed: + +## 1 +FilePath: line + + + +### Suggested fix + + +--- +... (repeat for each urgent issue) ... + +Found suggestions for improvement: + +## 1 +FilePath: line + + + +### Suggested fix + + +--- + +... (repeat for each suggestion) ... +``` + +If there are no urgent issues, omit that section. If there are no suggestions, omit that section. + +If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues. + +Don't compress the blank lines between sections; keep them as-is for readability. + +If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?" + +### Template B (no issues) +``` +## Code review +No issues found. +``` + diff --git a/.claude/skills/frontend-code-review/references/business-logic.md b/.claude/skills/frontend-code-review/references/business-logic.md new file mode 100644 index 0000000000..4584f99dfc --- /dev/null +++ b/.claude/skills/frontend-code-review/references/business-logic.md @@ -0,0 +1,15 @@ +# Rule Catalog — Business Logic + +## Can't use workflowStore in Node components + +IsUrgent: True + +### Description + +File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx` + +Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason. + +### Suggested Fix + +Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`. diff --git a/.claude/skills/frontend-code-review/references/code-quality.md b/.claude/skills/frontend-code-review/references/code-quality.md new file mode 100644 index 0000000000..afdd40deb3 --- /dev/null +++ b/.claude/skills/frontend-code-review/references/code-quality.md @@ -0,0 +1,44 @@ +# Rule Catalog — Code Quality + +## Conditional class names use utility function + +IsUrgent: True +Category: Code Quality + +### Description + +Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain. + +### Suggested Fix + +```ts +import { cn } from '@/utils/classnames' +const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500') +``` + +## Tailwind-first styling + +IsUrgent: True +Category: Code Quality + +### Description + +Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead. + +Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate. + +## Classname ordering for easy overrides + +### Description + +When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles. + +Example: + +```tsx +import { cn } from '@/utils/classnames' + +const Button = ({ className }) => { + return
+} +``` diff --git a/.claude/skills/frontend-code-review/references/performance.md b/.claude/skills/frontend-code-review/references/performance.md new file mode 100644 index 0000000000..2d60072f5c --- /dev/null +++ b/.claude/skills/frontend-code-review/references/performance.md @@ -0,0 +1,45 @@ +# Rule Catalog — Performance + +## React Flow data usage + +IsUrgent: True +Category: Performance + +### Description + +When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks. + +## Complex prop memoization + +IsUrgent: True +Category: Performance + +### Description + +Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders. + +Update this file when adding, editing, or removing Performance rules so the catalog remains accurate. + +Wrong: + +```tsx + +``` + +Right: + +```tsx +const config = useMemo(() => ({ + provider: ..., + detail: ... +}), [provider, detail]); + + +``` diff --git a/.claude/skills/skill-creator/SKILL.md b/.claude/skills/skill-creator/SKILL.md new file mode 100644 index 0000000000..b49da5ac68 --- /dev/null +++ b/.claude/skills/skill-creator/SKILL.md @@ -0,0 +1,355 @@ +--- +name: skill-creator +description: Guide for creating effective skills. This skill should be used when users want to create a new skill (or update an existing skill) that extends Claude's capabilities with specialized knowledge, workflows, or tool integrations. +--- + +# Skill Creator + +This skill provides guidance for creating effective skills. + +## About Skills + +Skills are modular, self-contained packages that extend Claude's capabilities by providing +specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific +domains or tasks—they transform Claude from a general-purpose agent into a specialized agent +equipped with procedural knowledge that no model can fully possess. + +### What Skills Provide + +1. Specialized workflows - Multi-step procedures for specific domains +2. Tool integrations - Instructions for working with specific file formats or APIs +3. Domain expertise - Company-specific knowledge, schemas, business logic +4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks + +## Core Principles + +### Concise is Key + +The context window is a public good. Skills share the context window with everything else Claude needs: system prompt, conversation history, other Skills' metadata, and the actual user request. + +**Default assumption: Claude is already very smart.** Only add context Claude doesn't already have. Challenge each piece of information: "Does Claude really need this explanation?" and "Does this paragraph justify its token cost?" + +Prefer concise examples over verbose explanations. + +### Set Appropriate Degrees of Freedom + +Match the level of specificity to the task's fragility and variability: + +**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach. + +**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior. + +**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed. + +Think of Claude as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom). + +### Anatomy of a Skill + +Every skill consists of a required SKILL.md file and optional bundled resources: + +``` +skill-name/ +├── SKILL.md (required) +│ ├── YAML frontmatter metadata (required) +│ │ ├── name: (required) +│ │ └── description: (required) +│ └── Markdown instructions (required) +└── Bundled Resources (optional) + ├── scripts/ - Executable code (Python/Bash/etc.) + ├── references/ - Documentation intended to be loaded into context as needed + └── assets/ - Files used in output (templates, icons, fonts, etc.) +``` + +#### SKILL.md (required) + +Every SKILL.md consists of: + +- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that Claude reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used. +- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all). + +#### Bundled Resources (optional) + +##### Scripts (`scripts/`) + +Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten. + +- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed +- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks +- **Benefits**: Token efficient, deterministic, may be executed without loading into context +- **Note**: Scripts may still need to be read by Claude for patching or environment-specific adjustments + +##### References (`references/`) + +Documentation and reference material intended to be loaded as needed into context to inform Claude's process and thinking. + +- **When to include**: For documentation that Claude should reference while working +- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications +- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides +- **Benefits**: Keeps SKILL.md lean, loaded only when Claude determines it's needed +- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md +- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files. + +##### Assets (`assets/`) + +Files not intended to be loaded into context, but rather used within the output Claude produces. + +- **When to include**: When the skill needs files that will be used in the final output +- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography +- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified +- **Benefits**: Separates output resources from documentation, enables Claude to use files without loading them into context + +#### What to Not Include in a Skill + +A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including: + +- README.md +- INSTALLATION_GUIDE.md +- QUICK_REFERENCE.md +- CHANGELOG.md +- etc. + +The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxilary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion. + +### Progressive Disclosure Design Principle + +Skills use a three-level loading system to manage context efficiently: + +1. **Metadata (name + description)** - Always in context (~100 words) +2. **SKILL.md body** - When skill triggers (<5k words) +3. **Bundled resources** - As needed by Claude (Unlimited because scripts can be executed without reading into context window) + +#### Progressive Disclosure Patterns + +Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them. + +**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files. + +**Pattern 1: High-level guide with references** + +```markdown +# PDF Processing + +## Quick start + +Extract text with pdfplumber: +[code example] + +## Advanced features + +- **Form filling**: See [FORMS.md](FORMS.md) for complete guide +- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods +- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns +``` + +Claude loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed. + +**Pattern 2: Domain-specific organization** + +For Skills with multiple domains, organize content by domain to avoid loading irrelevant context: + +``` +bigquery-skill/ +├── SKILL.md (overview and navigation) +└── reference/ + ├── finance.md (revenue, billing metrics) + ├── sales.md (opportunities, pipeline) + ├── product.md (API usage, features) + └── marketing.md (campaigns, attribution) +``` + +When a user asks about sales metrics, Claude only reads sales.md. + +Similarly, for skills supporting multiple frameworks or variants, organize by variant: + +``` +cloud-deploy/ +├── SKILL.md (workflow + provider selection) +└── references/ + ├── aws.md (AWS deployment patterns) + ├── gcp.md (GCP deployment patterns) + └── azure.md (Azure deployment patterns) +``` + +When the user chooses AWS, Claude only reads aws.md. + +**Pattern 3: Conditional details** + +Show basic content, link to advanced content: + +```markdown +# DOCX Processing + +## Creating documents + +Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md). + +## Editing documents + +For simple edits, modify the XML directly. + +**For tracked changes**: See [REDLINING.md](REDLINING.md) +**For OOXML details**: See [OOXML.md](OOXML.md) +``` + +Claude reads REDLINING.md or OOXML.md only when the user needs those features. + +**Important guidelines:** + +- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md. +- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so Claude can see the full scope when previewing. + +## Skill Creation Process + +Skill creation involves these steps: + +1. Understand the skill with concrete examples +2. Plan reusable skill contents (scripts, references, assets) +3. Initialize the skill (run init_skill.py) +4. Edit the skill (implement resources and write SKILL.md) +5. Package the skill (run package_skill.py) +6. Iterate based on real usage + +Follow these steps in order, skipping only if there is a clear reason why they are not applicable. + +### Step 1: Understanding the Skill with Concrete Examples + +Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill. + +To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback. + +For example, when building an image-editor skill, relevant questions include: + +- "What functionality should the image-editor skill support? Editing, rotating, anything else?" +- "Can you give some examples of how this skill would be used?" +- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?" +- "What would a user say that should trigger this skill?" + +To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness. + +Conclude this step when there is a clear sense of the functionality the skill should support. + +### Step 2: Planning the Reusable Skill Contents + +To turn concrete examples into an effective skill, analyze each example by: + +1. Considering how to execute on the example from scratch +2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly + +Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows: + +1. Rotating a PDF requires re-writing the same code each time +2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill + +Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows: + +1. Writing a frontend webapp requires the same boilerplate HTML/React each time +2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill + +Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows: + +1. Querying BigQuery requires re-discovering the table schemas and relationships each time +2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill + +To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets. + +### Step 3: Initializing the Skill + +At this point, it is time to actually create the skill. + +Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step. + +When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable. + +Usage: + +```bash +scripts/init_skill.py --path +``` + +The script: + +- Creates the skill directory at the specified path +- Generates a SKILL.md template with proper frontmatter and TODO placeholders +- Creates example resource directories: `scripts/`, `references/`, and `assets/` +- Adds example files in each directory that can be customized or deleted + +After initialization, customize or remove the generated SKILL.md and example files as needed. + +### Step 4: Edit the Skill + +When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of Claude to use. Include information that would be beneficial and non-obvious to Claude. Consider what procedural knowledge, domain-specific details, or reusable assets would help another Claude instance execute these tasks more effectively. + +#### Learn Proven Design Patterns + +Consult these helpful guides based on your skill's needs: + +- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic +- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns + +These files contain established best practices for effective skill design. + +#### Start with Reusable Skill Contents + +To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`. + +Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion. + +Any example files and directories not needed for the skill should be deleted. The initialization script creates example files in `scripts/`, `references/`, and `assets/` to demonstrate structure, but most skills won't need all of them. + +#### Update SKILL.md + +**Writing Guidelines:** Always use imperative/infinitive form. + +##### Frontmatter + +Write the YAML frontmatter with `name` and `description`: + +- `name`: The skill name +- `description`: This is the primary triggering mechanism for your skill, and helps Claude understand when to use the skill. + - Include both what the Skill does and specific triggers/contexts for when to use it. + - Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Claude. + - Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Claude needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks" + +Do not include any other fields in YAML frontmatter. + +##### Body + +Write instructions for using the skill and its bundled resources. + +### Step 5: Packaging a Skill + +Once development of the skill is complete, it must be packaged into a distributable .skill file that gets shared with the user. The packaging process automatically validates the skill first to ensure it meets all requirements: + +```bash +scripts/package_skill.py +``` + +Optional output directory specification: + +```bash +scripts/package_skill.py ./dist +``` + +The packaging script will: + +1. **Validate** the skill automatically, checking: + + - YAML frontmatter format and required fields + - Skill naming conventions and directory structure + - Description completeness and quality + - File organization and resource references + +2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension. + +If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again. + +### Step 6: Iterate + +After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed. + +**Iteration workflow:** + +1. Use the skill on real tasks +2. Notice struggles or inefficiencies +3. Identify how SKILL.md or bundled resources should be updated +4. Implement changes and test again diff --git a/.claude/skills/skill-creator/references/output-patterns.md b/.claude/skills/skill-creator/references/output-patterns.md new file mode 100644 index 0000000000..022e85fe5e --- /dev/null +++ b/.claude/skills/skill-creator/references/output-patterns.md @@ -0,0 +1,86 @@ +# Output Patterns + +Use these patterns when skills need to produce consistent, high-quality output. + +## Template Pattern + +Provide templates for output format. Match the level of strictness to your needs. + +**For strict requirements (like API responses or data formats):** + +```markdown +## Report structure + +ALWAYS use this exact template structure: + +# [Analysis Title] + +## Executive summary +[One-paragraph overview of key findings] + +## Key findings +- Finding 1 with supporting data +- Finding 2 with supporting data +- Finding 3 with supporting data + +## Recommendations +1. Specific actionable recommendation +2. Specific actionable recommendation +``` + +**For flexible guidance (when adaptation is useful):** + +```markdown +## Report structure + +Here is a sensible default format, but use your best judgment: + +# [Analysis Title] + +## Executive summary +[Overview] + +## Key findings +[Adapt sections based on what you discover] + +## Recommendations +[Tailor to the specific context] + +Adjust sections as needed for the specific analysis type. +``` + +## Examples Pattern + +For skills where output quality depends on seeing examples, provide input/output pairs: + +```markdown +## Commit message format + +Generate commit messages following these examples: + +**Example 1:** +Input: Added user authentication with JWT tokens +Output: +``` + +feat(auth): implement JWT-based authentication + +Add login endpoint and token validation middleware + +``` + +**Example 2:** +Input: Fixed bug where dates displayed incorrectly in reports +Output: +``` + +fix(reports): correct date formatting in timezone conversion + +Use UTC timestamps consistently across report generation + +``` + +Follow this style: type(scope): brief description, then detailed explanation. +``` + +Examples help Claude understand the desired style and level of detail more clearly than descriptions alone. diff --git a/.claude/skills/skill-creator/references/workflows.md b/.claude/skills/skill-creator/references/workflows.md new file mode 100644 index 0000000000..54b0174078 --- /dev/null +++ b/.claude/skills/skill-creator/references/workflows.md @@ -0,0 +1,28 @@ +# Workflow Patterns + +## Sequential Workflows + +For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md: + +```markdown +Filling a PDF form involves these steps: + +1. Analyze the form (run analyze_form.py) +2. Create field mapping (edit fields.json) +3. Validate mapping (run validate_fields.py) +4. Fill the form (run fill_form.py) +5. Verify output (run verify_output.py) +``` + +## Conditional Workflows + +For tasks with branching logic, guide Claude through decision points: + +```markdown +1. Determine the modification type: + **Creating new content?** → Follow "Creation workflow" below + **Editing existing content?** → Follow "Editing workflow" below + +2. Creation workflow: [steps] +3. Editing workflow: [steps] +``` diff --git a/.claude/skills/skill-creator/scripts/init_skill.py b/.claude/skills/skill-creator/scripts/init_skill.py new file mode 100755 index 0000000000..249fffcbbd --- /dev/null +++ b/.claude/skills/skill-creator/scripts/init_skill.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +""" +Skill Initializer - Creates a new skill from template + +Usage: + init_skill.py --path + +Examples: + init_skill.py my-new-skill --path skills/public + init_skill.py my-api-helper --path skills/private + init_skill.py custom-skill --path /custom/location +""" + +import sys +from pathlib import Path + + +SKILL_TEMPLATE = """--- +name: {skill_name} +description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.] +--- + +# {skill_title} + +## Overview + +[TODO: 1-2 sentences explaining what this skill enables] + +## Structuring This Skill + +[TODO: Choose the structure that best fits this skill's purpose. Common patterns: + +**1. Workflow-Based** (best for sequential processes) +- Works well when there are clear step-by-step procedures +- Example: DOCX skill with "Workflow Decision Tree" → "Reading" → "Creating" → "Editing" +- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2... + +**2. Task-Based** (best for tool collections) +- Works well when the skill offers different operations/capabilities +- Example: PDF skill with "Quick Start" → "Merge PDFs" → "Split PDFs" → "Extract Text" +- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2... + +**3. Reference/Guidelines** (best for standards or specifications) +- Works well for brand guidelines, coding standards, or requirements +- Example: Brand styling with "Brand Guidelines" → "Colors" → "Typography" → "Features" +- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage... + +**4. Capabilities-Based** (best for integrated systems) +- Works well when the skill provides multiple interrelated features +- Example: Product Management with "Core Capabilities" → numbered capability list +- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature... + +Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations). + +Delete this entire "Structuring This Skill" section when done - it's just guidance.] + +## [TODO: Replace with the first main section based on chosen structure] + +[TODO: Add content here. See examples in existing skills: +- Code samples for technical skills +- Decision trees for complex workflows +- Concrete examples with realistic user requests +- References to scripts/templates/references as needed] + +## Resources + +This skill includes example resource directories that demonstrate how to organize different types of bundled resources: + +### scripts/ +Executable code (Python/Bash/etc.) that can be run directly to perform specific operations. + +**Examples from other skills:** +- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation +- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing + +**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations. + +**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments. + +### references/ +Documentation and reference material intended to be loaded into context to inform Claude's process and thinking. + +**Examples from other skills:** +- Product management: `communication.md`, `context_building.md` - detailed workflow guides +- BigQuery: API reference documentation and query examples +- Finance: Schema documentation, company policies + +**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working. + +### assets/ +Files not intended to be loaded into context, but rather used within the output Claude produces. + +**Examples from other skills:** +- Brand styling: PowerPoint template files (.pptx), logo files +- Frontend builder: HTML/React boilerplate project directories +- Typography: Font files (.ttf, .woff2) + +**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output. + +--- + +**Any unneeded directories can be deleted.** Not every skill requires all three types of resources. +""" + +EXAMPLE_SCRIPT = '''#!/usr/bin/env python3 +""" +Example helper script for {skill_name} + +This is a placeholder script that can be executed directly. +Replace with actual implementation or delete if not needed. + +Example real scripts from other skills: +- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields +- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images +""" + +def main(): + print("This is an example script for {skill_name}") + # TODO: Add actual script logic here + # This could be data processing, file conversion, API calls, etc. + +if __name__ == "__main__": + main() +''' + +EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title} + +This is a placeholder for detailed reference documentation. +Replace with actual reference content or delete if not needed. + +Example real reference docs from other skills: +- product-management/references/communication.md - Comprehensive guide for status updates +- product-management/references/context_building.md - Deep-dive on gathering context +- bigquery/references/ - API references and query examples + +## When Reference Docs Are Useful + +Reference docs are ideal for: +- Comprehensive API documentation +- Detailed workflow guides +- Complex multi-step processes +- Information too lengthy for main SKILL.md +- Content that's only needed for specific use cases + +## Structure Suggestions + +### API Reference Example +- Overview +- Authentication +- Endpoints with examples +- Error codes +- Rate limits + +### Workflow Guide Example +- Prerequisites +- Step-by-step instructions +- Common patterns +- Troubleshooting +- Best practices +""" + +EXAMPLE_ASSET = """# Example Asset File + +This placeholder represents where asset files would be stored. +Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed. + +Asset files are NOT intended to be loaded into context, but rather used within +the output Claude produces. + +Example asset files from other skills: +- Brand guidelines: logo.png, slides_template.pptx +- Frontend builder: hello-world/ directory with HTML/React boilerplate +- Typography: custom-font.ttf, font-family.woff2 +- Data: sample_data.csv, test_dataset.json + +## Common Asset Types + +- Templates: .pptx, .docx, boilerplate directories +- Images: .png, .jpg, .svg, .gif +- Fonts: .ttf, .otf, .woff, .woff2 +- Boilerplate code: Project directories, starter files +- Icons: .ico, .svg +- Data files: .csv, .json, .xml, .yaml + +Note: This is a text placeholder. Actual assets can be any file type. +""" + + +def title_case_skill_name(skill_name): + """Convert hyphenated skill name to Title Case for display.""" + return " ".join(word.capitalize() for word in skill_name.split("-")) + + +def init_skill(skill_name, path): + """ + Initialize a new skill directory with template SKILL.md. + + Args: + skill_name: Name of the skill + path: Path where the skill directory should be created + + Returns: + Path to created skill directory, or None if error + """ + # Determine skill directory path + skill_dir = Path(path).resolve() / skill_name + + # Check if directory already exists + if skill_dir.exists(): + print(f"❌ Error: Skill directory already exists: {skill_dir}") + return None + + # Create skill directory + try: + skill_dir.mkdir(parents=True, exist_ok=False) + print(f"✅ Created skill directory: {skill_dir}") + except Exception as e: + print(f"❌ Error creating directory: {e}") + return None + + # Create SKILL.md from template + skill_title = title_case_skill_name(skill_name) + skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title) + + skill_md_path = skill_dir / "SKILL.md" + try: + skill_md_path.write_text(skill_content) + print("✅ Created SKILL.md") + except Exception as e: + print(f"❌ Error creating SKILL.md: {e}") + return None + + # Create resource directories with example files + try: + # Create scripts/ directory with example script + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir(exist_ok=True) + example_script = scripts_dir / "example.py" + example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name)) + example_script.chmod(0o755) + print("✅ Created scripts/example.py") + + # Create references/ directory with example reference doc + references_dir = skill_dir / "references" + references_dir.mkdir(exist_ok=True) + example_reference = references_dir / "api_reference.md" + example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title)) + print("✅ Created references/api_reference.md") + + # Create assets/ directory with example asset placeholder + assets_dir = skill_dir / "assets" + assets_dir.mkdir(exist_ok=True) + example_asset = assets_dir / "example_asset.txt" + example_asset.write_text(EXAMPLE_ASSET) + print("✅ Created assets/example_asset.txt") + except Exception as e: + print(f"❌ Error creating resource directories: {e}") + return None + + # Print next steps + print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}") + print("\nNext steps:") + print("1. Edit SKILL.md to complete the TODO items and update the description") + print("2. Customize or delete the example files in scripts/, references/, and assets/") + print("3. Run the validator when ready to check the skill structure") + + return skill_dir + + +def main(): + if len(sys.argv) < 4 or sys.argv[2] != "--path": + print("Usage: init_skill.py --path ") + print("\nSkill name requirements:") + print(" - Hyphen-case identifier (e.g., 'data-analyzer')") + print(" - Lowercase letters, digits, and hyphens only") + print(" - Max 40 characters") + print(" - Must match directory name exactly") + print("\nExamples:") + print(" init_skill.py my-new-skill --path skills/public") + print(" init_skill.py my-api-helper --path skills/private") + print(" init_skill.py custom-skill --path /custom/location") + sys.exit(1) + + skill_name = sys.argv[1] + path = sys.argv[3] + + print(f"🚀 Initializing skill: {skill_name}") + print(f" Location: {path}") + print() + + result = init_skill(skill_name, path) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.claude/skills/skill-creator/scripts/package_skill.py b/.claude/skills/skill-creator/scripts/package_skill.py new file mode 100755 index 0000000000..736b928be0 --- /dev/null +++ b/.claude/skills/skill-creator/scripts/package_skill.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +Skill Packager - Creates a distributable .skill file of a skill folder + +Usage: + python utils/package_skill.py [output-directory] + +Example: + python utils/package_skill.py skills/public/my-skill + python utils/package_skill.py skills/public/my-skill ./dist +""" + +import sys +import zipfile +from pathlib import Path +from quick_validate import validate_skill + + +def package_skill(skill_path, output_dir=None): + """ + Package a skill folder into a .skill file. + + Args: + skill_path: Path to the skill folder + output_dir: Optional output directory for the .skill file (defaults to current directory) + + Returns: + Path to the created .skill file, or None if error + """ + skill_path = Path(skill_path).resolve() + + # Validate skill folder exists + if not skill_path.exists(): + print(f"❌ Error: Skill folder not found: {skill_path}") + return None + + if not skill_path.is_dir(): + print(f"❌ Error: Path is not a directory: {skill_path}") + return None + + # Validate SKILL.md exists + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + print(f"❌ Error: SKILL.md not found in {skill_path}") + return None + + # Run validation before packaging + print("🔍 Validating skill...") + valid, message = validate_skill(skill_path) + if not valid: + print(f"❌ Validation failed: {message}") + print(" Please fix the validation errors before packaging.") + return None + print(f"✅ {message}\n") + + # Determine output location + skill_name = skill_path.name + if output_dir: + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + else: + output_path = Path.cwd() + + skill_filename = output_path / f"{skill_name}.skill" + + # Create the .skill file (zip format) + try: + with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf: + # Walk through the skill directory + for file_path in skill_path.rglob("*"): + if file_path.is_file(): + # Calculate the relative path within the zip + arcname = file_path.relative_to(skill_path.parent) + zipf.write(file_path, arcname) + print(f" Added: {arcname}") + + print(f"\n✅ Successfully packaged skill to: {skill_filename}") + return skill_filename + + except Exception as e: + print(f"❌ Error creating .skill file: {e}") + return None + + +def main(): + if len(sys.argv) < 2: + print("Usage: python utils/package_skill.py [output-directory]") + print("\nExample:") + print(" python utils/package_skill.py skills/public/my-skill") + print(" python utils/package_skill.py skills/public/my-skill ./dist") + sys.exit(1) + + skill_path = sys.argv[1] + output_dir = sys.argv[2] if len(sys.argv) > 2 else None + + print(f"📦 Packaging skill: {skill_path}") + if output_dir: + print(f" Output directory: {output_dir}") + print() + + result = package_skill(skill_path, output_dir) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.claude/skills/skill-creator/scripts/quick_validate.py b/.claude/skills/skill-creator/scripts/quick_validate.py new file mode 100755 index 0000000000..66eb0a71bf --- /dev/null +++ b/.claude/skills/skill-creator/scripts/quick_validate.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Quick validation script for skills - minimal version +""" + +import sys +import os +import re +import yaml +from pathlib import Path + + +def validate_skill(skill_path): + """Basic validation of a skill""" + skill_path = Path(skill_path) + + # Check SKILL.md exists + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + return False, "SKILL.md not found" + + # Read and validate frontmatter + content = skill_md.read_text() + if not content.startswith("---"): + return False, "No YAML frontmatter found" + + # Extract frontmatter + match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) + if not match: + return False, "Invalid frontmatter format" + + frontmatter_text = match.group(1) + + # Parse YAML frontmatter + try: + frontmatter = yaml.safe_load(frontmatter_text) + if not isinstance(frontmatter, dict): + return False, "Frontmatter must be a YAML dictionary" + except yaml.YAMLError as e: + return False, f"Invalid YAML in frontmatter: {e}" + + # Define allowed properties + ALLOWED_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata"} + + # Check for unexpected properties (excluding nested keys under metadata) + unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES + if unexpected_keys: + return False, ( + f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. " + f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}" + ) + + # Check required fields + if "name" not in frontmatter: + return False, "Missing 'name' in frontmatter" + if "description" not in frontmatter: + return False, "Missing 'description' in frontmatter" + + # Extract name for validation + name = frontmatter.get("name", "") + if not isinstance(name, str): + return False, f"Name must be a string, got {type(name).__name__}" + name = name.strip() + if name: + # Check naming convention (hyphen-case: lowercase with hyphens) + if not re.match(r"^[a-z0-9-]+$", name): + return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)" + if name.startswith("-") or name.endswith("-") or "--" in name: + return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens" + # Check name length (max 64 characters per spec) + if len(name) > 64: + return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters." + + # Extract and validate description + description = frontmatter.get("description", "") + if not isinstance(description, str): + return False, f"Description must be a string, got {type(description).__name__}" + description = description.strip() + if description: + # Check for angle brackets + if "<" in description or ">" in description: + return False, "Description cannot contain angle brackets (< or >)" + # Check description length (max 1024 characters per spec) + if len(description) > 1024: + return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters." + + return True, "Skill is valid!" + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python quick_validate.py ") + sys.exit(1) + + valid, message = validate_skill(sys.argv[1]) + print(message) + sys.exit(0 if valid else 1) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index aa5a50918a..50dbde2aee 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -20,4 +20,4 @@ - [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) - [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. - [x] I've updated the documentation accordingly. -- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods +- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods diff --git a/.gitignore b/.gitignore index 17a2bd5b7b..7bd919f095 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,4 @@ scripts/stress-test/reports/ # settings *.local.json +*.local.md diff --git a/Makefile b/Makefile index 07afd8187e..60c32948b9 100644 --- a/Makefile +++ b/Makefile @@ -60,9 +60,10 @@ check: @echo "✅ Code check complete" lint: - @echo "🔧 Running ruff format, check with fixes, and import linter..." + @echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..." @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api' @uv run --directory api --dev lint-imports + @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example @echo "✅ Linting complete" type-check: @@ -122,7 +123,7 @@ help: @echo "Backend Code Quality:" @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" - @echo " make lint - Format and fix code with ruff" + @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make type-check - Run type checking with basedpyright" @echo " make test - Run backend unit tests" @echo "" diff --git a/api/.env.example b/api/.env.example index 88611e016e..44d770ed70 100644 --- a/api/.env.example +++ b/api/.env.example @@ -575,6 +575,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # Celery beat configuration CELERY_BEAT_SCHEDULER_TIME=1 diff --git a/api/.importlinter b/api/.importlinter index 24ece72b30..2dec958788 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -3,9 +3,11 @@ root_packages = core configs controllers + extensions models tasks services +include_external_packages = True [importlinter:contract:workflow] name = Workflow @@ -33,6 +35,28 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels +[importlinter:contract:workflow-infrastructure-dependencies] +name = Workflow Infrastructure Dependencies +type = forbidden +source_modules = + core.workflow +forbidden_modules = + extensions.ext_database + extensions.ext_redis +allow_indirect_imports = True +ignore_imports = + core.workflow.nodes.agent.agent_node -> extensions.ext_database + core.workflow.nodes.datasource.datasource_node -> extensions.ext_database + core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database + core.workflow.nodes.llm.file_saver -> extensions.ext_database + core.workflow.nodes.llm.llm_utils -> extensions.ext_database + core.workflow.nodes.llm.node -> extensions.ext_database + core.workflow.nodes.tool.tool_node -> extensions.ext_database + core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis + core.workflow.graph_engine.manager -> extensions.ext_redis + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/Dockerfile b/api/Dockerfile index 02df91bfc1..a08d4e3aab 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -50,16 +50,33 @@ WORKDIR /app/api # Create non-root user ARG dify_uid=1001 +ARG NODE_MAJOR=22 +ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1 +ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4 RUN groupadd -r -g ${dify_uid} dify && \ useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \ chown -R dify:dify /app RUN \ apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + gnupg \ + && mkdir -p /etc/apt/keyrings \ + && curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \ + && gpg --show-keys --with-colons /tmp/nodesource.gpg \ + | awk -F: '/^fpr:/ {print $10}' \ + | grep -Fx "${NODESOURCE_KEY_FPR}" \ + && gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \ + && rm -f /tmp/nodesource.gpg \ + && echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \ + > /etc/apt/sources.list.d/nodesource.list \ + && apt-get update \ # Install dependencies && apt-get install -y --no-install-recommends \ # basic environment - curl nodejs \ + nodejs=${NODE_PACKAGE_VERSION} \ # for gmpy2 \ libgmp-dev libmpfr-dev libmpc-dev \ # For Security @@ -79,7 +96,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data -RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \ +RUN mkdir -p /usr/local/share/nltk_data \ + && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \ && chmod -R 755 /usr/local/share/nltk_data ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache diff --git a/api/commands.py b/api/commands.py index a8d89ac200..7ebf5b4874 100644 --- a/api/commands.py +++ b/api/commands.py @@ -235,7 +235,7 @@ def migrate_annotation_vector_database(): if annotations: for annotation in annotations: document = Document( - page_content=annotation.question, + page_content=annotation.question_text, metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) @@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) +@click.command("file-usage", help="Query file usages and show where files are referenced.") +@click.option("--file-id", type=str, default=None, help="Filter by file UUID.") +@click.option("--key", type=str, default=None, help="Filter by storage key.") +@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').") +@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).") +@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).") +@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.") +def file_usage( + file_id: str | None, + key: str | None, + src: str | None, + limit: int, + offset: int, + output_json: bool, +): + """ + Query file usages and show where files are referenced in the database. + + This command reuses the same reference checking logic as clear-orphaned-file-records + and displays detailed information about where each file is referenced. + """ + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"}, + {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"}, + {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"}, + {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"}, + {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"}, + {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"}, + {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "message", "pk_column": "id"}, + ] + + # Stream file usages with pagination to avoid holding all results in memory + paginated_usages = [] + total_count = 0 + + # First, build a mapping of file_id -> storage_key from the base tables + file_key_map = {} + for files_table in files_tables: + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}" + + # If filtering by key or file_id, verify it exists + if file_id and file_id not in file_key_map: + if output_json: + click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"})) + else: + click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red")) + return + + if key: + valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"} + matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes] + if not matching_file_ids: + if output_json: + click.echo(json.dumps({"error": f"Key {key} not found in base tables"})) + else: + click.echo(click.style(f"Key {key} not found in base tables.", fg="red")) + return + + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + + # For each reference table/column, find matching file IDs and record the references + for ids_table in ids_tables: + src_filter = f"{ids_table['table']}.{ids_table['column']}" + + # Skip if src filter doesn't match (use fnmatch for wildcard patterns) + if src: + if "%" in src or "_" in src: + import fnmatch + + # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?) + pattern = src.replace("%", "*").replace("_", "?") + if not fnmatch.fnmatch(src_filter, pattern): + continue + else: + if src_filter != src: + continue + + if ids_table["type"] == "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + elif ids_table["type"] in ("text", "json"): + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + # Output results + if output_json: + result = { + "total": total_count, + "offset": offset, + "limit": limit, + "usages": paginated_usages, + } + click.echo(json.dumps(result, indent=2)) + else: + click.echo( + click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white") + ) + click.echo("") + + if not paginated_usages: + click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow")) + return + + # Print table header + click.echo( + click.style( + f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}", + fg="cyan", + ) + ) + click.echo(click.style("-" * 190, fg="white")) + + # Print each usage + for usage in paginated_usages: + click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}") + + # Show pagination info + if offset + limit < total_count: + click.echo("") + click.echo( + click.style( + f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white" + ) + ) + click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white")) + + @click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") @click.option("--provider", prompt=True, help="Provider name") @click.option("--client-params", prompt=True, help="Client Params") diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 05cee51cc9..eb9b0ac2ab 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings): description="Authentication token for Milvus, if token-based authentication is enabled", default=None, ) - MILVUS_USER: str | None = Field( description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 44cf89d6a9..d66bb7063f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,14 +1,16 @@ import re import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal, TypeAlias from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -19,27 +21,19 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, ) +from core.file import helpers as file_helpers from core.ops.ops_trace_manager import OpsTraceManager from core.workflow.enums import NodeType from extensions.ext_database import db -from fields.app_fields import ( - deleted_tool_fields, - model_config_fields, - model_config_partial_fields, - site_fields, - tag_fields, -) -from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict -from libs.helper import AppIconUrlField, TimestampField from libs.login import current_account_with_tenant, login_required from models import App, Workflow +from models.model import IconType from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class AppListQuery(BaseModel): @@ -192,124 +186,292 @@ class AppTracePayload(BaseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +JSONValue: TypeAlias = Any -reg(AppListQuery) -reg(CreateAppPayload) -reg(UpdateAppPayload) -reg(CopyAppPayload) -reg(AppExportQuery) -reg(AppNamePayload) -reg(AppIconPayload) -reg(AppSiteStatusPayload) -reg(AppApiStatusPayload) -reg(AppTracePayload) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base models first -tag_model = console_ns.model("Tag", tag_fields) -workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -model_config_model = console_ns.model("ModelConfig", model_config_fields) -model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE.value: + return None + return file_helpers.get_signed_file_url(icon) -deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields) -site_model = console_ns.model("Site", site_fields) +class Tag(ResponseModel): + id: str + name: str + type: str -app_partial_model = console_ns.model( - "AppPartial", - { - "id": fields.String, - "name": fields.String, - "max_active_requests": fields.Raw(), - "description": fields.String(attribute="desc_or_prompt"), - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "tags": fields.List(fields.Nested(tag_model)), - "access_mode": fields.String, - "create_user_name": fields.String, - "author_name": fields.String, - "has_draft_trigger": fields.Boolean, - }, -) -app_detail_model = console_ns.model( - "AppDetail", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon": fields.String, - "icon_background": fields.String, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "tracing": fields.Raw, - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - }, -) +class WorkflowPartial(ResponseModel): + id: str + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None -app_detail_with_site_model = console_ns.model( - "AppDetailWithSite", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "api_base_url": fields.String, - "use_icon_as_answer_icon": fields.Boolean, - "max_active_requests": fields.Integer, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "deleted_tools": fields.List(fields.Nested(deleted_tool_model)), - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - "site": fields.Nested(site_model), - }, -) + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -app_pagination_model = console_ns.model( - "AppPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(app_partial_model), attribute="items"), - }, + +class ModelConfigPartial(ResponseModel): + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + pre_prompt: str | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ModelConfig(ResponseModel): + opening_statement: str | None = None + suggested_questions: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions") + ) + suggested_questions_after_answer: JSONValue | None = Field( + default=None, + validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"), + ) + speech_to_text: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text") + ) + text_to_speech: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech") + ) + retriever_resource: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource") + ) + annotation_reply: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply") + ) + more_like_this: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this") + ) + sensitive_word_avoidance: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance") + ) + external_data_tools: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools") + ) + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + user_input_form: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form") + ) + dataset_query_variable: str | None = None + pre_prompt: str | None = None + agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode")) + prompt_type: str | None = None + chat_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config") + ) + completion_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config") + ) + dataset_configs: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs") + ) + file_upload: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload") + ) + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class Site(ResponseModel): + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str | None = None + icon_type: str | IconType | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str | None = None + chat_color_theme: str | None = None + chat_color_theme_inverted: bool | None = None + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str | None = None + prompt_public: bool | None = None + app_base_url: str | None = None + show_workflow_steps: bool | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("icon_type", mode="before") + @classmethod + def _normalize_icon_type(cls, value: str | IconType | None) -> str | None: + if isinstance(value, IconType): + return value.value + return value + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DeletedTool(ResponseModel): + type: str + tool_name: str + provider_id: str + + +class AppPartial(ResponseModel): + id: str + name: str + max_active_requests: int | None = None + description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description")) + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + model_config_: ModelConfigPartial | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + tags: list[Tag] = Field(default_factory=list) + access_mode: str | None = None + create_user_name: str | None = None + author_name: str | None = None + has_draft_trigger: bool | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetail(ResponseModel): + id: str + name: str + description: str | None = None + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon: str | None = None + icon_background: str | None = None + enable_site: bool + enable_api: bool + model_config_: ModelConfig | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + tracing: JSONValue | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + access_mode: str | None = None + tags: list[Tag] = Field(default_factory=list) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetailWithSite(AppDetail): + icon_type: str | None = None + api_base_url: str | None = None + max_active_requests: int | None = None + deleted_tools: list[DeletedTool] = Field(default_factory=list) + site: Site | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class AppPagination(ResponseModel): + page: int + limit: int = Field(validation_alias=AliasChoices("per_page", "limit")) + total: int + has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more")) + data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data")) + + +class AppExportResponse(ResponseModel): + data: str + + +register_schema_models( + console_ns, + AppListQuery, + CreateAppPayload, + UpdateAppPayload, + CopyAppPayload, + AppExportQuery, + AppNamePayload, + AppIconPayload, + AppSiteStatusPayload, + AppApiStatusPayload, + AppTracePayload, + Tag, + WorkflowPartial, + ModelConfigPartial, + ModelConfig, + Site, + DeletedTool, + AppPartial, + AppDetail, + AppDetailWithSite, + AppPagination, + AppExportResponse, ) @@ -318,7 +480,7 @@ class AppListApi(Resource): @console_ns.doc("list_apps") @console_ns.doc(description="Get list of applications with pagination and filtering") @console_ns.expect(console_ns.models[AppListQuery.__name__]) - @console_ns.response(200, "Success", app_pagination_model) + @console_ns.response(200, "Success", console_ns.models[AppPagination.__name__]) @setup_required @login_required @account_initialization_required @@ -334,7 +496,8 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) if not app_pagination: - return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} + empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) + return empty.model_dump(mode="json"), 200 if FeatureService.get_system_features().webapp_auth.enabled: app_ids = [str(app.id) for app in app_pagination.items] @@ -378,18 +541,18 @@ class AppListApi(Resource): for app in app_pagination.items: app.has_draft_trigger = str(app.id) in draft_trigger_app_ids - return marshal(app_pagination, app_pagination_model), 200 + pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True) + return pagination_model.model_dump(mode="json"), 200 @console_ns.doc("create_app") @console_ns.doc(description="Create a new application") @console_ns.expect(console_ns.models[CreateAppPayload.__name__]) - @console_ns.response(201, "App created successfully", app_detail_model) + @console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(app_detail_model) @cloud_edition_billing_resource_check("apps") @edit_permission_required def post(self): @@ -399,8 +562,8 @@ class AppListApi(Resource): app_service = AppService() app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) - - return app, 201 + app_detail = AppDetail.model_validate(app, from_attributes=True) + return app_detail.model_dump(mode="json"), 201 @console_ns.route("/apps/") @@ -408,13 +571,12 @@ class AppApi(Resource): @console_ns.doc("get_app_detail") @console_ns.doc(description="Get application details") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Success", app_detail_with_site_model) + @console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__]) @setup_required @login_required @account_initialization_required @enterprise_license_required - @get_app_model - @marshal_with(app_detail_with_site_model) + @get_app_model(mode=None) def get(self, app_model): """Get app detail""" app_service = AppService() @@ -425,21 +587,21 @@ class AppApi(Resource): app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("update_app") @console_ns.doc(description="Update application details") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[UpdateAppPayload.__name__]) - @console_ns.response(200, "App updated successfully", app_detail_with_site_model) + @console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def put(self, app_model): """Update app""" args = UpdateAppPayload.model_validate(console_ns.payload) @@ -456,8 +618,8 @@ class AppApi(Resource): "max_active_requests": args.max_active_requests or 0, } app_model = app_service.update_app(app_model, args_dict) - - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("delete_app") @console_ns.doc(description="Delete application") @@ -483,14 +645,13 @@ class AppCopyApi(Resource): @console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(params={"app_id": "Application ID to copy"}) @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) - @console_ns.response(201, "App copied successfully", app_detail_with_site_model) + @console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def post(self, app_model): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -516,7 +677,8 @@ class AppCopyApi(Resource): stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) - return app, 201 + response_model = AppDetailWithSite.model_validate(app, from_attributes=True) + return response_model.model_dump(mode="json"), 201 @console_ns.route("/apps//export") @@ -525,11 +687,7 @@ class AppExportApi(Resource): @console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(params={"app_id": "Application ID to export"}) @console_ns.expect(console_ns.models[AppExportQuery.__name__]) - @console_ns.response( - 200, - "App exported successfully", - console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), - ) + @console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @get_app_model @setup_required @@ -540,13 +698,14 @@ class AppExportApi(Resource): """Export app""" args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return { - "data": AppDslService.export_dsl( + payload = AppExportResponse( + data=AppDslService.export_dsl( app_model=app_model, include_secret=args.include_secret, workflow_id=args.workflow_id, ) - } + ) + return payload.model_dump(mode="json") @console_ns.route("/apps//name") @@ -555,20 +714,19 @@ class AppNameApi(Resource): @console_ns.doc(description="Check if app name is available") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppNamePayload.__name__]) - @console_ns.response(200, "Name availability checked") + @console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__]) @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppNamePayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_name(app_model, args.name) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//icon") @@ -582,16 +740,15 @@ class AppIconApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//site-enable") @@ -600,21 +757,20 @@ class AppSiteStatus(Resource): @console_ns.doc(description="Enable or disable app site") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__]) - @console_ns.response(200, "Site status updated successfully", app_detail_model) + @console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppSiteStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_site_status(app_model, args.enable_site) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//api-enable") @@ -623,21 +779,20 @@ class AppApiStatus(Resource): @console_ns.doc(description="Enable or disable app API") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__]) - @console_ns.response(200, "API status updated successfully", app_detail_model) + @console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) def post(self, app_model): args = AppApiStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_api_status(app_model, args.enable_api) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//trace") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index ef2f86d4be..56816dd462 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -348,10 +348,13 @@ class CompletionConversationApi(Resource): ) if args.keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike(f"%{args.keyword}%"), - Message.answer.ilike(f"%{args.keyword}%"), + Message.query.ilike(f"%{escaped_keyword}%", escape="\\"), + Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) @@ -460,7 +463,10 @@ class ChatConversationApi(Resource): query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args.keyword: - keyword_filter = f"%{args.keyword}%" + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) + keyword_filter = f"%{escaped_keyword}%" query = ( query.join( Message, @@ -469,11 +475,11 @@ class ChatConversationApi(Resource): .join(subquery, subquery.c.conversation_id == Conversation.id) .where( or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter), + Message.query.ilike(keyword_filter, escape="\\"), + Message.answer.ilike(keyword_filter, escape="\\"), + Conversation.name.ilike(keyword_filter, escape="\\"), + Conversation.introduction.ilike(keyword_filter, escape="\\"), + subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"), ), ) .group_by(Conversation.id) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 772d98822e..4a52bf8abe 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,3 +1,5 @@ +from typing import Any + import flask_login from flask import make_response, request from flask_restx import Resource @@ -96,14 +98,13 @@ class LoginApi(Resource): if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() - # TODO: why invitation is re-assigned with different type? - invitation = args.invite_token # type: ignore - if invitation: - invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore + invitation_data: dict[str, Any] | None = None + if args.invite_token: + invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token) try: - if invitation: - data = invitation.get("data", {}) # type: ignore + if invitation_data: + data = invitation_data.get("data", {}) invitee_email = data.get("email") if data else None if invitee_email != args.email: raise InvalidEmailError() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e94768f985..ac78d3854b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -751,12 +751,12 @@ class DocumentApi(DocumentResource): elif metadata == "without": dataset_process_rules = DatasetService.get_process_rules(dataset_id) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} - data_source_info = document.data_source_detail_dict response = { "id": document.id, "position": document.position, "data_source_type": document.data_source_type, - "data_source_info": data_source_info, + "data_source_info": document.data_source_info_dict, + "data_source_detail_dict": document.data_source_detail_dict, "dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule": dataset_process_rules, "document_process_rule": document_process_rules, @@ -784,12 +784,12 @@ class DocumentApi(DocumentResource): else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} - data_source_info = document.data_source_detail_dict response = { "id": document.id, "position": document.position, "data_source_type": document.data_source_type, - "data_source_info": data_source_info, + "data_source_info": document.data_source_info_dict, + "data_source_detail_dict": document.data_source_detail_dict, "dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule": dataset_process_rules, "document_process_rule": document_process_rules, diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5a536af6d2..16fecb41c6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -145,6 +146,8 @@ class DatasetDocumentSegmentListApi(Resource): query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword) # Search in both content and keywords fields # Use database-specific methods for JSON array search if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": @@ -156,15 +159,15 @@ class DatasetDocumentSegmentListApi(Resource): .scalar_subquery() ), ",", - ).ilike(f"%{keyword}%") + ).ilike(f"%{escaped_keyword}%", escape="\\") else: # MySQL: Cast JSON to string for pattern matching # MySQL stores Chinese text directly in JSON without Unicode escaping - keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%") + keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\") query = query.where( or_( - DocumentSegment.content.ilike(f"%{keyword}%"), + DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"), keywords_condition, ) ) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db7c50f422..db1a874437 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal, reqparse +from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -56,15 +56,10 @@ class DatasetsHitTestingBase: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(): - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, required=False, location="json") - .add_argument("attachment_ids", type=list, required=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - return parser.parse_args() + def parse_args(payload: dict[str, Any]) -> dict[str, Any]: + """Validate and return hit-testing arguments from an incoming payload.""" + hit_testing_payload = HitTestingPayload.model_validate(payload or {}) + return hit_testing_payload.model_dump(exclude_none=True) @staticmethod def perform_hit_testing(dataset, args): diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 46d67f0581..02efc54eea 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource): pipeline=pipeline, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE, streaming=streaming, ) diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 29417dc896..109a3cd0d3 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,7 +1,7 @@ from typing import Literal from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import Forbidden import services @@ -15,18 +15,21 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, setup_required, ) from extensions.ext_database import db -from fields.file_fields import file_fields, upload_config_fields +from fields.file_fields import FileResponse, UploadConfig from libs.login import current_account_with_tenant, login_required from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, UploadConfig, FileResponse) + PREVIEW_WORDS_LIMIT = 3000 @@ -35,26 +38,27 @@ class FileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(upload_config_fields) + @console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__]) def get(self): - return { - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT, - "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT, - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, - "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, - "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, - "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, - }, 200 + config = UploadConfig( + file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT, + batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT, + file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT, + image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT, + single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, + ) + return config.model_dump(mode="json"), 200 @setup_required @login_required @account_initialization_required - @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") + @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() source_str = request.form.get("source") @@ -90,7 +94,8 @@ class FileApi(Resource): except services.errors.file.BlockedFileExtensionError as blocked_extension_error: raise BlockedFileExtensionError(blocked_extension_error.description) - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 @console_ns.route("/files//preview") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 47eef7eb7e..70c7b80ffa 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field import services @@ -11,19 +11,22 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl) + @console_ns.route("/remote-files/") class RemoteFileInfoApi(Resource): - @marshal_with(remote_file_info_fields) + @console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__]) def get(self, url): decoded_url = urllib.parse.unquote(url) resp = ssrf_proxy.head(decoded_url) @@ -31,10 +34,11 @@ class RemoteFileInfoApi(Resource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", 0)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", 0)), + ) + return info.model_dump(mode="json") class RemoteFileUploadPayload(BaseModel): @@ -50,7 +54,7 @@ console_ns.schema_model( @console_ns.route("/remote-files/upload") class RemoteFileUploadApi(Resource): @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) - @marshal_with(file_fields_with_signed_url) + @console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__]) def post(self): args = RemoteFileUploadPayload.model_validate(console_ns.payload) url = args.url @@ -85,13 +89,14 @@ class RemoteFileUploadApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload.model_dump(mode="json"), 201 diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 55eaa2f09f..03ad0f423b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import Literal @@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel): repeat_new_password: str @model_validator(mode="after") - def check_passwords_match(self) -> "AccountPasswordPayload": + def check_passwords_match(self) -> AccountPasswordPayload: if self.new_password != self.repeat_new_password: raise RepeatPasswordNotMatchError() return self diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 497e62b790..c13bfd986e 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -4,12 +4,11 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource, reqparse -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config -from constants import HIDDEN_VALUE, UNKNOWN_VALUE from controllers.web.error import NotFoundError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType @@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel): parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription") properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription") + @model_validator(mode="after") + def check_at_least_one_field(self): + if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)): + raise ValueError("At least one of name, credentials, parameters, or properties must be provided") + return self + class TriggerSubscriptionVerifyRequest(BaseModel): """Request payload for verifying subscription credentials.""" @@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource): user = current_user assert user.current_tenant_id is not None - args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload) + request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload) subscription = TriggerProviderService.get_subscription_by_id( tenant_id=user.current_tenant_id, @@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource): provider_id = TriggerProviderID(subscription.provider_id) try: - # rename only - if ( - args.name is not None - and args.credentials is None - and args.parameters is None - and args.properties is None - ): + # For rename only, just update the name + rename = request.name is not None and not any((request.credentials, request.parameters, request.properties)) + # When credential type is UNAUTHORIZED, it indicates the subscription was manually created + # For Manually created subscription, they dont have credentials, parameters + # They only have name and properties(which is input by user) + manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED + if rename or manually_created: TriggerProviderService.update_trigger_subscription( tenant_id=user.current_tenant_id, subscription_id=subscription_id, - name=args.name, + name=request.name, + properties=request.properties, ) return 200 - # rebuild for create automatically by the provider - match subscription.credential_type: - case CredentialType.UNAUTHORIZED: - TriggerProviderService.update_trigger_subscription( - tenant_id=user.current_tenant_id, - subscription_id=subscription_id, - name=args.name, - properties=args.properties, - ) - return 200 - case CredentialType.API_KEY | CredentialType.OAUTH2: - if args.credentials: - new_credentials: dict[str, Any] = { - key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE) - for key, value in args.credentials.items() - } - else: - new_credentials = subscription.credentials - - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=user.current_tenant_id, - name=args.name, - provider_id=provider_id, - subscription_id=subscription_id, - credentials=new_credentials, - parameters=args.parameters or subscription.parameters, - ) - return 200 - case _: - raise BadRequest("Invalid credential type") + # For the rest cases(API_KEY, OAUTH2) + # we need to call third party provider(e.g. GitHub) to rebuild the subscription + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=user.current_tenant_id, + name=request.name, + provider_id=provider_id, + subscription_id=subscription_id, + credentials=request.credentials or subscription.credentials, + parameters=request.parameters or subscription.parameters, + ) + return 200 except ValueError as e: raise BadRequest(str(e)) except Exception as e: diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 6096a87c56..28ec4b3935 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -4,18 +4,18 @@ from flask import request from flask_restx import Resource from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from ..common.errors import ( FileTooLargeError, UnsupportedFileTypeError, ) +from ..common.schema import register_schema_models from ..console.wraps import setup_required from ..files import files_ns from ..inner_api.plugin.wraps import get_user @@ -35,6 +35,8 @@ files_ns.schema_model( PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) +register_schema_models(files_ns, FileResponse) + @files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource): 415: "Unsupported file type", } ) - @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED) + @files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__]) def post(self): """Upload a file for plugin usage. @@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource): """ args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - file: FileStorage | None = request.files.get("file") + file = request.files.get("file") if file is None: raise Forbidden("File is required.") @@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource): user_id = args.user_id user = get_user(tenant_id, user_id) - filename: str | None = file.filename - mimetype: str | None = file.mimetype + filename = file.filename + mimetype = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") @@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource): preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) # Create a dictionary with all the necessary attributes - result = { - "id": tool_file.id, - "user_id": tool_file.user_id, - "tenant_id": tool_file.tenant_id, - "conversation_id": tool_file.conversation_id, - "file_key": tool_file.file_key, - "mimetype": tool_file.mimetype, - "original_url": tool_file.original_url, - "name": tool_file.name, - "size": tool_file.size, - "mime_type": mimetype, - "extension": extension, - "preview_url": preview_url, - } + result = FileResponse( + id=tool_file.id, + name=tool_file.name, + size=tool_file.size, + extension=extension, + mime_type=mimetype, + preview_url=preview_url, + source_url=tool_file.original_url, + original_url=tool_file.original_url, + user_id=tool_file.user_id, + tenant_id=tool_file.tenant_id, + conversation_id=tool_file.conversation_id, + file_key=tool_file.file_key, + ) - return result, 201 + return result.model_dump(mode="json"), 201 except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index ffe4e0b492..6f6dadf768 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -10,13 +10,16 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from models import App, EndUser from services.file_service import FileService +register_schema_models(service_api_ns, FileResponse) + @service_api_ns.route("/files/upload") class FileApi(Resource): @@ -31,8 +34,8 @@ class FileApi(Resource): 415: "Unsupported file type", } ) - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) - @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore + @service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__]) def post(self, app_model: App, end_user: EndUser): """Upload a file for use in conversations. @@ -64,4 +67,5 @@ class FileApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index d81287d56f..8dbb690901 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 0a2017e2bd..70b5030237 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource): pipeline=pipeline, user=current_user, args=payload.model_dump(), - invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER, streaming=payload.response_mode == "streaming", ) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 527eef6094..e76649495a 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,9 +1,11 @@ -from flask_restx import reqparse -from flask_restx.inputs import int_range -from pydantic import TypeAdapter +from typing import Literal + +from flask import request +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource @@ -21,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +class ConversationListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at" + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -64,25 +95,8 @@ class ConversationListApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args = request.args.to_dict() + query = ConversationListQuery.model_validate(raw_args) try: with Session(db.engine) as session: @@ -90,11 +104,11 @@ class ConversationListApi(WebApiResource): session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=query.last_id, + limit=query.limit, invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], + pinned=query.pinned, + sort_by=query.sort_by, ) adapter = TypeAdapter(SimpleConversation) conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] @@ -168,16 +182,11 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: conversation = ConversationService.rename( - app_model, conversation_id, end_user, args["name"], args["auto_generate"] + app_model, conversation_id, end_user, payload.name, payload.auto_generate ) return ( TypeAdapter(SimpleConversation) diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 80ad61e549..0036c90800 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,4 @@ from flask import request -from flask_restx import marshal_with import services from controllers.common.errors import ( @@ -9,12 +8,15 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from services.file_service import FileService +register_schema_models(web_ns, FileResponse) + @web_ns.route("/files/upload") class FileApi(WebApiResource): @@ -28,7 +30,7 @@ class FileApi(WebApiResource): 415: "Unsupported file type", } ) - @marshal_with(build_file_model(web_ns)) + @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__]) def post(self, app_model, end_user): """Upload a file for use in web applications. @@ -81,4 +83,5 @@ class FileApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index c1f976829f..b08b3fe858 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from flask_restx import marshal_with from pydantic import BaseModel, Field, HttpUrl import services @@ -14,7 +13,7 @@ from controllers.common.errors import ( from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService from ..common.schema import register_schema_models @@ -26,7 +25,7 @@ class RemoteFileUploadPayload(BaseModel): url: HttpUrl = Field(description="Remote file URL") -register_schema_models(web_ns, RemoteFileUploadPayload) +register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl) @web_ns.route("/remote-files/") @@ -41,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_remote_file_info_model(web_ns)) + @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__]) def get(self, app_model, end_user, url): """Get information about a remote file. @@ -65,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", -1)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", -1)), + ) + return info.model_dump(mode="json") @web_ns.route("/remote-files/upload") @@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_file_with_signed_url_model(web_ns)) + @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__]) def post(self, app_model, end_user): """Upload a file from a remote URL. @@ -139,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload1 = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload1.model_dump(mode="json"), 201 diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 4e20690e9e..29993100f6 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,18 +1,30 @@ -from flask_restx import reqparse -from flask_restx.inputs import int_range -from pydantic import TypeAdapter +from flask import request +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem -from libs.helper import uuid_value +from libs.helper import UUIDStrOrEmpty from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService +class SavedMessageListQuery(BaseModel): + last_id: UUIDStrOrEmpty | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUIDStrOrEmpty + + +register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) + + @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): @web_ns.doc("Get Saved Messages") @@ -42,14 +54,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict() + query = SavedMessageListQuery.model_validate(raw_args) - pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) adapter = TypeAdapter(SavedMessageItem) items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] return SavedMessageInfiniteScrollPagination( @@ -79,11 +87,10 @@ class SavedMessageListApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, args["message_id"]) + SavedMessageService.save(app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index ee092e55c5..a2ae8dec5b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -20,6 +20,8 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion @@ -40,6 +42,7 @@ from models import Workflow from models.enums import UserFrom from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) workflow_entry.graph_engine.layer(persistence_layer) + conversation_variable_layer = ConversationVariablePersistenceLayer( + ConversationVariableUpdater(session_factory.get_session_maker()) + ) + workflow_entry.graph_engine.layer(conversation_variable_layer) for layer in self._graph_engine_layers: workflow_entry.graph_engine.layer(layer) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index da1e9f19b6..4dd95be52d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -358,6 +358,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if node_finish_resp: yield node_finish_resp + # For ANSWER nodes, check if we need to send a message_replace event + # Only send if the final output differs from the accumulated task_state.answer + # This happens when variables were updated by variable_assigner during workflow execution + if event.node_type == NodeType.ANSWER and event.outputs: + final_answer = event.outputs.get("answer") + if final_answer is not None and final_answer != self._task_state.answer: + logger.info( + "ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event", + final_answer, + self._task_state.answer, + ) + # Update the task state answer + self._task_state.answer = str(final_answer) + # Send message_replace event to update the UI + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=str(final_answer), + reason="variable_update", + ) + def _handle_node_failed_events( self, event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent], diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 13eb40fd60..ea4441b5d8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) documents: list[Document] = [] - if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"): + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService for datasource_info in datasource_info_list: @@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator): for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = args.get("original_document_id") or None - if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry: document_id = document_id or documents[i].id document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0cb573cb86..5bc453420d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -42,7 +42,8 @@ class InvokeFrom(StrEnum): # DEBUGGER indicates that this invocation is from # the workflow (or chatflow) edit page. DEBUGGER = "debugger" - PUBLISHED = "published" + # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow. + PUBLISHED_PIPELINE = "published" # VALIDATION indicates that this invocation is from validation. VALIDATION = "validation" diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 79fbafe39e..3f9f3da9b2 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -75,7 +75,7 @@ class AnnotationReplyFeature: AppAnnotationService.add_annotation_history( annotation.id, app_record.id, - annotation.question, + annotation.question_text, annotation.content, query, user_id, diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py new file mode 100644 index 0000000000..77cc00bdc9 --- /dev/null +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -0,0 +1,60 @@ +import logging + +from core.variables import Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers + +logger = logging.getLogger(__name__) + + +class ConversationVariablePersistenceLayer(GraphEngineLayer): + def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None: + super().__init__() + self._conversation_variable_updater = conversation_variable_updater + + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + if not isinstance(event, NodeRunSucceededEvent): + return + if event.node_type != NodeType.VARIABLE_ASSIGNER: + return + if self.graph_runtime_state is None: + return + + updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] + if not updated_variables: + return + + conversation_id = self.graph_runtime_state.system_variable.conversation_id + if conversation_id is None: + return + + updated_any = False + for item in updated_variables: + selector = item.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) + continue + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + continue + variable = self.graph_runtime_state.variable_pool.get(selector) + if not isinstance(variable, Variable): + logger.warning( + "Conversation variable not found in variable pool. selector=%s", + selector, + ) + continue + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) + updated_any = True + + if updated_any: + self._conversation_variable_updater.flush() + + def on_graph_end(self, error: Exception | None) -> None: + pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 61a3e1baca..bf76ae8178 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): """ if isinstance(session_factory, Engine): session_factory = sessionmaker(session_factory) + super().__init__() self._session_maker = session_factory self._state_owner_user_id = state_owner_user_id self._generate_entity = generate_entity @@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer): if not isinstance(event, GraphRunPausedEvent): return - assert self.graph_runtime_state is not None - entity_wrapper: _GenerateEntityUnion if isinstance(self._generate_entity, WorkflowAppGenerateEntity): entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity) diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index fe1a46a945..225b758fcb 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer): trigger_log_id: str, session_maker: sessionmaker[Session], ): + super().__init__() self.trigger_log_id = trigger_log_id self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity @@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer): elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds() # Extract relevant data from result - if not self.graph_runtime_state: - logger.exception("Graph runtime state is not set") - return - outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 50c7249fe4..451e4fda0e 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from configs import dify_config @@ -30,7 +32,7 @@ class DatasourcePlugin(ABC): """ return DatasourceProviderType.LOCAL_FILE - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin: return self.__class__( entity=self.entity.model_copy(), runtime=runtime, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 260dcf04f5..dde7d59726 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from enum import StrEnum from typing import Any @@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DRIVE = "online_drive" @classmethod - def value_of(cls, value: str) -> "DatasourceProviderType": + def value_of(cls, value: str) -> DatasourceProviderType: """ Get value of given mode. @@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter): typ: DatasourceParameterType, required: bool, options: list[str] | None = None, - ) -> "DatasourceParameter": + ) -> DatasourceParameter: """ get a simple datasource parameter @@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "DatasourceInvokeMeta": + def empty(cls) -> DatasourceInvokeMeta: """ Get an empty instance of DatasourceInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "DatasourceInvokeMeta": + def error_instance(cls, error: str) -> DatasourceInvokeMeta: """ Get an instance of DatasourceInvokeMeta with error """ diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py index 1dae2eafd4..45d4bc4594 100644 --- a/api/core/db/session_factory.py +++ b/api/core/db/session_factory.py @@ -1,7 +1,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -_session_maker: sessionmaker | None = None +_session_maker: sessionmaker[Session] | None = None def configure_session_factory(engine: Engine, expire_on_commit: bool = False): @@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False): _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) -def get_session_maker() -> sessionmaker: +def get_session_maker() -> sessionmaker[Session]: if _session_maker is None: raise RuntimeError("Session factory not configured. Call configure_session_factory() first.") return _session_maker @@ -27,7 +27,7 @@ class SessionFactory: configure_session_factory(engine, expire_on_commit) @staticmethod - def get_session_maker() -> sessionmaker: + def get_session_maker() -> sessionmaker[Session]: return get_session_maker() @staticmethod diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7fdf5e4be6..135d2a4945 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from enum import StrEnum @@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel): updated_at: datetime @classmethod - def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity": + def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity: """Create entity from database model with decryption""" return cls( diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 8a8067332d..0078ec7e4f 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import StrEnum, auto from typing import Union @@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel): TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod - def value_of(cls, value: str) -> "ProviderConfig.Type": + def value_of(cls, value: str) -> ProviderConfig.Type: """ Get value of given mode. diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 6d553d7dc6..2ac483673a 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -8,8 +8,9 @@ import urllib.parse from configs import dify_config -def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str: - url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" +def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str: + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) + url = f"{base_url}/files/{upload_file_id}/file-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/file/models.py b/api/core/file/models.py index d149205d77..6324523b22 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -112,17 +112,17 @@ class File(BaseModel): return text - def generate_url(self) -> str | None: + def generate_url(self, for_external: bool = True) -> str | None: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.remote_url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: if self.related_id is None: raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id) + return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: assert self.related_id is not None assert self.extension is not None - return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) + return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external) return None def to_plugin_parameter(self) -> dict[str, Any]: @@ -133,7 +133,7 @@ class File(BaseModel): "extension": self.extension, "size": self.size, "type": self.type, - "url": self.generate_url(), + "url": self.generate_url(for_external=False), } @model_validator(mode="after") diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 6fda073913..5cdea19a8d 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -76,7 +76,7 @@ class TemplateTransformer(ABC): Post-process the result to convert scientific notation strings back to numbers """ - def convert_scientific_notation(value): + def convert_scientific_notation(value: Any) -> Any: if isinstance(value, str): # Check if the string looks like scientific notation if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE): @@ -90,7 +90,7 @@ class TemplateTransformer(ABC): return [convert_scientific_notation(v) for v in value] return value - return convert_scientific_notation(result) # type: ignore[no-any-return] + return convert_scientific_notation(result) @classmethod @abstractmethod diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index c97ae6eac7..84a6fd0d1f 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): request_id: RequestId, request_meta: RequestParams.Meta | None, request: ReceiveRequestT, - session: """BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT - ]""", + session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], ): self.request_id = request_id diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 89dae2dbff..3ac83b4c96 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from collections.abc import Mapping, Sequence from enum import StrEnum, auto @@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum): TOOL = auto() @classmethod - def value_of(cls, value: str) -> "PromptMessageRole": + def value_of(cls, value: str) -> PromptMessageRole: """ Get value of given mode. diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index aee6ce1108..19194d162c 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from decimal import Decimal from enum import StrEnum, auto from typing import Any @@ -20,7 +22,7 @@ class ModelType(StrEnum): TTS = auto() @classmethod - def value_of(cls, origin_model_type: str) -> "ModelType": + def value_of(cls, origin_model_type: str) -> ModelType: """ Get model type from origin model type. @@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum): JSON_SCHEMA = auto() @classmethod - def value_of(cls, value: Any) -> "DefaultParameterName": + def value_of(cls, value: Any) -> DefaultParameterName: """ Get parameter name from value. diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 12a202ce64..28f162a928 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import logging from collections.abc import Sequence @@ -38,7 +40,7 @@ class ModelProviderFactory: plugin_providers = self.get_plugin_model_providers() return [provider.declaration for provider in plugin_providers] - def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]: + def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: """ Get all plugin model providers :return: list of plugin model providers @@ -76,7 +78,7 @@ class ModelProviderFactory: plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) return plugin_model_provider_entity.declaration - def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity": + def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: """ Get plugin model provider :param provider: provider name diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 3b83121357..6674228dc0 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from collections.abc import Mapping, Sequence from datetime import datetime @@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum): return [item.value for item in cls] @classmethod - def of(cls, credential_type: str) -> "CredentialType": + def of(cls, credential_type: str) -> CredentialType: type_name = credential_type.lower() if type_name in {"api-key", "api_key"}: return cls.API_KEY diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index a306f9ba0c..91bb71bfa6 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import json import logging @@ -6,7 +8,7 @@ import re import threading import time import uuid -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import clickzetta # type: ignore from pydantic import BaseModel, model_validator @@ -76,7 +78,7 @@ class ClickzettaConnectionPool: Manages connection reuse across ClickzettaVector instances. """ - _instance: Optional["ClickzettaConnectionPool"] = None + _instance: ClickzettaConnectionPool | None = None _lock = threading.Lock() def __init__(self): @@ -89,7 +91,7 @@ class ClickzettaConnectionPool: self._start_cleanup_thread() @classmethod - def get_instance(cls) -> "ClickzettaConnectionPool": + def get_instance(cls) -> ClickzettaConnectionPool: """Get singleton instance of connection pool.""" if cls._instance is None: with cls._lock: @@ -104,7 +106,7 @@ class ClickzettaConnectionPool: f"{config.workspace}:{config.vcluster}:{config.schema_name}" ) - def _create_connection(self, config: ClickzettaConfig) -> "Connection": + def _create_connection(self, config: ClickzettaConfig) -> Connection: """Create a new ClickZetta connection.""" max_retries = 3 retry_delay = 1.0 @@ -134,7 +136,7 @@ class ClickzettaConnectionPool: raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") - def _configure_connection(self, connection: "Connection"): + def _configure_connection(self, connection: Connection): """Configure connection session settings.""" try: with connection.cursor() as cursor: @@ -181,7 +183,7 @@ class ClickzettaConnectionPool: except Exception: logger.exception("Failed to configure connection, continuing with defaults") - def _is_connection_valid(self, connection: "Connection") -> bool: + def _is_connection_valid(self, connection: Connection) -> bool: """Check if connection is still valid.""" try: with connection.cursor() as cursor: @@ -190,7 +192,7 @@ class ClickzettaConnectionPool: except Exception: return False - def get_connection(self, config: ClickzettaConfig) -> "Connection": + def get_connection(self, config: ClickzettaConfig) -> Connection: """Get a connection from the pool or create a new one.""" config_key = self._get_config_key(config) @@ -221,7 +223,7 @@ class ClickzettaConnectionPool: # No valid connection found, create new one return self._create_connection(config) - def return_connection(self, config: ClickzettaConfig, connection: "Connection"): + def return_connection(self, config: ClickzettaConfig, connection: Connection): """Return a connection to the pool.""" config_key = self._get_config_key(config) @@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector): self._connection_pool = ClickzettaConnectionPool.get_instance() self._init_write_queue() - def _get_connection(self) -> "Connection": + def _get_connection(self) -> Connection: """Get a connection from the pool.""" return self._connection_pool.get_connection(self._config) - def _return_connection(self, connection: "Connection"): + def _return_connection(self, connection: Connection): """Return a connection to the pool.""" self._connection_pool.return_connection(self._config, connection) class ConnectionContext: """Context manager for borrowing and returning connections.""" - def __init__(self, vector_instance: "ClickzettaVector"): + def __init__(self, vector_instance: ClickzettaVector): self.vector = vector_instance self.connection: Connection | None = None - def __enter__(self) -> "Connection": + def __enter__(self) -> Connection: self.connection = self.vector._get_connection() return self.connection @@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector): if self.connection: self.vector._return_connection(self.connection) - def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + def get_connection_context(self) -> ClickzettaVector.ConnectionContext: """Get a connection context manager.""" return self.ConnectionContext(self) @@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector): """Return the vector database type.""" return "clickzetta" - def _ensure_connection(self) -> "Connection": + def _ensure_connection(self) -> Connection: """Get a connection from the pool.""" return self._get_connection() @@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector): # No need for dataset_id filter since each dataset has its own table - # Use simple quote escaping for LIKE clause - escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query).replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'") where_clause = " AND ".join(filter_clauses) search_sql = f""" diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index b1bfabb76e..5bdb0af0b3 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -287,11 +287,15 @@ class IrisVector(BaseVector): cursor.execute(sql, (query,)) else: # Fallback to LIKE search (inefficient for large datasets) - query_pattern = f"%{query}%" + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query) + query_pattern = f"%{escaped_query}%" sql = f""" SELECT TOP {top_k} id, text, meta FROM {self.schema}.{self.table_name} - WHERE text LIKE ? + WHERE text LIKE ? ESCAPE '\\' """ cursor.execute(sql, (query_pattern,)) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 84d1e26b34..b48dd93f04 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -66,6 +66,8 @@ class WeaviateVector(BaseVector): in a Weaviate collection. """ + _DOCUMENT_ID_PROPERTY = "document_id" + def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): """ Initializes the Weaviate vector store. @@ -353,15 +355,12 @@ class WeaviateVector(BaseVector): return [] col = self._client.collections.use(self._collection_name) - props = list({*self._attributes, "document_id", Field.TEXT_KEY.value}) + props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value}) where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -408,10 +407,7 @@ class WeaviateVector(BaseVector): where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 1fe74d3042..69adac522d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Sequence from typing import Any @@ -22,7 +24,7 @@ class DatasetDocumentStore: self._document_id = document_id @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": + def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore: return cls(**config_dict) def to_dict(self) -> dict[str, Any]: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index f67f613e9d..511f5a698d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -7,10 +7,11 @@ import re import tempfile import uuid from urllib.parse import urlparse -from xml.etree import ElementTree import httpx from docx import Document as DocxDocument +from docx.oxml.ns import qn +from docx.text.run import Run from configs import dify_config from core.helper import ssrf_proxy @@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor): image_map = self._extract_images_from_docx(doc) - hyperlinks_url = None - url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") - for para in doc.paragraphs: - for run in para.runs: - if run.text and hyperlinks_url: - result = f" [{run.text}]({hyperlinks_url}) " - run.text = result - hyperlinks_url = None - if "HYPERLINK" in run.element.xml: - try: - xml = ElementTree.XML(run.element.xml) - x_child = [c for c in xml.iter() if c is not None] - for x in x_child: - if x is None: - continue - if x.tag.endswith("instrText"): - if x.text is None: - continue - for i in url_pattern.findall(x.text): - hyperlinks_url = str(i) - except Exception: - logger.exception("Failed to parse HYPERLINK xml") - def parse_paragraph(paragraph): - paragraph_content = [] - - def append_image_link(image_id, has_drawing): + def append_image_link(image_id, has_drawing, target_buffer): """Helper to append image link from image_map based on relationship type.""" rel = doc.part.rels[image_id] if rel.is_external: if image_id in image_map and not has_drawing: - paragraph_content.append(image_map[image_id]) + target_buffer.append(image_map[image_id]) else: image_part = rel.target_part if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) - for run in paragraph.runs: + def process_run(run, target_buffer): + # Helper to extract text and embedded images from a run element and append them to target_buffer if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): # Process drawing type images drawing_elements = run.element.findall( @@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor): # External image: use embed_id as key if embed_id in image_map: has_drawing = True - paragraph_content.append(image_map[embed_id]) + target_buffer.append(image_map[embed_id]) else: # Internal image: use target_part as key image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: has_drawing = True - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) # Process pict type images shape_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" @@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) # Find imagedata element in VML image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") if image_data is not None: @@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) if run.text.strip(): - paragraph_content.append(run.text.strip()) + target_buffer.append(run.text.strip()) + + def process_hyperlink(hyperlink_elem, target_buffer): + # Helper to extract text from a hyperlink element and append it to target_buffer + r_id = hyperlink_elem.get(qn("r:id")) + + # Extract text from runs inside the hyperlink + link_text_parts = [] + for run_elem in hyperlink_elem.findall(qn("w:r")): + run = Run(run_elem, paragraph) + # Hyperlink text may be split across multiple runs (e.g., with different formatting), + # so collect all run texts first + if run.text: + link_text_parts.append(run.text) + + link_text = "".join(link_text_parts).strip() + + # Resolve URL + if r_id: + try: + rel = doc.part.rels.get(r_id) + if rel and rel.is_external: + link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})" + except Exception: + logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id) + + if link_text: + target_buffer.append(link_text) + + paragraph_content = [] + # State for legacy HYPERLINK fields + hyperlink_field_url = None + hyperlink_field_text_parts: list = [] + is_collecting_field_text = False + # Iterate through paragraph elements in document order + for child in paragraph._element: + tag = child.tag + if tag == qn("w:r"): + # Regular run + run = Run(child, paragraph) + + # Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks + fld_chars = child.findall(qn("w:fldChar")) + instr_texts = child.findall(qn("w:instrText")) + + # Handle Fields + if fld_chars or instr_texts: + # Process instrText to find HYPERLINK "url" + for instr in instr_texts: + if instr.text and "HYPERLINK" in instr.text: + # Quick regex to extract URL + match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE) + if match: + hyperlink_field_url = match.group(1) + + # Process fldChar + for fld_char in fld_chars: + fld_char_type = fld_char.get(qn("w:fldCharType")) + if fld_char_type == "begin": + # Start of a field: reset legacy link state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + elif fld_char_type == "separate": + # Separator: if we found a URL, start collecting visible text + if hyperlink_field_url: + is_collecting_field_text = True + elif fld_char_type == "end": + # End of field + if is_collecting_field_text and hyperlink_field_url: + # Create markdown link and append to main content + display_text = "".join(hyperlink_field_text_parts).strip() + if display_text: + link_md = f"[{display_text}]({hyperlink_field_url})" + paragraph_content.append(link_md) + # Reset state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + + # Decide where to append content + target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content + process_run(run, target_buffer) + elif tag == qn("w:hyperlink"): + process_hyperlink(child, paragraph_content) return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py index 7472598a7f..bf8db95b4e 100644 --- a/api/core/rag/pipeline/queue.py +++ b/api/core/rag/pipeline/queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Sequence from typing import Any @@ -16,7 +18,7 @@ class TaskWrapper(BaseModel): return self.model_dump_json() @classmethod - def deserialize(cls, serialized_data: str) -> "TaskWrapper": + def deserialize(cls, serialized_data: str) -> TaskWrapper: return cls.model_validate_json(serialized_data) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4ec59940e3..f8f85d141a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -515,6 +515,7 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + dataset_count = len(available_datasets) with measure_time() as timer: cancel_event = threading.Event() thread_exceptions: list[Exception] = [] @@ -537,6 +538,7 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": query, "attachment_id": None, + "dataset_count": dataset_count, "cancel_event": cancel_event, "thread_exceptions": thread_exceptions, }, @@ -562,6 +564,7 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": None, "attachment_id": attachment_id, + "dataset_count": dataset_count, "cancel_event": cancel_event, "thread_exceptions": thread_exceptions, }, @@ -1195,18 +1198,24 @@ class DatasetRetrieval: json_field = DatasetDocument.doc_metadata[metadata_name].as_string() + from libs.helper import escape_like_pattern + match condition: case "contains": - filters.append(json_field.like(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}%", escape="\\")) case "not contains": - filters.append(json_field.notlike(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\")) case "start with": - filters.append(json_field.like(f"{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"{escaped_value}%", escape="\\")) case "end with": - filters.append(json_field.like(f"%{value}")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}", escape="\\")) case "is" | "=": if isinstance(value, str): @@ -1422,6 +1431,7 @@ class DatasetRetrieval: score_threshold: float, query: str | None, attachment_id: str | None, + dataset_count: int, cancel_event: threading.Event | None = None, thread_exceptions: list[Exception] | None = None, ): @@ -1470,37 +1480,38 @@ class DatasetRetrieval: if cancel_event and cancel_event.is_set(): break - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - if query: - all_documents_item = data_post_processor.invoke( - query=query, - documents=all_documents_item, - score_threshold=score_threshold, - top_n=top_k, - query_type=QueryType.TEXT_QUERY, - ) - if attachment_id: - all_documents_item = data_post_processor.invoke( - documents=all_documents_item, - score_threshold=score_threshold, - top_n=top_k, - query_type=QueryType.IMAGE_QUERY, - query=attachment_id, - ) - else: - if index_type == IndexTechniqueType.ECONOMY: - if not query: - all_documents_item = [] - else: - all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) - elif index_type == IndexTechniqueType.HIGH_QUALITY: - all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + # Skip second reranking when there is only one dataset + if reranking_enable and dataset_count > 1: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) else: - all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item - if all_documents_item: - all_documents.extend(all_documents_item) + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) except Exception as e: if cancel_event: cancel_event.set() diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index 51bfae1cd3..b4ecfe47ff 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import logging import threading from collections.abc import Mapping, MutableMapping from pathlib import Path -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar class SchemaRegistry: @@ -11,7 +13,7 @@ class SchemaRegistry: logger: ClassVar[logging.Logger] = logging.getLogger(__name__) - _default_instance: ClassVar[Optional["SchemaRegistry"]] = None + _default_instance: ClassVar[SchemaRegistry | None] = None _lock: ClassVar[threading.Lock] = threading.Lock() def __init__(self, base_dir: str): @@ -20,7 +22,7 @@ class SchemaRegistry: self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {} @classmethod - def default_registry(cls) -> "SchemaRegistry": + def default_registry(cls) -> SchemaRegistry: """Returns the default schema registry for builtin schemas (thread-safe singleton)""" if cls._default_instance is None: with cls._lock: diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 8ca4eabb7a..ebd200a822 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy @@ -24,7 +26,7 @@ class Tool(ABC): self.entity = entity self.runtime = runtime - def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool: """ fork a new tool with metadata :return: the new tool @@ -166,7 +168,7 @@ class Tool(ABC): type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) ) - def create_file_message(self, file: "File") -> ToolInvokeMessage: + def create_file_message(self, file: File) -> ToolInvokeMessage: return ToolInvokeMessage( type=ToolInvokeMessage.MessageType.FILE, message=ToolInvokeMessage.FileMessage(), diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 84efefba07..51b0407886 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.__base.tool import Tool @@ -24,7 +26,7 @@ class BuiltinTool(Tool): super().__init__(**kwargs) self.provider = provider - def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool: """ fork a new tool with metadata :return: the new tool diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 0cc992155a..e2f6c00555 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pydantic import Field from sqlalchemy import select @@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController): self.tools = [] @classmethod - def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": + def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController: credentials_schema = [ ProviderConfig( name="auth_type", diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 583a3584f7..b5c7a6310c 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import contextlib from collections.abc import Mapping @@ -55,7 +57,7 @@ class ToolProviderType(StrEnum): MCP = auto() @classmethod - def value_of(cls, value: str) -> "ToolProviderType": + def value_of(cls, value: str) -> ToolProviderType: """ Get value of given mode. @@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum): OPENAI_ACTIONS = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderSchemaType": + def value_of(cls, value: str) -> ApiProviderSchemaType: """ Get value of given mode. @@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum): API_KEY_QUERY = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderAuthType": + def value_of(cls, value: str) -> ApiProviderAuthType: """ Get value of given mode. @@ -307,7 +309,7 @@ class ToolParameter(PluginParameter): typ: ToolParameterType, required: bool, options: list[str] | None = None, - ) -> "ToolParameter": + ) -> ToolParameter: """ get a simple tool parameter @@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "ToolInvokeMeta": + def empty(cls) -> ToolInvokeMeta: """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "ToolInvokeMeta": + def error_instance(cls, error: str) -> ToolInvokeMeta: """ Get an instance of ToolInvokeMeta with error """ diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 96917045e3..ef9e9c103a 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import logging @@ -118,7 +120,7 @@ class MCPTool(Tool): for item in json_list: yield self.create_json_message(item) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool: return MCPTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 828dc3b810..d3a2ad488c 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Generator from typing import Any @@ -46,7 +48,7 @@ class PluginTool(Tool): message_id=message_id, ) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool: return PluginTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index fef3157f27..22e099deba 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -7,12 +7,12 @@ import time from configs import dify_config -def sign_tool_file(tool_file_id: str, extension: str) -> str: +def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str: """ sign file to get a temporary url for plugin access """ - # Use internal URL for plugin/tool file access in Docker environments - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + # Use internal URL for plugin/tool file access in Docker environments, unless for_external is True + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 5422f5250b..a706f101ca 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping from pydantic import Field @@ -47,7 +49,7 @@ class WorkflowToolProviderController(ToolProviderController): self.provider_id = provider_id @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": + def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: with session_factory.create_session() as session, session.begin(): app = session.get(App, db_provider.app_id) if not app: diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 30334f5da8..81a1d54199 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging from collections.abc import Generator, Mapping, Sequence @@ -181,7 +183,7 @@ class WorkflowTool(Tool): return found return None - def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool: """ fork a new tool with metadata diff --git a/api/core/variables/types.py b/api/core/variables/types.py index ce71711344..13b926c978 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Mapping from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.file.models import File @@ -52,7 +54,7 @@ class SegmentType(StrEnum): return self in _ARRAY_TYPES @classmethod - def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + def infer_segment_type(cls, value: Any) -> SegmentType | None: """ Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. @@ -173,7 +175,7 @@ class SegmentType(StrEnum): raise AssertionError("this statement should be unreachable.") @staticmethod - def cast_value(value: Any, type_: "SegmentType"): + def cast_value(value: Any, type_: SegmentType): # Cast Python's `bool` type to `int` when the runtime type requires # an integer or number. # @@ -193,7 +195,7 @@ class SegmentType(StrEnum): return [int(i) for i in value] return value - def exposed_type(self) -> "SegmentType": + def exposed_type(self) -> SegmentType: """Returns the type exposed to the frontend. The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. @@ -202,7 +204,7 @@ class SegmentType(StrEnum): return SegmentType.NUMBER return self - def element_type(self) -> "SegmentType | None": + def element_type(self) -> SegmentType | None: """Return the element type of the current segment type, or `None` if the element type is undefined. Raises: @@ -217,7 +219,7 @@ class SegmentType(StrEnum): return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) @staticmethod - def get_zero_value(t: "SegmentType"): + def get_zero_value(t: SegmentType): # Lazy import to avoid circular dependency from factories import variable_factory diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md index 72f5dbe1e2..9a39f976a6 100644 --- a/api/core/workflow/README.md +++ b/api/core/workflow/README.md @@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO")) engine.layer(ExecutionLimitsLayer(max_nodes=100)) ``` +`engine.layer()` binds the read-only runtime state before execution, so layer hooks +can assume `graph_runtime_state` is available. + ### Event-Driven Architecture All node executions emit events for monitoring and integration: diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index a8a86d3db2..1b3fb36f1f 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain implementation details like tenant_id, app_id, etc. """ +from __future__ import annotations + from collections.abc import Mapping from datetime import datetime from typing import Any @@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any], inputs: Mapping[str, Any], started_at: datetime, - ) -> "WorkflowExecution": + ) -> WorkflowExecution: return WorkflowExecution( id_=id_, workflow_id=workflow_id, diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index ba5a01fc94..7be94c2426 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import defaultdict from collections.abc import Mapping, Sequence @@ -175,7 +177,7 @@ class Graph: def _create_node_instances( cls, node_configs_map: dict[str, dict[str, object]], - node_factory: "NodeFactory", + node_factory: NodeFactory, ) -> dict[str, Node]: """ Create node instances from configurations using the node factory. @@ -197,7 +199,7 @@ class Graph: return nodes @classmethod - def new(cls) -> "GraphBuilder": + def new(cls) -> GraphBuilder: """Create a fluent builder for assembling a graph programmatically.""" return GraphBuilder(graph_cls=cls) @@ -284,9 +286,9 @@ class Graph: cls, *, graph_config: Mapping[str, object], - node_factory: "NodeFactory", + node_factory: NodeFactory, root_node_id: str | None = None, - ) -> "Graph": + ) -> Graph: """ Initialize graph @@ -383,7 +385,7 @@ class GraphBuilder: self._edges: list[Edge] = [] self._edge_counter = 0 - def add_root(self, node: Node) -> "GraphBuilder": + def add_root(self, node: Node) -> GraphBuilder: """Register the root node. Must be called exactly once.""" if self._nodes: @@ -398,7 +400,7 @@ class GraphBuilder: *, from_node_id: str | None = None, source_handle: str = "source", - ) -> "GraphBuilder": + ) -> GraphBuilder: """Append a node and connect it from the specified predecessor.""" if not self._nodes: @@ -419,7 +421,7 @@ class GraphBuilder: return self - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder": + def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: """Connect two existing nodes without adding a new node.""" if tail not in self._nodes_by_id: diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 4be3adb8f8..0fccd4a0fd 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue. import json from typing import TYPE_CHECKING, Any, final -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand if TYPE_CHECKING: from extensions.ext_redis import RedisClientWrapper @@ -113,6 +113,8 @@ class RedisChannel: return AbortCommand.model_validate(data) if command_type == CommandType.PAUSE: return PauseCommand.model_validate(data) + if command_type == CommandType.UPDATE_VARIABLES: + return UpdateVariablesCommand.model_validate(data) # For other command types, use base class return GraphEngineCommand.model_validate(data) diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py index 837f5e55fd..7b4f0dfff7 100644 --- a/api/core/workflow/graph_engine/command_processing/__init__.py +++ b/api/core/workflow/graph_engine/command_processing/__init__.py @@ -5,11 +5,12 @@ This package handles external commands sent to the engine during execution. """ -from .command_handlers import AbortCommandHandler, PauseCommandHandler +from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler from .command_processor import CommandProcessor __all__ = [ "AbortCommandHandler", "CommandProcessor", "PauseCommandHandler", + "UpdateVariablesCommandHandler", ] diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index e9f109c88c..cfe856d9e8 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -4,9 +4,10 @@ from typing import final from typing_extensions import override from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.runtime import VariablePool from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand from .command_processor import CommandHandler logger = logging.getLogger(__name__) @@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler): reason = command.reason pause_reason = SchedulingPause(message=reason) execution.pause(pause_reason) + + +@final +class UpdateVariablesCommandHandler(CommandHandler): + def __init__(self, variable_pool: VariablePool) -> None: + self._variable_pool = variable_pool + + @override + def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: + assert isinstance(command, UpdateVariablesCommand) + for update in command.updates: + try: + variable = update.value + self._variable_pool.add(variable.selector, variable) + logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) + except ValueError as exc: + logger.warning( + "Skipping invalid variable selector %s for workflow %s: %s", + getattr(update.value, "selector", None), + execution.workflow_id, + exc, + ) diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 0d51b2b716..6dce03c94d 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine instance to control its execution flow. """ -from enum import StrEnum +from collections.abc import Sequence +from enum import StrEnum, auto from typing import Any from pydantic import BaseModel, Field +from core.variables.variables import VariableUnion + class CommandType(StrEnum): """Types of commands that can be sent to GraphEngine.""" - ABORT = "abort" - PAUSE = "pause" + ABORT = auto() + PAUSE = auto() + UPDATE_VARIABLES = auto() class GraphEngineCommand(BaseModel): @@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand): command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") reason: str = Field(default="unknown reason", description="reason for pause") + + +class VariableUpdate(BaseModel): + """Represents a single variable update instruction.""" + + value: VariableUnion = Field(description="New variable value") + + +class UpdateVariablesCommand(GraphEngineCommand): + """Command to update a group of variables in the variable pool.""" + + command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") + updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2e8b8f345f..9a870d7bf5 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,9 +5,12 @@ This engine uses a modular architecture with separated packages following Domain-Driven Design principles for improved maintainability and testability. """ +from __future__ import annotations + import contextvars import logging import queue +import threading from collections.abc import Generator from typing import TYPE_CHECKING, cast, final @@ -30,8 +33,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr if TYPE_CHECKING: # pragma: no cover - used only for static analysis from core.workflow.runtime.graph_runtime_state import GraphProtocol -from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler -from .entities.commands import AbortCommand, PauseCommand +from .command_processing import ( + AbortCommandHandler, + CommandProcessor, + PauseCommandHandler, + UpdateVariablesCommandHandler, +) +from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager from .graph_state_manager import GraphStateManager @@ -70,10 +78,13 @@ class GraphEngine: scale_down_idle_time: float | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" + # stop event + self._stop_event = threading.Event() # Bind runtime state to current workflow context self._graph = graph self._graph_runtime_state = graph_runtime_state + self._graph_runtime_state.stop_event = self._stop_event self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel @@ -140,6 +151,9 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) + self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) + # === Extensibility === # Layers allow plugins to extend engine functionality self._layers: list[GraphEngineLayer] = [] @@ -169,6 +183,7 @@ class GraphEngine: max_workers=self._max_workers, scale_up_threshold=self._scale_up_threshold, scale_down_idle_time=self._scale_down_idle_time, + stop_event=self._stop_event, ) # === Orchestration === @@ -199,6 +214,7 @@ class GraphEngine: event_handler=self._event_handler_registry, execution_coordinator=self._execution_coordinator, event_emitter=self._event_manager, + stop_event=self._stop_event, ) # === Validation === @@ -212,9 +228,16 @@ class GraphEngine: if id(node.graph_runtime_state) != expected_state_id: raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - def layer(self, layer: GraphEngineLayer) -> "GraphEngine": + def _bind_layer_context( + self, + layer: GraphEngineLayer, + ) -> None: + layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) + + def layer(self, layer: GraphEngineLayer) -> GraphEngine: """Add a layer for extending functionality.""" self._layers.append(layer) + self._bind_layer_context(layer) return self def run(self) -> Generator[GraphEngineEvent, None, None]: @@ -301,14 +324,7 @@ class GraphEngine: def _initialize_layers(self) -> None: """Initialize layers with context.""" self._event_manager.set_layers(self._layers) - # Create a read-only wrapper for the runtime state - read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state) for layer in self._layers: - try: - layer.initialize(read_only_state, self._command_channel) - except Exception as e: - logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) - try: layer.on_graph_start() except Exception as e: @@ -316,6 +332,7 @@ class GraphEngine: def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" + self._stop_event.clear() paused_nodes: list[str] = [] if resume: paused_nodes = self._graph_runtime_state.consume_paused_nodes() @@ -343,13 +360,12 @@ class GraphEngine: def _stop_execution(self) -> None: """Stop execution subsystems.""" + self._stop_event.set() self._dispatcher.stop() self._worker_pool.stop() # Don't mark complete here as the dispatcher already does it # Notify layers - logger = logging.getLogger(__name__) - for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md index 17845ee1f0..b0f295037c 100644 --- a/api/core/workflow/graph_engine/layers/README.md +++ b/api/core/workflow/graph_engine/layers/README.md @@ -8,7 +8,7 @@ Pluggable middleware for engine extensions. Abstract base class for layers. -- `initialize()` - Receive runtime context +- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) - `on_graph_start()` - Execution start hook - `on_event()` - Process all events - `on_graph_end()` - Execution end hook @@ -34,6 +34,9 @@ engine.layer(debug_layer) engine.run() ``` +`engine.layer()` binds the read-only runtime state before execution, so +`graph_runtime_state` is always available inside layer hooks. + ## Custom Layers ```python diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 780f92a0f4..89293b9b30 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState +class GraphEngineLayerNotInitializedError(Exception): + """Raised when a layer's runtime state is accessed before initialization.""" + + def __init__(self, layer_name: str | None = None) -> None: + name = layer_name or "GraphEngineLayer" + super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") + + class GraphEngineLayer(ABC): """ Abstract base class for GraphEngine layers. @@ -28,22 +36,27 @@ class GraphEngineLayer(ABC): def __init__(self) -> None: """Initialize the layer. Subclasses can override with custom parameters.""" - self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None + self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None self.command_channel: CommandChannel | None = None + @property + def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: + if self._graph_runtime_state is None: + raise GraphEngineLayerNotInitializedError(type(self).__name__) + return self._graph_runtime_state + def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: """ Initialize the layer with engine dependencies. - Called by GraphEngine before execution starts to inject the read-only runtime state - and command channel. This allows layers to observe engine context and send - commands, but prevents direct state modification. - + Called by GraphEngine to inject the read-only runtime state and command channel. + This is invoked when the layer is registered with a `GraphEngine` instance. + Implementations should be idempotent. Args: graph_runtime_state: Read-only view of the runtime state command_channel: Channel for sending commands to the engine """ - self.graph_runtime_state = graph_runtime_state + self._graph_runtime_state = graph_runtime_state self.command_channel = command_channel @abstractmethod diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index 034ebcf54f..e0402cd09c 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info("=" * 80) self.logger.info("🚀 GRAPH EXECUTION STARTED") self.logger.info("=" * 80) - - if self.graph_runtime_state: - # Log initial state - self.logger.info("Initial State:") + # Log initial state + self.logger.info("Initial State:") @override def on_event(self, event: GraphEngineEvent) -> None: @@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info(" Node retries: %s", self.retry_count) # Log final state if available - if self.graph_runtime_state and self.include_outputs: - if self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) + if self.include_outputs and self.graph_runtime_state.outputs: + self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) self.logger.info("=" * 80) diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/workflow/graph_engine/layers/persistence.py index b70f36ec9e..e81df4f3b7 100644 --- a/api/core/workflow/graph_engine/layers/persistence.py +++ b/api/core/workflow/graph_engine/layers/persistence.py @@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer): if update_finished: execution.finished_at = naive_utc_now() runtime_state = self.graph_runtime_state - if runtime_state is None: - return execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs @@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _system_variables(self) -> Mapping[str, Any]: runtime_state = self.graph_runtime_state - if runtime_state is None: - return {} return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index 0577ba8f02..d2cfa755d9 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel. This module provides a simplified interface for controlling workflow executions using the new Redis command channel, without requiring user permission checks. -Supports stop, pause, and resume operations. """ import logging +from collections.abc import Sequence from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + GraphEngineCommand, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -23,7 +29,6 @@ class GraphEngineManager: This class provides a simple interface for controlling workflow executions by sending commands through Redis channels, without user validation. - Supports stop and pause operations. """ @staticmethod @@ -45,6 +50,16 @@ class GraphEngineManager: pause_command = PauseCommand(reason=reason or "User requested pause") GraphEngineManager._send_command(task_id, pause_command) + @staticmethod + def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None: + """Send a command to update variables in a running workflow.""" + + if not updates: + return + + update_command = UpdateVariablesCommand(updates=updates) + GraphEngineManager._send_command(task_id, update_command) + @staticmethod def _send_command(task_id: str, command: GraphEngineCommand) -> None: """Send a command to the workflow-specific Redis channel.""" diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 334a3f77bf..27439a2412 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -44,6 +44,7 @@ class Dispatcher: event_queue: queue.Queue[GraphNodeEventBase], event_handler: "EventHandler", execution_coordinator: ExecutionCoordinator, + stop_event: threading.Event, event_emitter: EventManager | None = None, ) -> None: """ @@ -61,7 +62,7 @@ class Dispatcher: self._event_emitter = event_emitter self._thread: threading.Thread | None = None - self._stop_event = threading.Event() + self._stop_event = stop_event self._start_time: float | None = None def start(self) -> None: @@ -69,16 +70,14 @@ class Dispatcher: if self._thread and self._thread.is_alive(): return - self._stop_event.clear() self._start_time = time.time() self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) self._thread.start() def stop(self) -> None: """Stop the dispatcher thread.""" - self._stop_event.set() if self._thread and self._thread.is_alive(): - self._thread.join(timeout=10.0) + self._thread.join(timeout=2.0) def _dispatcher_loop(self) -> None: """Main dispatcher loop.""" diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py index 1144e1de69..a9d4f470e5 100644 --- a/api/core/workflow/graph_engine/ready_queue/factory.py +++ b/api/core/workflow/graph_engine/ready_queue/factory.py @@ -2,6 +2,8 @@ Factory for creating ReadyQueue instances from serialized state. """ +from __future__ import annotations + from typing import TYPE_CHECKING from .in_memory import InMemoryReadyQueue @@ -11,7 +13,7 @@ if TYPE_CHECKING: from .protocol import ReadyQueue -def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue": +def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: """ Create a ReadyQueue instance from a serialized state. diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8b7c2e441e..8ceaa428c3 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally by ResponseStreamCoordinator to manage streaming sessions. """ +from __future__ import annotations + from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode @@ -27,7 +29,7 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> "ResponseSession": + def from_node(cls, node: Node) -> ResponseSession: """ Create a ResponseSession from an AnswerNode or EndNode. diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index e37a08ae47..83419830b6 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -42,6 +42,7 @@ class Worker(threading.Thread): event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: Sequence[GraphEngineLayer], + stop_event: threading.Event, worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -65,13 +66,16 @@ class Worker(threading.Thread): self._worker_id = worker_id self._flask_app = flask_app self._context_vars = context_vars - self._stop_event = threading.Event() self._last_task_time = time.time() + self._stop_event = stop_event self._layers = layers if layers is not None else [] def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() + """Worker is controlled via shared stop_event from GraphEngine. + + This method is a no-op retained for backward compatibility. + """ + pass @property def is_idle(self) -> bool: diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 5b9234586b..df76ebe882 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -41,6 +41,7 @@ class WorkerPool: event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: list[GraphEngineLayer], + stop_event: threading.Event, flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -81,6 +82,7 @@ class WorkerPool: self._worker_counter = 0 self._lock = threading.RLock() self._running = False + self._stop_event = stop_event # No longer tracking worker states with callbacks to avoid lock contention @@ -135,7 +137,7 @@ class WorkerPool: # Wait for workers to finish for worker in self._workers: if worker.is_alive(): - worker.join(timeout=10.0) + worker.join(timeout=2.0) self._workers.clear() @@ -152,6 +154,7 @@ class WorkerPool: worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, + stop_event=self._stop_event, ) worker.start() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 4be006de11..234651ce96 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, cast @@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]): variable_pool: VariablePool, node_data: AgentNodeData, for_log: bool = False, - strategy: "PluginAgentStrategy", + strategy: PluginAgentStrategy, ) -> dict[str, Any]: """ Generate parameters based on the given tool parameters, variable pool, and node data. @@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]): def _generate_credentials( self, parameters: dict[str, Any], - ) -> "InvokeCredentials": + ) -> InvokeCredentials: """ Generate credentials based on the given agent parameters. """ @@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]): model_schema.features.remove(feature) return model_schema - def _filter_mcp_type_tool( - self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Filter MCP type tool :param strategy: plugin agent strategy diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 5aab6bbde4..e5a20c8e91 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from abc import ABC from builtins import type as type_ @@ -111,7 +113,7 @@ class DefaultValue(BaseModel): raise DefaultValueTypeError(f"Cannot convert to number: {value}") @model_validator(mode="after") - def validate_value_type(self) -> "DefaultValue": + def validate_value_type(self) -> DefaultValue: # Type validation configuration type_validators = { DefaultValueType.STRING: { diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8ebba3659c..55c8db40ea 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import logging import operator @@ -59,7 +61,7 @@ logger = logging.getLogger(__name__) class Node(Generic[NodeDataT]): - node_type: ClassVar["NodeType"] + node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData @@ -198,14 +200,14 @@ class Node(Generic[NodeDataT]): return None # Global registry populated via __init_subclass__ - _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {} + _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} def __init__( self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params self.id = id @@ -241,7 +243,7 @@ class Node(Generic[NodeDataT]): return @property - def graph_init_params(self) -> "GraphInitParams": + def graph_init_params(self) -> GraphInitParams: return self._graph_init_params @property @@ -264,6 +266,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def _should_stop(self) -> bool: + """Check if execution should be stopped.""" + return self.graph_runtime_state.stop_event.is_set() + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -332,6 +338,21 @@ class Node(Generic[NodeDataT]): yield event else: yield event + + if self._should_stop(): + error_message = "Execution cancelled" + yield NodeRunFailedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + start_at=self._start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error_message, + ), + error=error_message, + ) + return except Exception as e: logger.exception("Node %s failed to run", self._node_id) result = NodeRunResult( @@ -438,7 +459,7 @@ class Node(Generic[NodeDataT]): raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod - def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]: + def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. Import all modules under core.workflow.nodes so subclasses register themselves on import. diff --git a/api/core/workflow/nodes/base/template.py b/api/core/workflow/nodes/base/template.py index ba3e2058cf..81f4b9f6fb 100644 --- a/api/core/workflow/nodes/base/template.py +++ b/api/core/workflow/nodes/base/template.py @@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes, similar to SegmentGroup but focused on template representation without values. """ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -58,7 +60,7 @@ class Template: segments: list[TemplateSegmentUnion] @classmethod - def from_answer_template(cls, template_str: str) -> "Template": + def from_answer_template(cls, template_str: str) -> Template: """Create a Template from an Answer node template string. Example: @@ -107,7 +109,7 @@ class Template: return cls(segments=segments) @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template": + def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: """Create a Template from an End node outputs configuration. End nodes are treated as templates of concatenated variables with newlines. diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index a38e10030a..e3035d3bf0 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,8 +1,7 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast -from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider @@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.limits import CodeNodeLimits from .exc import ( CodeNodeError, @@ -20,9 +20,41 @@ from .exc import ( OutputValidationError, ) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + class CodeNode(Node[CodeNodeData]): node_type = NodeType.CODE + _DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = ( + Python3CodeProvider, + JavascriptCodeProvider, + ) + _limits: CodeNodeLimits + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + code_executor: type[CodeExecutor] | None = None, + code_providers: Sequence[type[CodeNodeProvider]] | None = None, + code_limits: CodeNodeLimits, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor + self._code_providers: tuple[type[CodeNodeProvider], ...] = ( + tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS + ) + self._limits = code_limits @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]): if filters: code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next( + provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language) + ) return code_provider.get_default_config() + @classmethod + def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]: + return cls._DEFAULT_CODE_PROVIDERS + @classmethod def version(cls) -> str: return "1" @@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]): variables[variable_name] = variable.to_object() if variable else None # Run code try: - result = CodeExecutor.execute_workflow_code_template( + _ = self._select_code_provider(code_language) + result = self._code_executor.execute_workflow_code_template( language=code_language, code=code, inputs=variables, @@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) + def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]: + for provider in self._code_providers: + if provider.is_accept_language(code_language): + return provider + raise CodeNodeError(f"Unsupported code language: {code_language}") + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string @@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]): if value is None: return None - if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + if len(value) > self._limits.max_string_length: raise OutputValidationError( f"The length of output variable `{variable}` must be" - f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + f" less than {self._limits.max_string_length} characters" ) return value.replace("\x00", "") @@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]): if value is None: return None - if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: + if value > self._limits.max_number or value < self._limits.min_number: raise OutputValidationError( f"Output variable `{variable}` is out of range," - f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + f" it must be between {self._limits.min_number} and {self._limits.max_number}." ) if isinstance(value, float): decimal_value = Decimal(str(value)).normalize() precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] # raise error if precision is too high - if precision > dify_config.CODE_MAX_PRECISION: + if precision > self._limits.max_precision: raise OutputValidationError( f"Output variable `{variable}` has too high precision," - f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + f" it must be less than {self._limits.max_precision} digits." ) return value @@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]): # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. # Note that `_transform_result` may produce lists containing `None` values, # which don't conform to the type requirements of `Array*Segment` classes. - if depth > dify_config.CODE_MAX_DEPTH: - raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") + if depth > self._limits.max_depth: + raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") transformed_result: dict[str, Any] = {} if output_schema is None: @@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]): f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." ) else: - if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: + if len(value) > self._limits.max_number_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." + f" less than {self._limits.max_number_array_length} elements." ) for i, inner_value in enumerate(value): @@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]): f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: + if len(result[output_name]) > self._limits.max_string_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." + f" less than {self._limits.max_string_array_length} elements." ) transformed_result[output_name] = [ @@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]): f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: + if len(result[output_name]) > self._limits.max_object_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." + f" less than {self._limits.max_object_array_length} elements." ) for i, value in enumerate(result[output_name]): diff --git a/api/core/workflow/nodes/code/limits.py b/api/core/workflow/nodes/code/limits.py new file mode 100644 index 0000000000..a6b9e9e68e --- /dev/null +++ b/api/core/workflow/nodes/code/limits.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class CodeNodeLimits: + max_string_length: int + max_number: int | float + min_number: int | float + max_precision: int + max_depth: int + max_number_array_length: int + max_string_array_length: int + max_object_array_length: int diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 04e2802191..dfb55dcd80 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import io import json @@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]): # Instance attributes specific to LLMNode. # Output variable for file - _file_outputs: list["File"] + _file_outputs: list[File] _llm_file_saver: LLMFileSaver @@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]): self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]): structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]): ) @staticmethod - def _image_file_to_markdown(file: "File", /): + def _image_file_to_markdown(file: File, /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]): def fetch_prompt_messages( *, sys_query: str | None = None, - sys_files: Sequence["File"], + sys_files: Sequence[File], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - context_files: list["File"] | None = None, + context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | LLMResultWithStructuredOutput, saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], reasoning_format: Literal["separated", "tagged"] = "tagged", request_latency: float | None = None, ) -> ModelInvokeCompletedEvent: @@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]): *, content: ImagePromptMessageContent, file_saver: LLMFileSaver, - ) -> "File": + ) -> File: """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]): *, contents: str | list[PromptMessageContentUnionTypes] | None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index c55ad346bf..557d3a330d 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -1,10 +1,21 @@ +from collections.abc import Sequence from typing import TYPE_CHECKING, final from typing_extensions import override +from configs import dify_config +from core.helper.code_executor.code_executor import CodeExecutor +from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.workflow.enums import NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, + Jinja2TemplateRenderer, +) +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from libs.typing import is_str, is_str_dict from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING @@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory): self, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", + *, + code_executor: type[CodeExecutor] | None = None, + code_providers: Sequence[type[CodeNodeProvider]] | None = None, + code_limits: CodeNodeLimits | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state + self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor + self._code_providers: tuple[type[CodeNodeProvider], ...] = ( + tuple(code_providers) if code_providers else CodeNode.default_code_providers() + ) + self._code_limits = code_limits or CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ) + self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() @override def create_node(self, node_config: dict[str, object]) -> Node: @@ -72,6 +103,25 @@ class DifyNodeFactory(NodeFactory): raise ValueError(f"No latest version class found for node type: {node_type}") # Create node instance + if node_type == NodeType.CODE: + return CodeNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + code_executor=self._code_executor, + code_providers=self._code_providers, + code_limits=self._code_limits, + ) + if node_type == NodeType.TEMPLATE_TRANSFORM: + return TemplateTransformNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + template_renderer=self._template_renderer, + ) + return node_class( id=node_id, config=node_config, diff --git a/api/core/workflow/nodes/template_transform/template_renderer.py b/api/core/workflow/nodes/template_transform/template_renderer.py new file mode 100644 index 0000000000..a5f06bf2bb --- /dev/null +++ b/api/core/workflow/nodes/template_transform/template_renderer.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Protocol + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage + + +class TemplateRenderError(ValueError): + """Raised when rendering a Jinja2 template fails.""" + + +class Jinja2TemplateRenderer(Protocol): + """Render Jinja2 templates for template transform nodes.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + """Render a Jinja2 template with provided variables.""" + raise NotImplementedError + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Adapter that renders Jinja2 templates via CodeExecutor.""" + + _code_executor: type[CodeExecutor] + + def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None: + self._code_executor = code_executor or CodeExecutor + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = self._code_executor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, code=template, inputs=variables + ) + except CodeExecutionError as exc: + raise TemplateRenderError(str(exc)) from exc + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2274323960..f7e0bccccf 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,18 +1,44 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from core.workflow.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, + Jinja2TemplateRenderer, + TemplateRenderError, +) + +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = NodeType.TEMPLATE_TRANSFORM + _template_renderer: Jinja2TemplateRenderer + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + template_renderer: Jinja2TemplateRenderer | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): variables[variable_name] = value.to_object() if value else None # Run code try: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables - ) - except CodeExecutionError as e: + rendered = self._template_renderer.render_template(self.node_data.template, variables) + except TemplateRenderError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, @@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} ) @classmethod diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py deleted file mode 100644 index 050e213535..0000000000 --- a/api/core/workflow/nodes/variable_assigner/common/impl.py +++ /dev/null @@ -1,28 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.variables.variables import Variable -from extensions.ext_database import db -from models import ConversationVariable - -from .exc import VariableOperatorNodeError - - -class ConversationVariableUpdaterImpl: - def update(self, conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() - - def flush(self): - pass - - -def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: - return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index da23207b62..d2ea7d94ea 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,9 +1,8 @@ -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, TypeAlias +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: from core.workflow.runtime import GraphRuntimeState -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - - class VariableAssignerNode(Node[VariableAssignerData]): node_type = NodeType.VARIABLE_ASSIGNER - _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY def __init__( self, @@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]): config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, ): super().__init__( id=id, @@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._conv_var_updater_factory = conv_var_updater_factory @classmethod def version(cls) -> str: @@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - # TODO: Move database operation to the pipeline. - # Update conversation variable. - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - raise VariableOperatorNodeError("conversation_id not found") - conv_var_updater = self._conv_var_updater_factory() - conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) - conv_var_updater.flush() updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 389fb54d35..486e6bb6a7 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,24 +1,20 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( - ConversationIDNotFoundError, InputTypeNotSupportedError, InvalidDataError, InvalidInputValueError, @@ -26,6 +22,10 @@ from .exc import ( VariableNotFoundError, ) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ class VariableAssignerNode(Node[VariableAssignerNodeData]): node_type = NodeType.VARIABLE_ASSIGNER + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: """ Check if this Variable Assigner node blocks the output of specific variables. @@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): return False - def _conv_var_updater_factory(self) -> ConversationVariableUpdater: - return conversation_variable_updater_factory() - @classmethod def version(cls) -> str: return "2" @@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) - conv_var_updater = self._conv_var_updater_factory() - # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) if not isinstance(variable, Variable): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value - if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - if self.invoke_from != InvokeFrom.DEBUGGER: - raise ConversationIDNotFoundError - else: - conversation_id = conversation_id.value - conv_var_updater.update( - conversation_id=cast(str, conversation_id), - variable=variable, - ) - conv_var_updater.flush() updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/core/workflow/repositories/draft_variable_repository.py index 97bfcd5666..66ef714c16 100644 --- a/api/core/workflow/repositories/draft_variable_repository.py +++ b/api/core/workflow/repositories/draft_variable_repository.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from collections.abc import Mapping from typing import Any, Protocol @@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol): node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, - ) -> "DraftVariableSaver": + ) -> DraftVariableSaver: pass diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 1561b789df..401cecc162 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -2,6 +2,7 @@ from __future__ import annotations import importlib import json +import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass @@ -168,6 +169,7 @@ class GraphRuntimeState: self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() + self.stop_event: threading.Event = threading.Event() if graph is not None: self.attach_graph(graph) diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py index 5e0878e873..bfbb5ba704 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.model_runtime.entities.llm_entities import LLMUsage @@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView class ReadOnlyVariablePool(Protocol): """Read-only interface for VariablePool.""" - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Get a variable value (read-only).""" ... diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py index 8539727fd6..d3e4c60d9b 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any @@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper: def __init__(self, variable_pool: VariablePool) -> None: self._variable_pool = variable_pool - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Return a copy of a variable value if present.""" - value = self._variable_pool.get([node_id, variable_key]) + value = self._variable_pool.get(selector) return deepcopy(value) if value is not None else None def get_all_by_node(self, node_id: str) -> Mapping[str, object]: diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 7fbaec9e70..85ceb9d59e 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import defaultdict from collections.abc import Mapping, Sequence @@ -267,6 +269,6 @@ class VariablePool(BaseModel): self.add(selector, value) @classmethod - def empty(cls) -> "VariablePool": + def empty(cls) -> VariablePool: """Create an empty variable pool.""" return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py index ad925912a4..cda8091771 100644 --- a/api/core/workflow/system_variable.py +++ b/api/core/workflow/system_variable.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence from types import MappingProxyType from typing import Any @@ -70,7 +72,7 @@ class SystemVariable(BaseModel): return data @classmethod - def empty(cls) -> "SystemVariable": + def empty(cls) -> SystemVariable: return cls() def to_dict(self) -> dict[SystemVariableKey, Any]: @@ -114,7 +116,7 @@ class SystemVariable(BaseModel): d[SystemVariableKey.TIMESTAMP] = self.timestamp return d - def as_view(self) -> "SystemVariableReadOnlyView": + def as_view(self) -> SystemVariableReadOnlyView: return SystemVariableReadOnlyView(self) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 5a69eb15ac..c0279f893b 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -3,8 +3,9 @@ set -e # Set UTF-8 encoding to address potential encoding issues in containerized environments -export LANG=${LANG:-en_US.UTF-8} -export LC_ALL=${LC_ALL:-en_US.UTF-8} +# Use C.UTF-8 which is universally available in all containers +export LANG=${LANG:-C.UTF-8} +export LC_ALL=${LC_ALL:-C.UTF-8} export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8} if [[ "${MIGRATION_ENABLED}" == "true" ]]; then diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index cf994c11df..7d13f0c061 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) +EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE) EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") @@ -42,10 +43,28 @@ def init_app(app: DifyApp): _apply_cors_once( web_bp, - resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, - supports_credentials=True, - allow_headers=list(AUTHENTICATED_HEADERS), - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + resources={ + # Embedded bot endpoints (unauthenticated, cross-origin safe) + r"^/chat-messages$": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + r"^/chat-messages/.*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + # Default web application endpoints (authenticated) + r"/*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": True, + "allow_headers": list(AUTHENTICATED_HEADERS), + "methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + }, + }, expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(web_bp) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 71a63168a5..daa3756dba 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -11,6 +11,7 @@ def init_app(app: DifyApp): create_tenant, extract_plugins, extract_unique_plugins, + file_usage, fix_app_site_missing, install_plugins, install_rag_pipeline_plugins, @@ -47,6 +48,7 @@ def init_app(app: DifyApp): clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, remove_orphaned_files_on_storage, + file_usage, setup_system_tool_oauth_client, setup_system_trigger_oauth_client, cleanup_orphaned_draft_variables, diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index c90b1d0a9f..2e0d4c889a 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -53,3 +53,10 @@ def _setup_gevent_compatibility(): def init_app(app: DifyApp): db.init_app(app) _setup_gevent_compatibility() + + # Eagerly build the engine so pool_size/max_overflow/etc. come from config + try: + with app.app_context(): + _ = db.engine # triggers engine creation with the configured options + except Exception: + logger.exception("Failed to initialize SQLAlchemy engine during app startup") diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py index 22d1f473a3..8c64a25be4 100644 --- a/api/extensions/logstore/aliyun_logstore.py +++ b/api/extensions/logstore/aliyun_logstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import threading @@ -33,7 +35,7 @@ class AliyunLogStore: Ensures only one instance exists to prevent multiple PG connection pools. """ - _instance: "AliyunLogStore | None" = None + _instance: AliyunLogStore | None = None _initialized: bool = False # Track delayed PG connection for newly created projects @@ -66,7 +68,7 @@ class AliyunLogStore: "\t", ] - def __new__(cls) -> "AliyunLogStore": + def __new__(cls) -> AliyunLogStore: """Implement singleton pattern.""" if cls._instance is None: cls._instance = super().__new__(cls) diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 6e6631cfef..1119534d52 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) +def to_serializable(obj): + """ + Convert non-JSON-serializable objects into JSON-compatible formats. + + - Uses `to_dict()` if it's a callable method. + - Falls back to string representation. + """ + if hasattr(obj, "to_dict") and callable(obj.to_dict): + return obj.to_dict() + return str(obj) + + class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): def __init__( self, @@ -69,6 +81,11 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Set to True to enable dual-write for safe migration, False to use LogStore only self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + # Control flag for whether to write the `graph` field to LogStore. + # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; + # otherwise write an empty {} instead. Defaults to writing the `graph` field. + self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true" + def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]: """ Convert a domain model to a logstore model (List[Tuple[str, str]]). @@ -108,9 +125,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): ), ("type", domain_model.workflow_type.value), ("version", domain_model.workflow_version), - ("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"), - ("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"), - ("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"), + ( + "graph", + json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable) + if domain_model.graph and self._enable_put_graph_field + else "{}", + ), + ( + "inputs", + json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable) + if domain_model.inputs + else "{}", + ), + ( + "outputs", + json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable) + if domain_model.outputs + else "{}", + ), ("status", domain_model.status.value), ("error_message", domain_model.error_message or ""), ("total_tokens", str(domain_model.total_tokens)), diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 51a97b20f8..1d9911465b 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -5,6 +5,8 @@ automatic cleanup, backup and restore. Supports complete lifecycle management for knowledge base files. """ +from __future__ import annotations + import json import logging import operator @@ -48,7 +50,7 @@ class FileMetadata: return data @classmethod - def from_dict(cls, data: dict) -> "FileMetadata": + def from_dict(cls, data: dict) -> FileMetadata: """Create instance from dictionary""" data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 70138404c7..913fb675f9 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,93 +1,85 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -upload_config_fields = { - "file_size_limit": fields.Integer, - "batch_count_limit": fields.Integer, - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "workflow_file_upload_limit": fields.Integer, - "image_file_batch_limit": fields.Integer, - "single_chunk_attachment_limit": fields.Integer, -} +from pydantic import BaseModel, ConfigDict, field_validator -def build_upload_config_model(api_or_ns: Namespace): - """Build the upload config model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("UploadConfig", upload_config_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -file_fields = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, - "preview_url": fields.String, - "source_url": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def build_file_model(api_or_ns: Namespace): - """Build the file model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("File", file_fields) +class UploadConfig(ResponseModel): + file_size_limit: int + batch_count_limit: int + file_upload_limit: int | None = None + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + workflow_file_upload_limit: int + image_file_batch_limit: int + single_chunk_attachment_limit: int + attachment_image_file_size_limit: int | None = None -remote_file_info_fields = { - "file_type": fields.String(attribute="file_type"), - "file_length": fields.Integer(attribute="file_length"), -} +class FileResponse(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + preview_url: str | None = None + source_url: str | None = None + original_url: str | None = None + user_id: str | None = None + tenant_id: str | None = None + conversation_id: str | None = None + file_key: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_remote_file_info_model(api_or_ns: Namespace): - """Build the remote file info model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) +class RemoteFileInfo(ResponseModel): + file_type: str + file_length: int -file_fields_with_signed_url = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "url": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, -} +class FileWithSignedUrl(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + url: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_file_with_signed_url_model(api_or_ns: Namespace): - """Build the file with signed URL model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) +__all__ = [ + "FileResponse", + "FileWithSignedUrl", + "RemoteFileInfo", + "UploadConfig", +] diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py index f9e858c68b..97c02e7085 100644 --- a/api/fields/rag_pipeline_fields.py +++ b/api/fields/rag_pipeline_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields # type: ignore +from flask_restx import fields from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index 5bbf0c79a3..d4cb3e9971 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -2,6 +2,8 @@ Broadcast channel for Pub/Sub messaging. """ +from __future__ import annotations + import types from abc import abstractmethod from collections.abc import Iterator @@ -129,6 +131,6 @@ class BroadcastChannel(Protocol): """ @abstractmethod - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: """topic returns a `Topic` instance for the given topic name.""" ... diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 1fc3db8156..5bb4f579c1 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -20,7 +22,7 @@ class BroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: return Topic(self._client, topic) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 16e3a80ee1..d190c51bbc 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "ShardedTopic": + def topic(self, topic: str) -> ShardedTopic: return ShardedTopic(self._client, topic) diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index ff74ccbe8e..0828cf80bf 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and eliminates the need for repetitive language switching logic. """ +from __future__ import annotations + from dataclasses import dataclass from enum import StrEnum, auto from typing import Any, Protocol @@ -53,7 +55,7 @@ class EmailLanguage(StrEnum): ZH_HANS = "zh-Hans" @classmethod - def from_language_code(cls, language_code: str) -> "EmailLanguage": + def from_language_code(cls, language_code: str) -> EmailLanguage: """Convert a language code to EmailLanguage with fallback to English.""" if language_code == "zh-Hans": return cls.ZH_HANS diff --git a/api/libs/helper.py b/api/libs/helper.py index 74e1808e49..07c4823727 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,38 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def escape_like_pattern(pattern: str) -> str: + """ + Escape special characters in a string for safe use in SQL LIKE patterns. + + This function escapes the special characters used in SQL LIKE patterns: + - Backslash (\\) -> \\ + - Percent (%) -> \\% + - Underscore (_) -> \\_ + + The escaped pattern can then be safely used in SQL LIKE queries with the + ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards. + + Args: + pattern: The string pattern to escape + + Returns: + Escaped string safe for use in SQL LIKE queries + + Examples: + >>> escape_like_pattern("50% discount") + '50\\% discount' + >>> escape_like_pattern("test_data") + 'test\\_data' + >>> escape_like_pattern("path\\to\\file") + 'path\\\\to\\\\file' + """ + if not pattern: + return pattern + # Escape backslash first, then percent and underscore + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: """ Extract tenant_id from Account or EndUser object. diff --git a/api/models/model.py b/api/models/model.py index 88cb945b3f..46df047237 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re import uuid @@ -5,7 +7,7 @@ from collections.abc import Mapping from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from uuid import uuid4 import sqlalchemy as sa @@ -54,7 +56,7 @@ class AppMode(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> "AppMode": + def value_of(cls, value: str) -> AppMode: """ Get value of given mode. @@ -70,6 +72,7 @@ class AppMode(StrEnum): class IconType(StrEnum): IMAGE = auto() EMOJI = auto() + LINK = auto() class App(Base): @@ -81,7 +84,7 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) @@ -120,19 +123,19 @@ class App(Base): return "" @property - def site(self) -> Optional["Site"]: + def site(self) -> Site | None: site = db.session.query(Site).where(Site.app_id == self.id).first() return site @property - def app_model_config(self) -> Optional["AppModelConfig"]: + def app_model_config(self) -> AppModelConfig | None: if self.app_model_config_id: return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional["Workflow"]: + def workflow(self) -> Workflow | None: if self.workflow_id: from .workflow import Workflow @@ -287,7 +290,7 @@ class App(Base): return deleted_tools @property - def tags(self) -> list["Tag"]: + def tags(self) -> list[Tag]: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) @@ -1193,7 +1196,7 @@ class Message(Base): return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self) -> list["MessageAgentThought"]: + def agent_thoughts(self) -> list[MessageAgentThought]: return ( db.session.query(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) @@ -1306,7 +1309,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "Message": + def from_dict(cls, data: dict[str, Any]) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1419,15 +1422,20 @@ class MessageAnnotation(Base): app_id: Mapped[str] = mapped_column(StringUUID) conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[str | None] = mapped_column(StringUUID) - question = mapped_column(LongText, nullable=True) - content = mapped_column(LongText, nullable=False) + question: Mapped[str | None] = mapped_column(LongText, nullable=True) + content: Mapped[str] = mapped_column(LongText, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) - account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + @property + def question_text(self) -> str: + """Return a non-null question string, falling back to the answer content.""" + return self.question or self.content + @property def account(self): account = db.session.query(Account).where(Account.id == self.account_id).first() @@ -1528,7 +1536,7 @@ class OperationLog(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) action: Mapped[str] = mapped_column(String(255), nullable=False) - content: Mapped[Any] = mapped_column(sa.JSON) + content: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/models/provider.py b/api/models/provider.py index 2afd8c5329..441b54c797 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from enum import StrEnum, auto from functools import cached_property @@ -19,7 +21,7 @@ class ProviderType(StrEnum): SYSTEM = auto() @staticmethod - def value_of(value: str) -> "ProviderType": + def value_of(value: str) -> ProviderType: for member in ProviderType: if member.value == value: return member @@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum): """hosted trial quota""" @staticmethod - def value_of(value: str) -> "ProviderQuotaType": + def value_of(value: str) -> ProviderQuotaType: for member in ProviderQuotaType: if member.value == value: return member @@ -76,7 +78,7 @@ class Provider(TypeBase): quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="") quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None) - quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0) + quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index e4f9bcb582..e7b98dcf27 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from decimal import Decimal @@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase): ) @property - def schema_type(self) -> "ApiProviderSchemaType": + def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) @property - def tools(self) -> list["ApiToolBundle"]: + def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)] @property @@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase): return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property - def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]: + def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: return [ WorkflowToolParameterConfiguration.model_validate(config) for config in json.loads(self.parameter_configuration) @@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase): except (json.JSONDecodeError, TypeError): return [] - def to_entity(self) -> "MCPProviderEntity": + def to_entity(self) -> MCPProviderEntity: """Convert to domain entity""" from core.entities.mcp_provider import MCPProviderEntity @@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase): ) @property - def description_i18n(self) -> "I18nObject": + def description_i18n(self) -> I18nObject: return I18nObject.model_validate(json.loads(self.description)) diff --git a/api/models/trigger.py b/api/models/trigger.py index 87e2a5ccfc..209345eb84 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -415,7 +415,7 @@ class AppTrigger(TypeBase): node_id: Mapped[str | None] = mapped_column(String(64), nullable=False) trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - provider_name: Mapped[str] = mapped_column(String(255), server_default="", default="") # why it is nullable? + provider_name: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default="") status: Mapped[str] = mapped_column( EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 853d5afefc..a18939523b 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -67,7 +69,7 @@ class WorkflowType(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> "WorkflowType": + def value_of(cls, value: str) -> WorkflowType: """ Get value of given mode. @@ -80,7 +82,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: """ Get workflow type from app mode. @@ -181,7 +183,7 @@ class Workflow(Base): # bug rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", - ) -> "Workflow": + ) -> Workflow: workflow = Workflow() workflow.id = str(uuid4()) workflow.tenant_id = tenant_id @@ -619,7 +621,7 @@ class WorkflowRun(Base): finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( + pause: Mapped[WorkflowPause | None] = orm.relationship( "WorkflowPause", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", uselist=False, @@ -689,7 +691,7 @@ class WorkflowRun(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": + def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -841,7 +843,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) - offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship( + offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( "WorkflowNodeExecutionOffload", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", uselist=True, @@ -851,13 +853,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def preload_offload_data( - query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], + query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], ): return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) @staticmethod def preload_offload_data_and_files( - query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], + query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], ): return query.options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( @@ -932,7 +934,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property @@ -1046,7 +1048,7 @@ class WorkflowNodeExecutionOffload(Base): back_populates="offload_data", ) - file: Mapped[Optional["UploadFile"]] = orm.relationship( + file: Mapped[UploadFile | None] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1064,7 +1066,7 @@ class WorkflowAppLogCreatedFrom(StrEnum): INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": + def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: """ Get value of given mode. @@ -1181,7 +1183,7 @@ class ConversationVariable(TypeBase): ) @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: obj = cls( id=variable.id, app_id=app_id, @@ -1334,7 +1336,7 @@ class WorkflowDraftVariable(Base): ) # Relationship to WorkflowDraftVariableFile - variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship( + variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1504,8 +1506,9 @@ class WorkflowDraftVariable(Base): node_execution_id: str | None, description: str = "", file_id: str | None = None, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = WorkflowDraftVariable() + variable.id = str(uuid4()) variable.created_at = naive_utc_now() variable.updated_at = naive_utc_now() variable.description = description @@ -1526,7 +1529,7 @@ class WorkflowDraftVariable(Base): name: str, value: Segment, description: str = "", - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, @@ -1547,7 +1550,7 @@ class WorkflowDraftVariable(Base): value: Segment, node_execution_id: str, editable: bool = False, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, @@ -1570,7 +1573,7 @@ class WorkflowDraftVariable(Base): visible: bool = True, editable: bool = True, file_id: str | None = None, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=node_id, @@ -1666,7 +1669,7 @@ class WorkflowDraftVariableFile(Base): ) # Relationship to UploadFile - upload_file: Mapped["UploadFile"] = orm.relationship( + upload_file: Mapped[UploadFile] = orm.relationship( foreign_keys=[upload_file_id], lazy="raise", uselist=False, @@ -1733,7 +1736,7 @@ class WorkflowPause(DefaultFieldsMixin, Base): state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) # Relationship to WorkflowRun - workflow_run: Mapped["WorkflowRun"] = orm.relationship( + workflow_run: Mapped[WorkflowRun] = orm.relationship( foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", @@ -1789,7 +1792,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": + def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: if isinstance(pause_reason, HumanInputRequired): return cls( type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index d03cbddceb..b73302508a 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -77,7 +77,7 @@ class AppAnnotationService: if annotation_setting: add_annotation_to_index_task.delay( annotation.id, - annotation.question, + question, current_tenant_id, app_id, annotation_setting.collection_binding_id, @@ -137,13 +137,16 @@ class AppAnnotationService: if not app: raise NotFound("App not found") if keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) stmt = ( select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .where( or_( - MessageAnnotation.question.ilike(f"%{keyword}%"), - MessageAnnotation.content.ilike(f"%{keyword}%"), + MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"), + MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) @@ -253,7 +256,7 @@ class AppAnnotationService: if app_annotation_setting: update_annotation_to_index_task.delay( annotation.id, - annotation.question, + annotation.question_text, current_tenant_id, app_id, app_annotation_setting.collection_binding_id, diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index deba0b79e8..acd2a25a86 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig +from models.model import AppModelConfig, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -428,10 +428,10 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in ["emoji", "link", "image"]: + if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]: icon_type = icon_type_value else: - icon_type = "emoji" + icon_type = IconType.EMOJI.value icon = icon or str(app_data.get("icon", "")) if app: diff --git a/api/services/app_service.py b/api/services/app_service.py index ef89a4fd10..02ebfbace0 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -55,8 +55,11 @@ class AppService: if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) if args.get("name"): + from libs.helper import escape_like_pattern + name = args["name"][:30] - filters.append(App.name.ilike(f"%{name}%")) + escaped_name = escape_like_pattern(name) + filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if args.get("tag_ids") and len(args["tag_ids"]) > 0: target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 659e7406fb..295d48d8a1 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.variables.types import SegmentType -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.model import App, Conversation, EndUser, Message +from services.conversation_variable_updater import ConversationVariableUpdater from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, @@ -218,7 +218,9 @@ class ConversationService: # Apply variable_name filter if provided if variable_name: # Filter using JSON extraction to match variable names case-insensitively - escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + from libs.helper import escape_like_pattern + + escaped_variable_name = escape_like_pattern(variable_name) # Filter using JSON extraction to match variable names case-insensitively if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: stmt = stmt.where( @@ -335,7 +337,7 @@ class ConversationService: updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) # Use the conversation variable updater to persist the changes - updater = conversation_variable_updater_factory() + updater = ConversationVariableUpdater(session_factory.get_session_maker()) updater.update(conversation_id, updated_variable) updater.flush() diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py new file mode 100644 index 0000000000..acc0ec2b22 --- /dev/null +++ b/api/services/conversation_variable_updater.py @@ -0,0 +1,28 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from core.variables.variables import Variable +from models import ConversationVariable + + +class ConversationVariableNotFoundError(Exception): + pass + + +class ConversationVariableUpdater: + def __init__(self, session_maker: sessionmaker[Session]) -> None: + self._session_maker: sessionmaker[Session] = session_maker + + def update(self, conversation_id: str, variable: Variable) -> None: + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with self._session_maker() as session: + row = session.scalar(stmt) + if not row: + raise ConversationVariableNotFoundError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self) -> None: + pass diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac4b25c5dc..18e5613438 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -144,7 +144,8 @@ class DatasetService: query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.where(Dataset.name.ilike(f"%{search}%")) + escaped_search = helper.escape_like_pattern(search) + query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: @@ -3423,7 +3424,8 @@ class SegmentService: .order_by(ChildChunk.position.asc()) ) if keyword: - query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\")) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod @@ -3456,7 +3458,8 @@ class SegmentService: query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\")) query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 40faa85b9a..65dd41af43 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -35,7 +35,10 @@ class ExternalDatasetService: .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + from libs.helper import escape_like_pattern + + escaped_search = escape_like_pattern(search) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f53448e7fe..1ba64813ba 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -874,7 +874,7 @@ class RagPipelineService: variable_pool = node_instance.graph_runtime_state.variable_pool invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - if invoke_from.value == InvokeFrom.PUBLISHED: + if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() @@ -1318,7 +1318,7 @@ class RagPipelineService: "datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)], "original_document_id": document.id, }, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, streaming=False, call_depth=0, workflow_thread_pool_id=None, diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 937e6593fe..bd3585acf4 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -19,7 +19,10 @@ class TagService: .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) + query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index ef77c33c1b..4131d75145 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -853,7 +853,7 @@ class TriggerProviderService: """ Create a subscription builder for rebuilding an existing subscription. - This method creates a builder pre-filled with data from the rebuild request, + This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub) keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged. :param tenant_id: Tenant ID @@ -868,111 +868,50 @@ class TriggerProviderService: if not provider_controller: raise ValueError(f"Provider {provider_id} not found") - # Use distributed lock to prevent race conditions on the same subscription - lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}" - with redis_client.lock(lock_key, timeout=20): - with Session(db.engine, expire_on_commit=False) as session: - try: - # Get subscription within the transaction - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() - ) - if not subscription: - raise ValueError(f"Subscription {subscription_id} not found") + subscription = TriggerProviderService.get_subscription_by_id( + tenant_id=tenant_id, + subscription_id=subscription_id, + ) + if not subscription: + raise ValueError(f"Subscription {subscription_id} not found") - credential_type = CredentialType.of(subscription.credential_type) - if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]: - raise ValueError("Credential type not supported for rebuild") + credential_type = CredentialType.of(subscription.credential_type) + if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}: + raise ValueError(f"Credential type {credential_type} not supported for auto creation") - # Decrypt existing credentials for merging - credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription( - tenant_id=tenant_id, - controller=provider_controller, - subscription=subscription, - ) - decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials)) + # Delete the previous subscription + user_id = subscription.user_id + unsubscribe_result = TriggerManager.unsubscribe_trigger( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + subscription=subscription.to_entity(), + credentials=subscription.credentials, + credential_type=credential_type, + ) + if not unsubscribe_result.success: + raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}") - # Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value - merged_credentials: dict[str, Any] = { - key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE) - for key, value in credentials.items() - } - - user_id = subscription.user_id - - # TODO: Trying to invoke update api of the plugin trigger provider - - # FALLBACK: If the update api is not implemented, - # delete the previous subscription and create a new one - - # Unsubscribe the previous subscription (external call, but we'll handle errors) - try: - TriggerManager.unsubscribe_trigger( - tenant_id=tenant_id, - user_id=user_id, - provider_id=provider_id, - subscription=subscription.to_entity(), - credentials=decrypted_credentials, - credential_type=credential_type, - ) - except Exception as e: - logger.exception("Error unsubscribing trigger during rebuild", exc_info=e) - # Continue anyway - the subscription might already be deleted externally - - # Create a new subscription with the same subscription_id and endpoint_id (external call) - new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger( - tenant_id=tenant_id, - user_id=user_id, - provider_id=provider_id, - endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id), - parameters=parameters, - credentials=merged_credentials, - credential_type=credential_type, - ) - - # Update the subscription in the same transaction - # Inline update logic to reuse the same session - if name is not None and name != subscription.name: - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() - ) - if existing and existing.id != subscription.id: - raise ValueError(f"Subscription name '{name}' already exists for this provider") - subscription.name = name - - # Update parameters - subscription.parameters = dict(parameters) - - # Update credentials with merged (and encrypted) values - subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials)) - - # Update properties - if new_subscription.properties: - properties_encrypter, _ = create_provider_encrypter( - tenant_id=tenant_id, - config=provider_controller.get_properties_schema(), - cache=NoOpProviderCredentialCache(), - ) - subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties))) - - # Update expiration timestamp - if new_subscription.expires_at is not None: - subscription.expires_at = new_subscription.expires_at - - # Commit the transaction - session.commit() - - # Clear subscription cache - delete_cache_for_subscription( - tenant_id=tenant_id, - provider_id=subscription.provider_id, - subscription_id=subscription.id, - ) - - except Exception as e: - # Rollback on any error - session.rollback() - logger.exception("Failed to rebuild trigger subscription", exc_info=e) - raise + # Create a new subscription with the same subscription_id and endpoint_id + new_credentials: dict[str, Any] = { + key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } + new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id), + parameters=parameters, + credentials=new_credentials, + credential_type=credential_type, + ) + TriggerProviderService.update_trigger_subscription( + tenant_id=tenant_id, + subscription_id=subscription.id, + name=name, + parameters=parameters, + credentials=new_credentials, + properties=new_subscription.properties, + expires_at=new_subscription.expires_at, + ) diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 0f969207cf..f973361341 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from abc import ABC, abstractmethod from collections.abc import Mapping @@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator): self._max_size_bytes = max_size_bytes @classmethod - def default(cls) -> "VariableTruncator": + def default(cls) -> VariableTruncator: return VariableTruncator( max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE, array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH, diff --git a/api/services/website_service.py b/api/services/website_service.py index a23f01ec71..fe48c3b08e 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import json from dataclasses import dataclass @@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest: return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest": + def from_args(cls, args: dict) -> WebsiteCrawlApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") url = args.get("url") @@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest: job_id: str @classmethod - def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": + def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") if not provider: diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 01f0c7a55a..8574d30255 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -86,12 +86,19 @@ class WorkflowAppService: # Join to workflow run for filtering when needed. if keyword: - keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") + from libs.helper import escape_like_pattern + + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword[:30]) + keyword_like_val = f"%{escaped_keyword}%" keyword_conditions = [ - WorkflowRun.inputs.ilike(keyword_like_val), - WorkflowRun.outputs.ilike(keyword_like_val), + WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"), + WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), + and_( + WorkflowRun.created_by_role == "end_user", + EndUser.session_id.ilike(keyword_like_val, escape="\\"), + ), ] # filter keyword by workflow run id diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f299ce3baa..9407a2b3f0 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -679,6 +679,7 @@ def _batch_upsert_draft_variable( def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: d: dict[str, Any] = { + "id": model.id, "app_id": model.app_id, "last_edited_at": None, "node_id": model.node_id, diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index cdc07c77a8..be1de3cdd2 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -98,7 +98,7 @@ def enable_annotation_reply_task( if annotations: for annotation in annotations: document = Document( - page_content=annotation.question, + page_content=annotation.question_text, metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 1eef361a92..3c5e152520 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 275f5abe6e..093342d1a3 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e421e4ff36..9b0bd6275b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -10,6 +10,7 @@ from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import NodeRunResult from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -67,6 +68,16 @@ def init_code_node(code_config: dict): config=code_config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + code_limits=CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ), ) return node diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 72469ad646..dcf31aeca7 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.enums import WorkflowExecutionStatus from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import GraphRunPausedEvent from core.workflow.runtime.graph_runtime_state import GraphRuntimeState from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState @@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers: """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() - # Don't initialize - graph_runtime_state should not be set + # Don't initialize - graph_runtime_state should be uninitialized event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) - # Act & Assert - Should raise AttributeError - with pytest.raises(AttributeError): + # Act & Assert - Should raise GraphEngineLayerNotInitializedError + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index da73122cd7..5555400ca6 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -444,6 +444,78 @@ class TestAnnotationService: assert total == 1 assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with special characters in content + annotation_with_percent = { + "question": "Question with 50% discount", + "answer": "Answer about 50% discount offer", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id) + + annotation_with_underscore = { + "question": "Question with test_data", + "answer": "Answer about test_data value", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id) + + annotation_with_backslash = { + "question": "Question with path\\to\\file", + "answer": "Answer about path\\to\\file location", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id) + + # Create annotation that should NOT match (contains % but as part of different text) + annotation_no_match = { + "question": "Question with 100% different", + "answer": "Answer about 100% different content", + } + AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id) + + # Test 1: Search with % character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content + + # Test 2: Search with _ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="test_data" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content + + # Test 3: Search with \ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="path\\to\\file" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + # Should only find the 50% annotation, not the 100% one + assert total == 1 + assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) + def test_get_annotation_list_by_app_id_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index e53392bcef..745d6c97b0 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -7,7 +7,9 @@ from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService class TestAppService: @@ -71,6 +73,9 @@ class TestAppService: } # Create app + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -109,6 +114,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Test different app modes @@ -159,6 +167,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() created_app = app_service.create_app(tenant.id, app_args, account) @@ -194,6 +205,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create multiple apps @@ -245,6 +259,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create apps with different modes @@ -315,6 +332,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create an app @@ -392,6 +412,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -458,6 +481,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -508,6 +534,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -562,6 +591,9 @@ class TestAppService: "icon_background": "#74B9FF", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -617,6 +649,9 @@ class TestAppService: "icon_background": "#A29BFE", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -672,6 +707,9 @@ class TestAppService: "icon_background": "#FD79A8", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -720,6 +758,9 @@ class TestAppService: "icon_background": "#E17055", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -768,6 +809,9 @@ class TestAppService: "icon_background": "#00B894", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -826,6 +870,9 @@ class TestAppService: "icon_background": "#6C5CE7", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -862,6 +909,9 @@ class TestAppService: "icon_background": "#FDCB6E", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -899,6 +949,9 @@ class TestAppService: "icon_background": "#E84393", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -947,8 +1000,132 @@ class TestAppService: "icon_background": "#D63031", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Attempt to create app with invalid mode with pytest.raises(ValueError, match="invalid mode value"): app_service.create_app(tenant.id, app_args, account) + + def test_get_apps_with_special_characters_in_name( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test app retrieval with special characters in name search to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in name search are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Import here to avoid circular dependency + from services.app_service import AppService + + app_service = AppService() + + # Create apps with special characters in names + app_with_percent = app_service.create_app( + tenant.id, + { + "name": "App with 50% discount", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_underscore = app_service.create_app( + tenant.id, + { + "name": "test_data_app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_backslash = app_service.create_app( + tenant.id, + { + "name": "path\\to\\app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Create app that should NOT match + app_no_match = app_service.create_app( + tenant.id, + { + "name": "100% different", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Test 1: Search with % character + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "App with 50% discount" + + # Test 2: Search with _ character + args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "test_data_app" + + # Test 3: Search with \ character + args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "path\\to\\app" + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert all("50%" in app.name for app in paginated_apps.items) diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 6732b8d558..e8c7f17e0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import create_autospec, patch import pytest @@ -312,6 +313,85 @@ class TestTagService: result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") assert len(result_no_match) == 0 + def test_get_tags_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test tag retrieval with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + from extensions.ext_database import db + + # Create tags with special characters in names + tag_with_percent = Tag( + name="50% discount", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_percent.id = str(uuid.uuid4()) + db.session.add(tag_with_percent) + + tag_with_underscore = Tag( + name="test_data_tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_underscore.id = str(uuid.uuid4()) + db.session.add(tag_with_underscore) + + tag_with_backslash = Tag( + name="path\\to\\tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_backslash.id = str(uuid.uuid4()) + db.session.add(tag_with_backslash) + + # Create tag that should NOT match + tag_no_match = Tag( + name="100% different", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_no_match.id = str(uuid.uuid4()) + db.session.add(tag_no_match) + + db.session.commit() + + # Act & Assert: Test 1 - Search with % character + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert result[0].name == "50% discount" + + # Test 2 - Search with _ character + result = TagService.get_tags("app", tenant.id, keyword="test_data") + assert len(result) == 1 + assert result[0].name == "test_data_tag" + + # Test 3 - Search with \ character + result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag") + assert len(result) == 1 + assert result[0].name == "path\\to\\tag" + + # Test 4 - Search with % should NOT match 100% (verifies escaping works) + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert all("50%" in item.name for item in result) + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 8322b9414e..5315960d73 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -474,64 +474,6 @@ class TestTriggerProviderService: assert subscription.name == original_name assert subscription.parameters == original_parameters - def test_rebuild_trigger_subscription_unsubscribe_error_continues( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test that unsubscribe errors are handled gracefully and operation continues. - - This test verifies: - - Unsubscribe errors are caught and logged but don't stop the rebuild - - Rebuild continues even if unsubscribe fails - """ - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - provider_id = TriggerProviderID("test_org/test_plugin/test_provider") - credential_type = CredentialType.API_KEY - - original_credentials = {"api_key": "original-key"} - subscription = self._create_test_subscription( - db_session_with_containers, - tenant.id, - account.id, - provider_id, - credential_type, - original_credentials, - mock_external_service_dependencies, - ) - - # Make unsubscribe_trigger raise an error (should be caught and continue) - mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError( - "Unsubscribe failed" - ) - - new_subscription_entity = TriggerSubscriptionEntity( - endpoint=subscription.endpoint_id, - parameters={}, - properties={}, - expires_at=-1, - ) - mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity - - # Execute rebuild - should succeed despite unsubscribe error - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=tenant.id, - provider_id=provider_id, - subscription_id=subscription.id, - credentials={"api_key": "new-key"}, - parameters={}, - ) - - # Verify subscribe was still called (operation continued) - mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once() - - # Verify subscription was updated - db.session.refresh(subscription) - assert subscription.parameters == {} - def test_rebuild_trigger_subscription_subscription_not_found( self, db_session_with_containers, mock_external_service_dependencies ): @@ -558,70 +500,6 @@ class TestTriggerProviderService: parameters={}, ) - def test_rebuild_trigger_subscription_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test error when provider is not found. - - This test verifies: - - Proper error is raised when provider doesn't exist - """ - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider") - - # Make get_trigger_provider return None - mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None - - with pytest.raises(ValueError, match="Provider.*not found"): - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=tenant.id, - provider_id=provider_id, - subscription_id=fake.uuid4(), - credentials={}, - parameters={}, - ) - - def test_rebuild_trigger_subscription_unsupported_credential_type( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test error when credential type is not supported for rebuild. - - This test verifies: - - Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY) - """ - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - provider_id = TriggerProviderID("test_org/test_plugin/test_provider") - credential_type = CredentialType.UNAUTHORIZED # Not supported - - subscription = self._create_test_subscription( - db_session_with_containers, - tenant.id, - account.id, - provider_id, - credential_type, - {}, - mock_external_service_dependencies, - ) - - with pytest.raises(ValueError, match="Credential type not supported for rebuild"): - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=tenant.id, - provider_id=provider_id, - subscription_id=subscription.id, - credentials={}, - parameters={}, - ) - def test_rebuild_trigger_subscription_name_uniqueness_check( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 7b95944bbe..040fb826e1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService from services.workflow_app_service import WorkflowAppService @@ -86,6 +88,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -147,6 +152,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -308,6 +316,156 @@ class TestWorkflowAppService: assert result_no_match["total"] == 0 assert len(result_no_match["data"]) == 0 + def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + from extensions.ext_database import db + + service = WorkflowAppService() + + # Test 1: Search with % character + workflow_run_1 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}), + outputs=json.dumps({"result": "50% discount applied", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_1) + db.session.flush() + + workflow_app_log_1 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_1.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_1.id = str(uuid.uuid4()) + workflow_app_log_1.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_1) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should find the workflow_run_1 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"]) + + # Test 2: Search with _ character + workflow_run_2 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}), + outputs=json.dumps({"result": "test_data_value found", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_2) + db.session.flush() + + workflow_app_log_2 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_2.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_2.id = str(uuid.uuid4()) + workflow_app_log_2.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_2) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 + ) + # Should find the workflow_run_2 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"]) + + # Test 3: Search with % should NOT match 100% (verifies escaping works correctly) + workflow_run_4 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}), + outputs=json.dumps({"result": "100% different result", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_4) + db.session.flush() + + workflow_app_log_4 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_4.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_4.id = str(uuid.uuid4()) + workflow_app_log_4.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_4) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4) + # This verifies that escaping works correctly - 50% should not match 100% + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + # Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different) + found_run_ids = [log.workflow_run_id for log in result["data"]] + assert workflow_run_1.id in found_run_ids + assert workflow_run_4.id not in found_run_ids + def test_get_paginate_workflow_app_logs_with_status_filter( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index e29b98037f..b9977b1fb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -165,7 +165,7 @@ class TestRagPipelineRunTasks: "files": [], "user_id": account.id, "stream": False, - "invoke_from": "published", + "invoke_from": InvokeFrom.PUBLISHED_PIPELINE.value, "workflow_execution_id": str(uuid.uuid4()), "pipeline_config": { "app_id": str(uuid.uuid4()), @@ -249,7 +249,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -294,7 +294,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -743,7 +743,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 209b6bf59b..6fce7849f9 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -16,6 +16,7 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -51,6 +52,7 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): os.environ.clear() # Set minimal required env vars + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -75,6 +77,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -124,6 +127,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -140,6 +144,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): """Test that DB_EXTRAS options are properly merged with default timezone setting""" # Set environment variables + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -199,6 +204,7 @@ def test_celery_broker_url_with_special_chars_password( # Set up basic required environment variables (following existing pattern) monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") diff --git a/api/tests/unit_tests/controllers/__init__.py b/api/tests/unit_tests/controllers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/__init__.py b/api/tests/unit_tests/controllers/console/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py new file mode 100644 index 0000000000..40eb59a8f4 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import builtins +import sys +from datetime import datetime +from importlib import util +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import pytest +from flask.views import MethodView + +# kombu references MethodView as a global when importing celery/kombu pools. +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +def _load_app_module(): + module_name = "controllers.console.app.app" + if module_name in sys.modules: + return sys.modules[module_name] + + root = Path(__file__).resolve().parents[5] + module_path = root / "controllers" / "console" / "app" / "app.py" + + class _StubNamespace: + def __init__(self): + self.models: dict[str, Any] = {} + self.payload = None + + def schema_model(self, name, schema): + self.models[name] = schema + + def _decorator(self, obj): + return obj + + def doc(self, *args, **kwargs): + return self._decorator + + def expect(self, *args, **kwargs): + return self._decorator + + def response(self, *args, **kwargs): + return self._decorator + + def route(self, *args, **kwargs): + def decorator(obj): + return obj + + return decorator + + stub_namespace = _StubNamespace() + + original_console = sys.modules.get("controllers.console") + original_app_pkg = sys.modules.get("controllers.console.app") + stubbed_modules: list[tuple[str, ModuleType | None]] = [] + + console_module = ModuleType("controllers.console") + console_module.__path__ = [str(root / "controllers" / "console")] + console_module.console_ns = stub_namespace + console_module.api = None + console_module.bp = None + sys.modules["controllers.console"] = console_module + + app_package = ModuleType("controllers.console.app") + app_package.__path__ = [str(root / "controllers" / "console" / "app")] + sys.modules["controllers.console.app"] = app_package + console_module.app = app_package + + def _stub_module(name: str, attrs: dict[str, Any]): + original = sys.modules.get(name) + module = ModuleType(name) + for key, value in attrs.items(): + setattr(module, key, value) + sys.modules[name] = module + stubbed_modules.append((name, original)) + + class _OpsTraceManager: + @staticmethod + def get_app_tracing_config(app_id: str) -> dict[str, Any]: + return {} + + @staticmethod + def update_app_tracing_config(app_id: str, **kwargs) -> None: + return None + + _stub_module( + "core.ops.ops_trace_manager", + { + "OpsTraceManager": _OpsTraceManager, + "TraceQueueManager": object, + "TraceTask": object, + }, + ) + + spec = util.spec_from_file_location(module_name, module_path) + module = util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + assert spec.loader is not None + spec.loader.exec_module(module) + finally: + for name, original in reversed(stubbed_modules): + if original is not None: + sys.modules[name] = original + else: + sys.modules.pop(name, None) + if original_console is not None: + sys.modules["controllers.console"] = original_console + else: + sys.modules.pop("controllers.console", None) + if original_app_pkg is not None: + sys.modules["controllers.console.app"] = original_app_pkg + else: + sys.modules.pop("controllers.console.app", None) + + return module + + +_app_module = _load_app_module() +AppDetailWithSite = _app_module.AppDetailWithSite +AppPagination = _app_module.AppPagination +AppPartial = _app_module.AppPartial + + +@pytest.fixture(autouse=True) +def patch_signed_url(monkeypatch): + """Ensure icon URL generation uses a deterministic helper for tests.""" + + def _fake_signed_url(key: str | None) -> str | None: + if not key: + return None + return f"signed:{key}" + + monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + + +def _ts(hour: int = 12) -> datetime: + return datetime(2024, 1, 1, hour, 0, 0) + + +def _dummy_model_config(): + return SimpleNamespace( + model_dict={"provider": "openai", "name": "gpt-4o"}, + pre_prompt="hello", + created_by="config-author", + created_at=_ts(9), + updated_by="config-editor", + updated_at=_ts(10), + ) + + +def _dummy_workflow(): + return SimpleNamespace( + id="wf-1", + created_by="workflow-author", + created_at=_ts(8), + updated_by="workflow-editor", + updated_at=_ts(9), + ) + + +def test_app_partial_serialization_uses_aliases(): + created_at = _ts() + app_obj = SimpleNamespace( + id="app-1", + name="My App", + desc_or_prompt="Prompt snippet", + mode_compatible_with_agent="chat", + icon_type="image", + icon="icon-key", + icon_background="#fff", + app_model_config=_dummy_model_config(), + workflow=_dummy_workflow(), + created_by="creator", + created_at=created_at, + updated_by="editor", + updated_at=created_at, + tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")], + access_mode="private", + create_user_name="Creator", + author_name="Author", + has_draft_trigger=True, + ) + + serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["description"] == "Prompt snippet" + assert serialized["mode"] == "chat" + assert serialized["icon_url"] == "signed:icon-key" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["updated_at"] == int(created_at.timestamp()) + assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"} + assert serialized["workflow"]["id"] == "wf-1" + assert serialized["tags"][0]["name"] == "Utilities" + + +def test_app_detail_with_site_includes_nested_serialization(): + timestamp = _ts(14) + site = SimpleNamespace( + code="site-code", + title="Public Site", + icon_type="image", + icon="site-icon", + created_at=timestamp, + updated_at=timestamp, + ) + app_obj = SimpleNamespace( + id="app-2", + name="Detailed App", + description="Desc", + mode_compatible_with_agent="advanced-chat", + icon_type="image", + icon="detail-icon", + icon_background="#123456", + enable_site=True, + enable_api=True, + app_model_config={ + "opening_statement": "hi", + "model": {"provider": "openai", "name": "gpt-4o"}, + "retriever_resource": {"enabled": True}, + }, + workflow=_dummy_workflow(), + tracing={"enabled": True}, + use_icon_as_answer_icon=True, + created_by="creator", + created_at=timestamp, + updated_by="editor", + updated_at=timestamp, + access_mode="public", + tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")], + api_base_url="https://api.example.com/v1", + max_active_requests=5, + deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}], + site=site, + ) + + serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["icon_url"] == "signed:detail-icon" + assert serialized["model_config"]["retriever_resource"] == {"enabled": True} + assert serialized["deleted_tools"][0]["tool_name"] == "search" + assert serialized["site"]["icon_url"] == "signed:site-icon" + assert serialized["site"]["created_at"] == int(timestamp.timestamp()) + + +def test_app_pagination_aliases_per_page_and_has_next(): + item_one = SimpleNamespace( + id="app-10", + name="Paginated One", + desc_or_prompt="Summary", + mode_compatible_with_agent="chat", + icon_type="image", + icon="first-icon", + created_at=_ts(15), + updated_at=_ts(15), + ) + item_two = SimpleNamespace( + id="app-11", + name="Paginated Two", + desc_or_prompt="Summary", + mode_compatible_with_agent="agent-chat", + icon_type="emoji", + icon="🙂", + created_at=_ts(16), + updated_at=_ts(16), + ) + pagination = SimpleNamespace( + page=2, + per_page=10, + total=50, + has_next=True, + items=[item_one, item_two], + ) + + serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json") + + assert serialized["page"] == 2 + assert serialized["limit"] == 10 + assert serialized["has_more"] is True + assert len(serialized["data"]) == 2 + assert serialized["data"][0]["icon_url"] == "signed:first-icon" + assert serialized["data"][1]["icon_url"] is None diff --git a/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py b/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py new file mode 100644 index 0000000000..f8dd98fdb2 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py @@ -0,0 +1,145 @@ +""" +Test for document detail API data_source_info serialization fix. + +This test verifies that the document detail API returns both data_source_info +and data_source_detail_dict for all data_source_type values, including "local_file". +""" + +import json +from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union + +from models.dataset import Document + + +class LocalFileInfo(TypedDict): + file_path: str + size: int + created_at: NotRequired[str] + + +class UploadFileInfo(TypedDict): + upload_file_id: str + + +class NotionImportInfo(TypedDict): + notion_page_id: str + workspace_id: str + + +class WebsiteCrawlInfo(TypedDict): + url: str + job_id: str + + +RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo] +T_type = TypeVar("T_type", bound=str) +T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]) + + +class Case(TypedDict, Generic[T_type, T_info]): + data_source_type: T_type + data_source_info: str + expected_raw: T_info + + +LocalFileCase = Case[Literal["local_file"], LocalFileInfo] +UploadFileCase = Case[Literal["upload_file"], UploadFileInfo] +NotionImportCase = Case[Literal["notion_import"], NotionImportInfo] +WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo] + +AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase] + + +case_1: LocalFileCase = { + "data_source_type": "local_file", + "data_source_info": json.dumps({"file_path": "/tmp/test.txt", "size": 1024}), + "expected_raw": {"file_path": "/tmp/test.txt", "size": 1024}, +} + + +# ERROR: Expected LocalFileInfo, but got WebsiteCrawlInfo +case_2: LocalFileCase = { + "data_source_type": "local_file", + "data_source_info": "...", + "expected_raw": {"file_path": "https://google.com", "size": 123}, +} + +cases: list[AnyCase] = [case_1] + + +class TestDocumentDetailDataSourceInfo: + """Test cases for document detail API data_source_info serialization.""" + + def test_data_source_info_dict_returns_raw_data(self): + """Test that data_source_info_dict returns raw JSON data for all data_source_type values.""" + # Test data for different data_source_type values + for case in cases: + document = Document( + data_source_type=case["data_source_type"], + data_source_info=case["data_source_info"], + ) + + # Test data_source_info_dict (raw data) + raw_result = document.data_source_info_dict + assert raw_result == case["expected_raw"], f"Failed for {case['data_source_type']}" + + # Verify raw_result is always a valid dict + assert isinstance(raw_result, dict) + + def test_local_file_data_source_info_without_db_context(self): + """Test that local_file type data_source_info_dict works without database context.""" + test_data: LocalFileInfo = { + "file_path": "/local/path/document.txt", + "size": 512, + "created_at": "2024-01-01T00:00:00Z", + } + + document = Document( + data_source_type="local_file", + data_source_info=json.dumps(test_data), + ) + + # data_source_info_dict should return the raw data (this doesn't need DB context) + raw_data = document.data_source_info_dict + assert raw_data == test_data + assert isinstance(raw_data, dict) + + # Verify the data contains expected keys for pipeline mode + assert "file_path" in raw_data + assert "size" in raw_data + + def test_notion_and_website_crawl_data_source_detail(self): + """Test that notion_import and website_crawl return raw data in data_source_detail_dict.""" + # Test notion_import + notion_data: NotionImportInfo = {"notion_page_id": "page-123", "workspace_id": "ws-456"} + document = Document( + data_source_type="notion_import", + data_source_info=json.dumps(notion_data), + ) + + # data_source_detail_dict should return raw data for notion_import + detail_result = document.data_source_detail_dict + assert detail_result == notion_data + + # Test website_crawl + website_data: WebsiteCrawlInfo = {"url": "https://example.com", "job_id": "job-789"} + document = Document( + data_source_type="website_crawl", + data_source_info=json.dumps(website_data), + ) + + # data_source_detail_dict should return raw data for website_crawl + detail_result = document.data_source_detail_dict + assert detail_result == website_data + + def test_local_file_data_source_detail_dict_without_db(self): + """Test that local_file returns empty data_source_detail_dict (this doesn't need DB context).""" + # Test local_file - this should work without database context since it returns {} early + document = Document( + data_source_type="local_file", + data_source_info=json.dumps({"file_path": "/tmp/test.txt"}), + ) + + # Should return empty dict for local_file type (handled in the model) + detail_result = document.data_source_detail_dict + assert detail_result == {} diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py index 2630fbcfd0..370bf63fdb 100644 --- a/api/tests/unit_tests/controllers/console/test_files_security.py +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -1,7 +1,9 @@ +import builtins import io from unittest.mock import patch import pytest +from flask.views import MethodView from werkzeug.exceptions import Forbidden from controllers.common.errors import ( @@ -14,6 +16,9 @@ from controllers.common.errors import ( from services.errors.file import FileTooLargeError as ServiceFileTooLargeError from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + class TestFileUploadSecurity: """Test file upload security logic without complex framework setup""" @@ -128,7 +133,7 @@ class TestFileUploadSecurity: # Test passes if no exception is raised # Test 4: Service error handling - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_file_too_large_error(self, mock_upload): """Test that service FileTooLargeError is properly converted""" mock_upload.side_effect = ServiceFileTooLargeError("File too large") @@ -140,7 +145,7 @@ class TestFileUploadSecurity: with pytest.raises(FileTooLargeError): raise FileTooLargeError(e.description) - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_unsupported_file_type_error(self, mock_upload): """Test that service UnsupportedFileTypeError is properly converted""" mock_upload.side_effect = ServiceUnsupportedFileTypeError() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py new file mode 100644 index 0000000000..205b157542 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py @@ -0,0 +1,390 @@ +""" +Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method, +specifically testing the ANSWER node message_replace logic. +""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity +from core.app.entities.queue_entities import QueueNodeSucceededEvent +from core.workflow.enums import NodeType +from models import EndUser +from models.model import AppMode + + +class TestAnswerNodeMessageReplace: + """Test cases for ANSWER node message_replace event logic.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock(spec=AdvancedChatAppGenerateEntity) + entity.task_id = "test-task-id" + entity.app_id = "test-app-id" + entity.workflow_run_id = "test-workflow-run-id" + # minimal app_config used by pipeline internals + entity.app_config = SimpleNamespace( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.ADVANCED_CHAT, + app_model_config_dict={}, + additional_features=None, + sensitive_word_avoidance=None, + ) + entity.query = "test query" + entity.files = [] + entity.extras = {} + entity.trace_manager = None + entity.inputs = {} + entity.invoke_from = "debugger" + return entity + + @pytest.fixture + def mock_workflow(self): + """Create a mock workflow.""" + workflow = Mock() + workflow.id = "test-workflow-id" + workflow.features_dict = {} + return workflow + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = Mock() + manager.listen.return_value = [] + manager.graph_runtime_state = None + return manager + + @pytest.fixture + def mock_conversation(self): + """Create a mock conversation.""" + conversation = Mock() + conversation.id = "test-conversation-id" + conversation.mode = "advanced_chat" + return conversation + + @pytest.fixture + def mock_message(self): + """Create a mock message.""" + message = Mock() + message.id = "test-message-id" + message.query = "test query" + message.created_at = Mock() + message.created_at.timestamp.return_value = 1234567890 + return message + + @pytest.fixture + def mock_user(self): + """Create a mock end user.""" + user = MagicMock(spec=EndUser) + user.id = "test-user-id" + user.session_id = "test-session-id" + return user + + @pytest.fixture + def mock_draft_var_saver_factory(self): + """Create a mock draft variable saver factory.""" + return Mock() + + @pytest.fixture + def pipeline( + self, + mock_application_generate_entity, + mock_workflow, + mock_queue_manager, + mock_conversation, + mock_message, + mock_user, + mock_draft_var_saver_factory, + ): + """Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies.""" + from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline + + with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"): + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=mock_application_generate_entity, + workflow=mock_workflow, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + user=mock_user, + stream=True, + dialogue_count=1, + draft_var_saver_factory=mock_draft_var_saver_factory, + ) + # Initialize workflow run id to avoid validation errors + pipeline._workflow_run_id = "test-workflow-run-id" + # Mock the message cycle manager methods we need to track + pipeline._message_cycle_manager.message_replace_to_stream_response = Mock() + return pipeline + + def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity): + """ + Test that when an ANSWER node's final output differs from accumulated answer, + a message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "initial answer" + + # Create ANSWER node succeeded event with different final output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": "updated final answer"}, + ) + + # Mock the workflow response converter to avoid extra processing + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + responses = list(pipeline._handle_node_succeeded_event(event)) + + # Assert + assert pipeline._task_state.answer == "updated final answer" + # Verify message_replace was called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with( + answer="updated final answer", reason="variable_update" + ) + + def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node's final output is the same as accumulated answer, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "same answer" + + # Create ANSWER node succeeded event with same output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": "same answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "same answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node's output is None or missing 'answer' key, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with None output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": None}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node has empty outputs dict, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with empty outputs + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_no_answer_key_in_outputs(self, pipeline): + """ + Test that when an ANSWER node's outputs don't contain 'answer' key, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event without 'answer' key in outputs + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"other_key": "some value"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_non_answer_node_does_not_send_message_replace(self, pipeline): + """ + Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Test with LLM node + llm_event = QueueNodeSucceededEvent( + node_execution_id="test-llm-execution-id", + node_id="test-llm-node", + node_type=NodeType.LLM, + start_at=datetime.now(), + outputs={"answer": "different answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(llm_event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_end_node_does_not_send_message_replace(self, pipeline): + """ + Test that END nodes don't trigger message_replace events even with 'answer' output. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create END node succeeded event with answer output + event = QueueNodeSucceededEvent( + node_execution_id="test-end-execution-id", + node_id="test-end-node", + node_type=NodeType.END, + start_at=datetime.now(), + outputs={"answer": "different answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_numeric_output_converts_to_string(self, pipeline): + """ + Test that when an ANSWER node's final output is numeric, + it gets converted to string properly. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "text answer" + + # Create ANSWER node succeeded event with numeric output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": 12345}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should be converted to string + assert pipeline._task_state.answer == "12345" + # Verify message_replace was called with string + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with( + answer="12345", reason="variable_update" + ) + + def test_answer_node_files_are_recorded(self, pipeline): + """ + Test that ANSWER nodes properly record files from outputs. + """ + # Arrange + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with files + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={ + "answer": "same answer", + "files": [ + {"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"} + ], + }, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"]) + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: files should be recorded + assert len(pipeline._recorded_files) == 1 + assert pipeline._recorded_files[0] == event.outputs["files"][0] diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 1c9f577a50..6b40bf462b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -431,10 +431,10 @@ class TestWorkflowResponseConverterServiceApiTruncation: description="Explore calls should have truncation enabled", ), TestCase( - name="published_truncation_enabled", - invoke_from=InvokeFrom.PUBLISHED, + name="published_pipeline_truncation_enabled", + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, expected_truncation_enabled=True, - description="Published app calls should have truncation enabled", + description="Published pipeline calls should have truncation enabled", ), ], ids=lambda x: x.name, diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py new file mode 100644 index 0000000000..b6e8cc9c8e --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -0,0 +1,144 @@ +from collections.abc import Sequence +from datetime import datetime +from unittest.mock import Mock + +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.variables import StringVariable +from core.variables.segments import Segment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.protocols.command_channel import CommandChannel +from core.workflow.graph_events.node import NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers +from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from core.workflow.system_variable import SystemVariable + + +class MockReadOnlyVariablePool: + def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None: + self._variables = variables or {} + + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + return self._variables.get((selector[0], selector[1])) + + def get_all_by_node(self, node_id: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == node_id} + + def get_by_prefix(self, prefix: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} + + +def _build_graph_runtime_state( + variable_pool: MockReadOnlyVariablePool, + conversation_id: str | None = None, +) -> ReadOnlyGraphRuntimeState: + graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + graph_runtime_state.variable_pool = variable_pool + graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() + return graph_runtime_state + + +def _build_node_run_succeeded_event( + *, + node_type: NodeType, + outputs: dict[str, object] | None = None, + process_data: dict[str, object] | None = None, +) -> NodeRunSucceededEvent: + return NodeRunSucceededEvent( + id="node-exec-id", + node_id="assigner", + node_type=node_type, + start_at=datetime.utcnow(), + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs or {}, + process_data=process_data or {}, + ), + ) + + +def test_persists_conversation_variables_from_assigner_output(): + conversation_id = "conv-123" + variable = StringVariable( + id="var-1", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) + updater.flush.assert_called_once() + + +def test_skips_when_outputs_missing(): + conversation_id = "conv-456" + variable = StringVariable( + id="var-2", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_assigner_nodes(): + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.LLM) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_conversation_variables(): + conversation_id = "conv-789" + non_conversation_variable = StringVariable( + id="var-3", + name="name", + value="updated", + selector=["environment", "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] + ) + + variable_pool = MockReadOnlyVariablePool() + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 534420f21e..1d885f6b2e 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -1,4 +1,5 @@ import json +from collections.abc import Sequence from time import time from unittest.mock import Mock @@ -15,6 +16,7 @@ from core.app.layers.pause_state_persist_layer import ( from core.variables.segments import Segment from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, @@ -66,8 +68,10 @@ class MockReadOnlyVariablePool: def __init__(self, variables: dict[tuple[str, str], object] | None = None): self._variables = variables or {} - def get(self, node_id: str, variable_key: str) -> Segment | None: - value = self._variables.get((node_id, variable_key)) + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + value = self._variables.get((selector[0], selector[1])) if value is None: return None mock_segment = Mock(spec=Segment) @@ -209,8 +213,9 @@ class TestPauseStatePersistenceLayer: assert layer._session_maker is session_factory assert layer._state_owner_user_id == state_owner_user_id - assert not hasattr(layer, "graph_runtime_state") - assert not hasattr(layer, "command_channel") + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + assert layer.command_channel is None def test_initialize_sets_dependencies(self): session_factory = Mock(name="session_factory") @@ -295,7 +300,7 @@ class TestPauseStatePersistenceLayer: mock_factory.assert_not_called() mock_repo.create_workflow_pause.assert_not_called() - def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): + def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self): session_factory = Mock(name="session_factory") layer = PauseStatePersistenceLayer( session_factory=session_factory, @@ -305,7 +310,7 @@ class TestPauseStatePersistenceLayer: event = TestDataFactory.create_graph_run_paused_event() - with pytest.raises(AttributeError): + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 3203aab8c3..f9e59a5f05 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,8 +1,12 @@ """Primarily used for testing merged cell scenarios""" +import os +import tempfile from types import SimpleNamespace from docx import Document +from docx.oxml import OxmlElement +from docx.oxml.ns import qn import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor @@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url(): dify_config.FILES_URL = original_files_url if original_internal_files_url is not None: dify_config.INTERNAL_FILES_URL = original_internal_files_url + + +def test_extract_hyperlinks(monkeypatch): + # Mock db and storage to avoid issues during image extraction (even if no images are present) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph("Visit ") + + # Adding a hyperlink manually + r_id = "rId99" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + new_run = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + new_run.append(t) + hyperlink.append(new_run) + p._p.append(hyperlink) + + # Add relationship to the part + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify modern hyperlink extraction + assert "Visit[Dify](https://dify.ai)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_extract_legacy_hyperlinks(monkeypatch): + # Mock db and storage + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph() + + # Construct a legacy HYPERLINK field: + # 1. w:fldChar (begin) + # 2. w:instrText (HYPERLINK "http://example.com") + # 3. w:fldChar (separate) + # 4. w:r (visible text "Example") + # 5. w:fldChar (end) + + run1 = OxmlElement("w:r") + fldCharBegin = OxmlElement("w:fldChar") + fldCharBegin.set(qn("w:fldCharType"), "begin") + run1.append(fldCharBegin) + p._p.append(run1) + + run2 = OxmlElement("w:r") + instrText = OxmlElement("w:instrText") + instrText.text = ' HYPERLINK "http://example.com" ' + run2.append(instrText) + p._p.append(run2) + + run3 = OxmlElement("w:r") + fldCharSep = OxmlElement("w:fldChar") + fldCharSep.set(qn("w:fldCharType"), "separate") + run3.append(fldCharSep) + p._p.append(run3) + + run4 = OxmlElement("w:r") + t4 = OxmlElement("w:t") + t4.text = "Example" + run4.append(t4) + p._p.append(run4) + + run5 = OxmlElement("w:r") + fldCharEnd = OxmlElement("w:fldChar") + fldCharEnd.set(qn("w:fldCharType"), "end") + run5.append(fldCharEnd) + p._p.append(run5) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify legacy hyperlink extraction + assert "[Example](http://example.com)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 6306d665e7..ca08cb0591 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -73,6 +73,7 @@ import pytest from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from models.dataset import Dataset @@ -1518,6 +1519,282 @@ class TestRetrievalService: call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model + # ==================== Multiple Retrieve Thread Tests ==================== + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset( + self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1. + + When there is only one dataset, the second reranking is unnecessary + because the documents are already ranked from the first retrieval. + This optimization avoids the overhead of reranking when it won't + provide any benefit. + + Verifies: + - DataPostProcessor is NOT called when dataset_count == 1 + - Documents are still added to all_documents + - Standard scoring logic is applied instead + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + all_documents = [] + + # Act - Call with dataset_count = 1 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=1, # Single dataset - should skip second reranking + ) + + # Assert + # DataPostProcessor should NOT be called (second reranking skipped) + mock_data_processor_class.assert_not_called() + + # Documents should still be added to all_documents + assert len(all_documents) == 2 + assert all_documents[0].page_content == "Test content 1" + assert all_documents[1].page_content == "Test content 2" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score") + def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets( + self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1. + + When there are multiple datasets, the second reranking is necessary + to merge and re-rank results from different datasets. This ensures + the most relevant documents across all datasets are returned. + + Verifies: + - DataPostProcessor IS called when dataset_count > 1 + - Reranking is applied with correct parameters + - Documents are processed correctly + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + # Mock DataPostProcessor instance and its invoke method + mock_processor_instance = Mock() + # Simulate reranking - return documents in reversed order with updated scores + reranked_docs = [ + Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + ] + mock_processor_instance.invoke.return_value = reranked_docs + mock_data_processor_class.return_value = mock_processor_instance + + all_documents = [] + + # Create second dataset + mock_dataset2 = Mock(spec=Dataset) + mock_dataset2.id = str(uuid4()) + mock_dataset2.indexing_technique = "high_quality" + mock_dataset2.provider = "dify" + + # Act - Call with dataset_count = 2 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset, mock_dataset2], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=2, # Multiple datasets - should perform second reranking + ) + + # Assert + # DataPostProcessor SHOULD be called (second reranking performed) + mock_data_processor_class.assert_called_once_with( + tenant_id, + "reranking_model", + {"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + None, + False, + ) + + # Verify invoke was called with correct parameters + mock_processor_instance.invoke.assert_called_once() + + # Documents should be added to all_documents after reranking + assert len(all_documents) == 2 + # The reranked order should be reflected + assert all_documents[0].page_content == "Test content 2" + assert all_documents[1].page_content == "Test content 1" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score") + def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring( + self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1 + and reranking is enabled. + + When there's only one dataset, instead of using DataPostProcessor, + the method should fall through to the standard scoring logic + (calculate_vector_score for high_quality datasets). + + Verifies: + - DataPostProcessor is NOT called + - calculate_vector_score IS called for high_quality indexing + - Documents are scored correctly + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + # Mock calculate_vector_score to return scored documents + scored_docs = [ + Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + ] + mock_calculate_vector_score.return_value = scored_docs + + all_documents = [] + + # Act - Call with dataset_count = 1 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, # Reranking enabled but should be skipped for single dataset + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=1, + ) + + # Assert + # DataPostProcessor should NOT be called + mock_data_processor_class.assert_not_called() + + # calculate_vector_score SHOULD be called for high_quality datasets + mock_calculate_vector_score.assert_called_once() + call_args = mock_calculate_vector_score.call_args + assert call_args[0][1] == 5 # top_k + + # Documents should be added after standard scoring + assert len(all_documents) == 1 + assert all_documents[0].page_content == "Test content 1" + class TestRetrievalMethods: """ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py new file mode 100644 index 0000000000..5f461d53ae --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py @@ -0,0 +1,113 @@ +import threading +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from flask import Flask, current_app + +from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from models.dataset import Dataset + + +class TestRetrievalService: + @pytest.fixture + def mock_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = str(uuid4()) + dataset.tenant_id = str(uuid4()) + dataset.name = "test_dataset" + dataset.indexing_technique = "high_quality" + dataset.provider = "dify" + return dataset + + def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): + """ + Repro test for current bug: + reranking runs after `with flask_app.app_context():` exits. + `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, + so we must assert from that list (not from an outer try/except). + """ + dataset_retrieval = DatasetRetrieval() + flask_app = Flask(__name__) + tenant_id = str(uuid4()) + + # second dataset to ensure dataset_count > 1 reranking branch + secondary_dataset = Mock(spec=Dataset) + secondary_dataset.id = str(uuid4()) + secondary_dataset.provider = "dify" + secondary_dataset.indexing_technique = "high_quality" + + # retriever returns 1 doc into internal list (all_documents_item) + document = Document( + page_content="Context aware doc", + metadata={ + "doc_id": "doc1", + "score": 0.95, + "document_id": str(uuid4()), + "dataset_id": mock_dataset.id, + }, + provider="dify", + ) + + def fake_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.append(document) + + called = {"init": 0, "invoke": 0} + + class ContextRequiredPostProcessor: + def __init__(self, *args, **kwargs): + called["init"] += 1 + # will raise RuntimeError if no Flask app context exists + _ = current_app.name + + def invoke(self, *args, **kwargs): + called["invoke"] += 1 + _ = current_app.name + return kwargs.get("documents") or args[1] + + # output list from _multiple_retrieve_thread + all_documents: list[Document] = [] + + # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here + thread_exceptions: list[Exception] = [] + + def target(): + with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): + with patch( + "core.rag.retrieval.dataset_retrieval.DataPostProcessor", + ContextRequiredPostProcessor, + ): + dataset_retrieval._multiple_retrieve_thread( + flask_app=flask_app, + available_datasets=[mock_dataset, secondary_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={ + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-v2", + }, + weights=None, + top_k=3, + score_threshold=0.0, + query="test query", + attachment_id=None, + dataset_count=2, # force reranking branch + thread_exceptions=thread_exceptions, # ✅ key + ) + + t = threading.Thread(target=target) + t.start() + t.join() + + # Ensure reranking branch was actually executed + assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." + + # Current buggy code should record an exception (not raise it) + assert not thread_exceptions, thread_exceptions diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 8677325d4e..f33fd0deeb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,8 +3,15 @@ import json from unittest.mock import MagicMock +from core.variables import IntegerVariable, StringVariable from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + GraphEngineCommand, + UpdateVariablesCommand, + VariableUpdate, +) class TestRedisChannel: @@ -148,6 +155,43 @@ class TestRedisChannel: assert commands[0].command_type == CommandType.ABORT assert isinstance(commands[1], AbortCommand) + def test_fetch_commands_with_update_variables_command(self): + """Test fetching update variables command from Redis.""" + mock_redis = MagicMock() + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), + ), + ] + ) + command_json = json.dumps(update_command.model_dump()) + + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command_json.encode()], 1] + + channel = RedisChannel(mock_redis, "test:key") + commands = channel.fetch_commands() + + assert len(commands) == 1 + assert isinstance(commands[0], UpdateVariablesCommand) + assert isinstance(commands[0].updates[0].value, StringVariable) + assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] + assert commands[0].updates[0].value.value == "bar" + def test_fetch_commands_skips_invalid_json(self): """Test that invalid JSON commands are skipped.""" mock_redis = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py new file mode 100644 index 0000000000..d6ba61c50c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import pytest + +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_engine.layers.base import ( + GraphEngineLayer, + GraphEngineLayerNotInitializedError, +) +from core.workflow.graph_events import GraphEngineEvent + +from ..test_table_runner import WorkflowRunner + + +class LayerForTest(GraphEngineLayer): + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + pass + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_layer_runtime_state_raises_when_uninitialized() -> None: + layer = LayerForTest() + + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + + +def test_layer_runtime_state_available_after_engine_layer() -> None: + runner = WorkflowRunner() + fixture_data = runner.load_fixture("simple_passthrough_workflow") + graph, graph_runtime_state = runner.create_graph_from_fixture( + fixture_data, + inputs={"query": "test layer state"}, + ) + engine = GraphEngine( + workflow_id="test_workflow", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + ) + + layer = LayerForTest() + engine.layer(layer) + + outputs = layer.graph_runtime_state.outputs + ready_queue_size = layer.graph_runtime_state.ready_queue_size + + assert outputs == {} + assert isinstance(ready_queue_size, int) + assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index c1fc4acd73..fe3ea576c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -3,6 +3,7 @@ from __future__ import annotations import queue +import threading from unittest import mock from core.workflow.entities.pause_reason import SchedulingPause @@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=execution_coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() assert event_queue.empty() @@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int: event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() @@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index b074a11be9..d826f7a900 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -4,12 +4,19 @@ import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables import IntegerVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent from core.workflow.nodes.start.start_node import StartNode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -180,3 +187,67 @@ def test_pause_command(): graph_execution = engine.graph_runtime_state.graph_execution assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] + + +def test_update_variables_command_updates_pool(): + """Test that GraphEngine updates variable pool via update variables command.""" + + shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=shared_runtime_state, + ) + mock_graph.nodes["start"] = start_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + command_channel = InMemoryChannel() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=shared_runtime_state, + command_channel=command_channel, + ) + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), + ), + ] + ) + command_channel.send_command(update_command) + + list(engine.run()) + + updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) + added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) + + assert updated_existing is not None + assert updated_existing.value == "new value" + assert added_new is not None + assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index b02f90588b..5ceb8dd7f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing the behavior of mock nodes during testing. """ +from __future__ import annotations + from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -95,67 +97,67 @@ class MockConfigBuilder: def __init__(self) -> None: self._config = MockConfig() - def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder": + def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable auto-mocking.""" self._config.enable_auto_mock = enabled return self - def with_delays(self, enabled: bool = True) -> "MockConfigBuilder": + def with_delays(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable simulated execution delays.""" self._config.simulate_delays = enabled return self - def with_llm_response(self, response: str) -> "MockConfigBuilder": + def with_llm_response(self, response: str) -> MockConfigBuilder: """Set default LLM response.""" self._config.default_llm_response = response return self - def with_agent_response(self, response: str) -> "MockConfigBuilder": + def with_agent_response(self, response: str) -> MockConfigBuilder: """Set default agent response.""" self._config.default_agent_response = response return self - def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default tool response.""" self._config.default_tool_response = response return self - def with_retrieval_response(self, response: str) -> "MockConfigBuilder": + def with_retrieval_response(self, response: str) -> MockConfigBuilder: """Set default retrieval response.""" self._config.default_retrieval_response = response return self - def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default HTTP response.""" self._config.default_http_response = response return self - def with_template_transform_response(self, response: str) -> "MockConfigBuilder": + def with_template_transform_response(self, response: str) -> MockConfigBuilder: """Set default template transform response.""" self._config.default_template_transform_response = response return self - def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default code execution response.""" self._config.default_code_response = response return self - def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder": + def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder: """Set outputs for a specific node.""" self._config.set_node_outputs(node_id, outputs) return self - def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder": + def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder: """Set error for a specific node.""" self._config.set_node_error(node_id, error) return self - def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder": + def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder: """Add a node-specific configuration.""" self._config.set_node_config(config.node_id, config) return self - def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder": + def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder: """Set default configuration for a node type.""" self._config.set_default_config(node_type, config) return self diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index eeffdd27fe..6e9a432745 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -103,13 +103,25 @@ class MockNodeFactory(DifyNodeFactory): # Create mock node instance mock_class = self._mock_node_types[node_type] - mock_instance = mock_class( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - mock_config=self.mock_config, - ) + if node_type == NodeType.CODE: + mock_instance = mock_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + mock_config=self.mock_config, + code_executor=self._code_executor, + code_providers=self._code_providers, + code_limits=self._code_limits, + ) + else: + mock_instance = mock_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + mock_config=self.mock_config, + ) return mock_instance diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index fd94a5e833..5937bbfb39 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -40,12 +40,14 @@ class MockNodeMixin: graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, + **kwargs: Any, ): super().__init__( id=id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + **kwargs, ) self.mock_config = mock_config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 4fb693a5c2..de08cc3497 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -5,11 +5,24 @@ This module tests the functionality of MockTemplateTransformNode and MockCodeNod to ensure they work correctly with the TableTestRunner. """ +from configs import dify_config from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode +DEFAULT_CODE_LIMITS = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, +) + class TestMockTemplateTransformNode: """Test cases for MockTemplateTransformNode.""" @@ -306,6 +319,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node @@ -370,6 +384,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node @@ -438,6 +453,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py new file mode 100644 index 0000000000..ea8d3a977f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py @@ -0,0 +1,539 @@ +""" +Unit tests for stop_event functionality in GraphEngine. + +Tests the unified stop_event management by GraphEngine and its propagation +to WorkerPool, Worker, Dispatcher, and Nodes. +""" + +import threading +import time +from unittest.mock import MagicMock, Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, +) +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from models.enums import UserFrom + + +class TestStopEventPropagation: + """Test suite for stop_event propagation through GraphEngine components.""" + + def test_graph_engine_creates_stop_event(self): + """Test that GraphEngine creates a stop_event on initialization.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify stop_event was created + assert engine._stop_event is not None + assert isinstance(engine._stop_event, threading.Event) + + # Verify it was set in graph_runtime_state + assert runtime_state.stop_event is not None + assert runtime_state.stop_event is engine._stop_event + + def test_stop_event_cleared_on_start(self): + """Test that stop_event is cleared when execution starts.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set the stop_event before running + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear the stop_event) + events = list(engine.run()) + + # After running, stop_event should be set again (by _stop_execution) + # But during start it was cleared + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + def test_stop_event_set_on_stop(self): + """Test that stop_event is set when execution stops.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Initially not set + assert not engine._stop_event.is_set() + + # Run the engine + list(engine.run()) + + # After execution completes, stop_event should be set + assert engine._stop_event.is_set() + + def test_stop_event_passed_to_worker_pool(self): + """Test that stop_event is passed to WorkerPool.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify WorkerPool has the stop_event + assert engine._worker_pool._stop_event is not None + assert engine._worker_pool._stop_event is engine._stop_event + + def test_stop_event_passed_to_dispatcher(self): + """Test that stop_event is passed to Dispatcher.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify Dispatcher has the stop_event + assert engine._dispatcher._stop_event is not None + assert engine._dispatcher._stop_event is engine._stop_event + + +class TestNodeStopCheck: + """Test suite for Node._should_stop() functionality.""" + + def test_node_should_stop_checks_runtime_state(self): + """Test that Node._should_stop() checks GraphRuntimeState.stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Initially stop_event is not set + assert not answer_node._should_stop() + + # Set the stop_event + runtime_state.stop_event.set() + + # Now _should_stop should return True + assert answer_node._should_stop() + + def test_node_run_checks_stop_event_between_yields(self): + """Test that Node.run() checks stop_event between yielding events.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a simple node + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Set stop_event BEFORE running the node + runtime_state.stop_event.set() + + # Run the node - should yield start event then detect stop + # The node should check stop_event before processing + assert answer_node._should_stop(), "stop_event should be set" + + # Run and collect events + events = list(answer_node.run()) + + # Since stop_event is set at the start, we should get: + # 1. NodeRunStartedEvent (always yielded first) + # 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast) + assert len(events) >= 2 + assert isinstance(events[0], NodeRunStartedEvent) + + # Note: AnswerNode is very simple and might complete before stop check + # The important thing is that _should_stop() returns True when stop_event is set + assert answer_node._should_stop() + + +class TestStopEventIntegration: + """Integration tests for stop_event in workflow execution.""" + + def test_simple_workflow_respects_stop_event(self): + """Test that a simple workflow respects stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + # Create start and answer nodes + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + mock_graph.nodes["start"] = start_node + mock_graph.nodes["answer"] = answer_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set stop_event before running + runtime_state.stop_event.set() + + # Run the engine + events = list(engine.run()) + + # Should get started event but not succeeded (due to stop) + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + # The workflow should still complete (start node runs quickly) + # but answer node might be cancelled depending on timing + + def test_stop_event_with_concurrent_nodes(self): + """Test stop_event behavior with multiple concurrent nodes.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + # Create multiple nodes + for i in range(3): + answer_node = AnswerNode( + id=f"answer_{i}", + config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes[f"answer_{i}"] = answer_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # All nodes should share the same stop_event + for node in mock_graph.nodes.values(): + assert node.graph_runtime_state.stop_event is runtime_state.stop_event + assert node.graph_runtime_state.stop_event is engine._stop_event + + +class TestStopEventTimeoutBehavior: + """Test stop_event behavior with join timeouts.""" + + @patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread") + def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock): + """Test that Dispatcher uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + dispatcher = engine._dispatcher + dispatcher.start() # This will create and start the mocked thread + + mock_thread_instance = mock_thread_cls.return_value + mock_thread_instance.is_alive.return_value = True + + dispatcher.stop() + + mock_thread_instance.join.assert_called_once_with(timeout=2.0) + + @patch("core.workflow.graph_engine.worker_management.worker_pool.Worker") + def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock): + """Test that WorkerPool uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + worker_pool = engine._worker_pool + worker_pool.start(initial_count=1) # Start with one worker + + mock_worker_instance = mock_worker_cls.return_value + mock_worker_instance.is_alive.return_value = True + + worker_pool.stop() + + mock_worker_instance.join.assert_called_once_with(timeout=2.0) + + +class TestStopEventResumeBehavior: + """Test stop_event behavior during workflow resume.""" + + def test_stop_event_cleared_on_resume(self): + """Test that stop_event is cleared when resuming a paused workflow.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Simulate a previous execution that set stop_event + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear stop_event in _start_execution) + events = list(engine.run()) + + # Execution should complete successfully + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + +class TestWorkerStopBehavior: + """Test Worker behavior with shared stop_event.""" + + def test_worker_uses_shared_stop_event(self): + """Test that Worker uses shared stop_event from GraphEngine.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Get the worker pool and check workers + worker_pool = engine._worker_pool + + # Start the worker pool to create workers + worker_pool.start() + + # Check that at least one worker was created + assert len(worker_pool._workers) > 0 + + # Verify workers use the shared stop_event + for worker in worker_pool._workers: + assert worker._stop_event is engine._stop_event + + # Clean up + worker_pool.stop() + + def test_worker_stop_is_noop(self): + """Test that Worker.stop() is now a no-op.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a mock worker + from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + from core.workflow.graph_engine.worker import Worker + + ready_queue = InMemoryReadyQueue() + event_queue = MagicMock() + + # Create a proper mock graph with real dict + mock_graph = Mock(spec=Graph) + mock_graph.nodes = {} # Use real dict + + stop_event = threading.Event() + + worker = Worker( + ready_queue=ready_queue, + event_queue=event_queue, + graph=mock_graph, + layers=[], + stop_event=stop_event, + ) + + # Calling stop() should do nothing (no-op) + # and should NOT set the stop_event + worker.stop() + assert not stop_event.is_set() diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 596e72ddd0..2262d25a14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,3 +1,4 @@ +from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage from core.variables.types import SegmentType from core.workflow.nodes.code.code_node import CodeNode @@ -7,6 +8,18 @@ from core.workflow.nodes.code.exc import ( DepthLimitError, OutputValidationError, ) +from core.workflow.nodes.code.limits import CodeNodeLimits + +CodeNode._limits = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, +) class TestCodeNodeExceptions: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index e8f257bf2f..1e224d56a5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -78,7 +78,7 @@ class TestFileSaverImpl: file_binary=_PNG_DATA, mimetype=mime_type, ) - mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") + mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 1a67d5c3e3..66d6c3c56b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.helper.code_executor.code_executor import CodeExecutionError from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.workflow import WorkflowType @@ -127,7 +127,9 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_simple_template( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): @@ -145,7 +147,7 @@ class TestTemplateTransformNode: mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) # Setup mock executor - mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"} + mock_execute.return_value = "Hello Alice, you are 30 years old!" node = TemplateTransformNode( id="test_node", @@ -162,7 +164,9 @@ class TestTemplateTransformNode: assert result.inputs["name"] == "Alice" assert result.inputs["age"] == 30 - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with None variable values.""" node_data = { @@ -172,7 +176,7 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.return_value = None - mock_execute.return_value = {"result": "Value: "} + mock_execute.return_value = "Value: " node = TemplateTransformNode( id="test_node", @@ -187,13 +191,15 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs["value"] is None - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_code_execution_error( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): """Test _run when code execution fails.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.side_effect = CodeExecutionError("Template syntax error") + mock_execute.side_effect = TemplateRenderError("Template syntax error") node = TemplateTransformNode( id="test_node", @@ -208,14 +214,16 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Template syntax error" in result.error - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) @patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10) def test_run_output_length_exceeds_limit( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): """Test _run when output exceeds maximum length.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"} + mock_execute.return_value = "This is a very long output that exceeds the limit" node = TemplateTransformNode( id="test_node", @@ -230,7 +238,9 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_complex_jinja2_template( self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params ): @@ -257,7 +267,7 @@ class TestTemplateTransformNode: ("sys", "show_total"): mock_show_total, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"} + mock_execute.return_value = "apple, banana, orange (Total: 3)" node = TemplateTransformNode( id="test_node", @@ -292,7 +302,9 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -301,7 +313,7 @@ class TestTemplateTransformNode: "template": "This is a static message.", } - mock_execute.return_value = {"result": "This is a static message."} + mock_execute.return_value = "This is a static message." node = TemplateTransformNode( id="test_node", @@ -317,7 +329,9 @@ class TestTemplateTransformNode: assert result.outputs["output"] == "This is a static message." assert result.inputs == {} - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with numeric variable values.""" node_data = { @@ -339,7 +353,7 @@ class TestTemplateTransformNode: ("sys", "quantity"): mock_quantity, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = {"result": "Total: $31.5"} + mock_execute.return_value = "Total: $31.5" node = TemplateTransformNode( id="test_node", @@ -354,7 +368,9 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["output"] == "Total: $31.5" - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with dictionary variable values.""" node_data = { @@ -367,7 +383,7 @@ class TestTemplateTransformNode: mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"} mock_graph_runtime_state.variable_pool.get.return_value = mock_user - mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"} + mock_execute.return_value = "Name: John Doe, Email: john@example.com" node = TemplateTransformNode( id="test_node", @@ -383,7 +399,9 @@ class TestTemplateTransformNode: assert "John Doe" in result.outputs["output"] assert "john@example.com" in result.outputs["output"] - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with list variable values.""" node_data = { @@ -396,7 +414,7 @@ class TestTemplateTransformNode: mock_tags.to_object.return_value = ["python", "ai", "workflow"] mock_graph_runtime_state.variable_pool.get.return_value = mock_tags - mock_execute.return_value = {"result": "Tags: #python #ai #workflow "} + mock_execute.return_value = "Tags: #python #ai #workflow " node = TemplateTransformNode( id="test_node", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 09b8191870..06927cddcf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import types from collections.abc import Generator @@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only @pytest.fixture -def tool_node(monkeypatch) -> "ToolNode": +def tool_node(monkeypatch) -> ToolNode: module_name = "core.ops.ops_trace_manager" if module_name not in sys.modules: ops_stub = types.ModuleType(module_name) @@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: +def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: def _identity_transform(messages, *_args, **_kwargs): return messages @@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l return _collect_events(generator) -def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): +def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( tenant_id="tenant-id", type=FileType.DOCUMENT, @@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): assert files_segment.value == [file_obj] -def test_plain_link_messages_remain_links(tool_node: "ToolNode"): +def test_plain_link_messages_remain_links(tool_node: ToolNode): message = ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index c62fc4d8fe..1df75380af 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -1,14 +1,14 @@ import time import uuid -from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph +from core.workflow.graph_events.node import NodeRunSucceededEvent from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -86,9 +86,6 @@ def test_overwrite_string_variable(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -104,20 +101,14 @@ def test_overwrite_string_variable(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = StringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=input_variable.value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == input_variable.value got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -191,9 +182,6 @@ def test_append_variable_to_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -209,22 +197,14 @@ def test_append_variable_to_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_value = list(conversation_variable.value) - expected_value.append(input_variable.value) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=expected_value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == ["the first value", "the second value"] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -287,9 +267,6 @@ def test_clear_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -305,20 +282,14 @@ def test_clear_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=[], - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == [] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index caa36734ad..353d56fe25 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -390,3 +390,42 @@ def test_remove_last_from_empty_array(): got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None assert got.to_object() == [] + + +def test_node_factory_creates_variable_assigner_node(): + graph_config = { + "edges": [], + "nodes": [ + { + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + variable_pool = VariablePool( + system_variables=SystemVariable(conversation_id="conversation_id"), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + node_factory = DifyNodeFactory( + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + ) + + node = node_factory.create_node(graph_config["nodes"][0]) + + assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/fields/test_file_fields.py b/api/tests/unit_tests/fields/test_file_fields.py new file mode 100644 index 0000000000..8be8df16f4 --- /dev/null +++ b/api/tests/unit_tests/fields/test_file_fields.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig + + +def test_file_response_serializes_datetime() -> None: + created_at = datetime(2024, 1, 1, 12, 0, 0) + file_obj = SimpleNamespace( + id="file-1", + name="example.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by="user-1", + created_at=created_at, + preview_url="https://preview", + source_url="https://source", + original_url="https://origin", + user_id="user-1", + tenant_id="tenant-1", + conversation_id="conv-1", + file_key="key-1", + ) + + serialized = FileResponse.model_validate(file_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["id"] == "file-1" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["preview_url"] == "https://preview" + assert serialized["source_url"] == "https://source" + assert serialized["original_url"] == "https://origin" + assert serialized["user_id"] == "user-1" + assert serialized["tenant_id"] == "tenant-1" + assert serialized["conversation_id"] == "conv-1" + assert serialized["file_key"] == "key-1" + + +def test_file_with_signed_url_builds_payload() -> None: + payload = FileWithSignedUrl( + id="file-2", + name="remote.pdf", + size=2048, + extension="pdf", + url="https://signed", + mime_type="application/pdf", + created_by="user-2", + created_at=datetime(2024, 1, 2, 0, 0, 0), + ) + + dumped = payload.model_dump(mode="json") + + assert dumped["url"] == "https://signed" + assert dumped["created_at"] == int(datetime(2024, 1, 2, 0, 0, 0).timestamp()) + + +def test_remote_file_info_and_upload_config() -> None: + info = RemoteFileInfo(file_type="text/plain", file_length=123) + assert info.model_dump(mode="json") == {"file_type": "text/plain", "file_length": 123} + + config = UploadConfig( + file_size_limit=1, + batch_count_limit=2, + file_upload_limit=3, + image_file_size_limit=4, + video_file_size_limit=5, + audio_file_size_limit=6, + workflow_file_upload_limit=7, + image_file_batch_limit=8, + single_chunk_attachment_limit=9, + attachment_image_file_size_limit=10, + ) + + dumped = config.model_dump(mode="json") + assert dumped["file_upload_limit"] == 3 + assert dumped["attachment_image_file_size_limit"] == 10 diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index 85789bfa7e..de74eff82f 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,6 +1,6 @@ import pytest -from libs.helper import extract_tenant_id +from libs.helper import escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -63,3 +63,51 @@ class TestExtractTenantId: with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): extract_tenant_id(dict_user) + + +class TestEscapeLikePattern: + """Test cases for the escape_like_pattern utility function.""" + + def test_escape_percent_character(self): + """Test escaping percent character.""" + result = escape_like_pattern("50% discount") + assert result == "50\\% discount" + + def test_escape_underscore_character(self): + """Test escaping underscore character.""" + result = escape_like_pattern("test_data") + assert result == "test\\_data" + + def test_escape_backslash_character(self): + """Test escaping backslash character.""" + result = escape_like_pattern("path\\to\\file") + assert result == "path\\\\to\\\\file" + + def test_escape_combined_special_characters(self): + """Test escaping multiple special characters together.""" + result = escape_like_pattern("file_50%\\path") + assert result == "file\\_50\\%\\\\path" + + def test_escape_empty_string(self): + """Test escaping empty string returns empty string.""" + result = escape_like_pattern("") + assert result == "" + + def test_escape_none_handling(self): + """Test escaping None returns None (falsy check handles it).""" + # The function checks `if not pattern`, so None is falsy and returns as-is + result = escape_like_pattern(None) + assert result is None + + def test_escape_normal_string_no_change(self): + """Test that normal strings without special characters are unchanged.""" + result = escape_like_pattern("normal text") + assert result == "normal text" + + def test_escape_order_matters(self): + """Test that backslash is escaped first to prevent double escaping.""" + # If we escape % first, then escape \, we might get wrong results + # This test ensures the order is correct: \ first, then % and _ + result = escape_like_pattern("test\\%_value") + # Should be: test\\\%\_value + assert result == "test\\\\\\%\\_value" diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index e35788660d..8be2eea121 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -114,7 +114,7 @@ class TestAppModelValidation: def test_icon_type_validation(self): """Test icon type enum values.""" # Assert - assert {t.value for t in IconType} == {"image", "emoji"} + assert {t.value for t in IconType} == {"image", "emoji", "link"} def test_app_desc_or_prompt_with_description(self): """Test desc_or_prompt property when description exists.""" diff --git a/api/uv.lock b/api/uv.lock index fa032fa8d4..8e60fad3a7 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -2955,14 +2955,14 @@ wheels = [ [[package]] name = "intersystems-irispython" -version = "5.3.0" +version = "5.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" }, - { url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" }, - { url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" }, - { url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" }, - { url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" }, + { url = "https://files.pythonhosted.org/packages/33/5b/8eac672a6ef26bef6ef79a7c9557096167b50c4d3577d558ae6999c195fe/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-macosx_10_9_universal2.whl", hash = "sha256:634c9b4ec620837d830ff49543aeb2797a1ce8d8570a0e868398b85330dfcc4d", size = 6736686, upload-time = "2025-12-19T16:24:57.734Z" }, + { url = "https://files.pythonhosted.org/packages/ba/17/bab3e525ffb6711355f7feea18c1b7dced9c2484cecbcdd83f74550398c0/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf912f30f85e2a42f2c2ea77fbeb98a24154d5ea7428a50382786a684ec4f583", size = 16005259, upload-time = "2025-12-19T16:25:05.578Z" }, + { url = "https://files.pythonhosted.org/packages/39/59/9bb79d9e32e3e55fc9aed8071a797b4497924cbc6457cea9255bb09320b7/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be5659a6bb57593910f2a2417eddb9f5dc2f93a337ead6ddca778f557b8a359a", size = 15638040, upload-time = "2025-12-19T16:24:54.429Z" }, + { url = "https://files.pythonhosted.org/packages/cf/47/654ccf9c5cca4f5491f070888544165c9e2a6a485e320ea703e4e38d2358/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win32.whl", hash = "sha256:583e4f17088c1e0530f32efda1c0ccb02993cbc22035bc8b4c71d8693b04ee7e", size = 2879644, upload-time = "2025-12-19T16:24:59.945Z" }, + { url = "https://files.pythonhosted.org/packages/68/95/19cc13d09f1b4120bd41b1434509052e1d02afd27f2679266d7ad9cc1750/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win_amd64.whl", hash = "sha256:1d5d40450a0cdeec2a1f48d12d946a8a8ffc7c128576fcae7d58e66e3a127eae", size = 3522092, upload-time = "2025-12-19T16:25:01.834Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index c3feccb102..09ee1060e2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -58,8 +58,8 @@ FILES_URL= INTERNAL_FILES_URL= # Ensure UTF-8 encoding -LANG=en_US.UTF-8 -LC_ALL=en_US.UTF-8 +LANG=C.UTF-8 +LC_ALL=C.UTF-8 PYTHONIOENCODING=utf-8 # ------------------------------ @@ -233,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false # You can adjust the database configuration according to your needs. # ------------------------------ -# Database type, supported values are `postgresql` and `mysql` +# Database type, supported values are `postgresql`, `mysql`, `oceanbase`, `seekdb` DB_TYPE=postgresql # For MySQL, only `root` user is supported for now DB_USERNAME=postgres @@ -533,7 +533,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. +# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -544,9 +544,9 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_GRPC_ENDPOINT=grpc://weaviate:50051 WEAVIATE_TOKENIZATION=word -# For OceanBase metadata database configuration, available when `DB_TYPE` is `mysql` and `COMPOSE_PROFILES` includes `oceanbase`. +# For OceanBase metadata database configuration, available when `DB_TYPE` is `oceanbase`. # For OceanBase vector database configuration, available when `VECTOR_STORE` is `oceanbase` -# If you want to use OceanBase as both vector database and metadata database, you need to set `DB_TYPE` to `mysql`, `COMPOSE_PROFILES` is `oceanbase`, and set Database Configuration is the same as the vector database. +# If you want to use OceanBase as both vector database and metadata database, you need to set both `DB_TYPE` and `VECTOR_STORE` to `oceanbase`, and set Database Configuration is the same as the vector database. # seekdb is the lite version of OceanBase and shares the connection configuration with OceanBase. OCEANBASE_VECTOR_HOST=oceanbase OCEANBASE_VECTOR_PORT=2881 @@ -1077,6 +1077,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 3c88cddf8c..709aff23df 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -475,7 +475,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a67141ce05..712de84c62 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -13,8 +13,8 @@ x-shared-env: &shared-api-worker-env APP_WEB_URL: ${APP_WEB_URL:-} FILES_URL: ${FILES_URL:-} INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} - LANG: ${LANG:-en_US.UTF-8} - LC_ALL: ${LC_ALL:-en_US.UTF-8} + LANG: ${LANG:-C.UTF-8} + LC_ALL: ${LC_ALL:-C.UTF-8} PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8} LOG_LEVEL: ${LOG_LEVEL:-INFO} LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} @@ -475,6 +475,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365} LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false} LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true} + LOGSTORE_ENABLE_PUT_GRAPH_FIELD: ${LOGSTORE_ENABLE_PUT_GRAPH_FIELD:-true} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} @@ -1156,7 +1157,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index f7e0252a6f..c88dbe5511 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -233,4 +233,8 @@ ALIYUN_SLS_LOGSTORE_TTL=365 LOGSTORE_DUAL_WRITE_ENABLED=true # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database -LOGSTORE_DUAL_READ_ENABLED=true \ No newline at end of file +LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true \ No newline at end of file diff --git a/docker/ssrf_proxy/squid.conf.template b/docker/ssrf_proxy/squid.conf.template index 1775a1fff9..256e669c8d 100644 --- a/docker/ssrf_proxy/squid.conf.template +++ b/docker/ssrf_proxy/squid.conf.template @@ -54,3 +54,52 @@ http_access allow src_all # Unless the option's size is increased, an error will occur when uploading more than two files. client_request_buffer_max_size 100 MB + +################################## Performance & Concurrency ############################### +# Increase file descriptor limit for high concurrency +max_filedescriptors 65536 + +# Timeout configurations for image requests +connect_timeout 30 seconds +request_timeout 2 minutes +read_timeout 2 minutes +client_lifetime 5 minutes +shutdown_lifetime 30 seconds + +# Persistent connections - improve performance for multiple requests +server_persistent_connections on +client_persistent_connections on +persistent_request_timeout 30 seconds +pconn_timeout 1 minute + +# Connection pool and concurrency limits +client_db on +server_idle_pconn_timeout 2 minutes +client_idle_pconn_timeout 2 minutes + +# Quick abort settings - don't abort requests that are mostly done +quick_abort_min 16 KB +quick_abort_max 16 MB +quick_abort_pct 95 + +# Memory and cache optimization +memory_cache_mode disk +cache_mem 256 MB +maximum_object_size_in_memory 512 KB + +# DNS resolver settings for better performance +dns_timeout 30 seconds +dns_retransmit_interval 5 seconds +# By default, Squid uses the system's configured DNS resolvers. +# If you need to override them, set dns_nameservers to appropriate servers +# for your environment (for example, internal/corporate DNS). The following +# is an example using public DNS and SHOULD be customized before use: +# dns_nameservers 8.8.8.8 8.8.4.4 + +# Logging format for better debugging +logformat dify_log %ts.%03tu %6tr %>a %Ss/%03>Hs % ({ + getI18n: () => ({ + t: (key: string) => key, + language: 'en', + }), +})) + // Mock API functions vi.mock('@/service/base', () => ({ postMarketplace: vi.fn(), diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index 8080b565cd..1d65e4de53 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -1,11 +1,8 @@ -/* eslint-disable dify-i18n/require-ns-option */ -import * as React from 'react' +import { useTranslation } from '#i18n' import Form from '@/app/components/datasets/settings/form' -import { getLocaleOnServer, getTranslation } from '@/i18n-config/server' -const Settings = async () => { - const locale = await getLocaleOnServer() - const { t } = await getTranslation(locale, 'dataset-settings') +const Settings = () => { + const { t } = useTranslation('datasetSettings') return (
diff --git a/web/app/(commonLayout)/plugins/page.tsx b/web/app/(commonLayout)/plugins/page.tsx index 2df9cf23c4..81bda3a8a3 100644 --- a/web/app/(commonLayout)/plugins/page.tsx +++ b/web/app/(commonLayout)/plugins/page.tsx @@ -1,14 +1,12 @@ import Marketplace from '@/app/components/plugins/marketplace' import PluginPage from '@/app/components/plugins/plugin-page' import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel' -import { getLocaleOnServer } from '@/i18n-config/server' const PluginList = async () => { - const locale = await getLocaleOnServer() return ( } - marketplace={} + marketplace={} /> ) } diff --git a/web/app/components/apps/app-card-skeleton.tsx b/web/app/components/apps/app-card-skeleton.tsx new file mode 100644 index 0000000000..806f19973a --- /dev/null +++ b/web/app/components/apps/app-card-skeleton.tsx @@ -0,0 +1,41 @@ +'use client' + +import * as React from 'react' +import { SkeletonContainer, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton' + +type AppCardSkeletonProps = { + count?: number +} + +/** + * Skeleton placeholder for App cards during loading states. + * Matches the visual layout of AppCard component. + */ +export const AppCardSkeleton = React.memo(({ count = 6 }: AppCardSkeletonProps) => { + return ( + <> + {Array.from({ length: count }).map((_, index) => ( +
+ + + +
+ + +
+
+
+ + +
+
+
+ ))} + + ) +}) + +AppCardSkeleton.displayName = 'AppCardSkeleton' diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 003b463595..290a73fc7c 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -27,7 +27,9 @@ import { useGlobalPublicStore } from '@/context/global-public-context' import { CheckModal } from '@/hooks/use-pay' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum } from '@/types/app' +import { cn } from '@/utils/classnames' import AppCard from './app-card' +import { AppCardSkeleton } from './app-card-skeleton' import Empty from './empty' import Footer from './footer' import useAppsQueryState from './hooks/use-apps-query-state' @@ -45,7 +47,7 @@ const List = () => { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() const router = useRouter() - const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext() + const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [activeTab, setActiveTab] = useQueryState( 'category', @@ -89,6 +91,7 @@ const List = () => { const { data, isLoading, + isFetching, isFetchingNextPage, fetchNextPage, hasNextPage, @@ -172,6 +175,8 @@ const List = () => { const pages = data?.pages ?? [] const hasAnyApp = (pages[0]?.total ?? 0) > 0 + // Show skeleton during initial load or when refetching with no previous data + const showSkeleton = isLoading || (isFetching && pages.length === 0) return ( <> @@ -205,23 +210,34 @@ const List = () => { />
- {hasAnyApp - ? ( -
- {isCurrentWorkspaceEditor - && } - {pages.map(({ data: apps }) => apps.map(app => ( - - )))} -
- ) - : ( -
- {isCurrentWorkspaceEditor - && } - -
- )} +
+ {(isCurrentWorkspaceEditor || isLoadingCurrentWorkspace) && ( + + )} + {(() => { + if (showSkeleton) + return + + if (hasAnyApp) { + return pages.flatMap(({ data: apps }) => apps).map(app => ( + + )) + } + + // No apps - show empty state + return + })()} +
{isCurrentWorkspaceEditor && (
import('@/app/components/app/create-fro export type CreateAppCardProps = { className?: string + isLoading?: boolean onSuccess?: () => void ref: React.RefObject selectedAppType?: string @@ -33,6 +34,7 @@ export type CreateAppCardProps = { const CreateAppCard = ({ ref, className, + isLoading = false, onSuccess, selectedAppType, }: CreateAppCardProps) => { @@ -56,7 +58,11 @@ const CreateAppCard = ({ return (
{t('createApp', { ns: 'app' })}
diff --git a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx index 32ef133453..a6d51d8643 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx @@ -17,7 +17,7 @@ vi.mock('@/hooks/use-app-favicon', () => ({ useAppFavicon: vi.fn(), })) -vi.mock('@/i18n-config/i18next-config', () => ({ +vi.mock('@/i18n-config/client', () => ({ changeLanguage: vi.fn().mockResolvedValue(undefined), })) diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 5ff8e61ff6..ed1981b530 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -25,7 +25,7 @@ import { useToastContext } from '@/app/components/base/toast' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' -import { changeLanguage } from '@/i18n-config/i18next-config' +import { changeLanguage } from '@/i18n-config/client' import { delConversation, pinConversation, diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx index ca6a90c4d8..066fb8ebe9 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.spec.tsx @@ -13,7 +13,7 @@ import { shareQueryKeys } from '@/service/use-share' import { CONVERSATION_ID_INFO } from '../constants' import { useEmbeddedChatbot } from './hooks' -vi.mock('@/i18n-config/i18next-config', () => ({ +vi.mock('@/i18n-config/client', () => ({ changeLanguage: vi.fn().mockResolvedValue(undefined), })) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 803e905837..9028d10000 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -23,7 +23,7 @@ import { useToastContext } from '@/app/components/base/toast' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { InputVarType } from '@/app/components/workflow/types' import { useWebAppStore } from '@/context/web-app-context' -import { changeLanguage } from '@/i18n-config/i18next-config' +import { changeLanguage } from '@/i18n-config/client' import { updateFeedback } from '@/service/share' import { useInvalidateShareConversations, diff --git a/web/app/components/base/form/hooks/use-get-form-values.ts b/web/app/components/base/form/hooks/use-get-form-values.ts index 9ea418ea00..3dd2eceb30 100644 --- a/web/app/components/base/form/hooks/use-get-form-values.ts +++ b/web/app/components/base/form/hooks/use-get-form-values.ts @@ -4,7 +4,7 @@ import type { GetValuesOptions, } from '../types' import { useCallback } from 'react' -import { getTransformedValuesWhenSecretInputPristine } from '../utils' +import { getTransformedValuesWhenSecretInputPristine } from '../utils/secret-input' import { useCheckValidated } from './use-check-validated' export const useGetFormValues = (form: AnyFormApi, formSchemas: FormSchema[]) => { diff --git a/web/app/components/base/form/utils/index.ts b/web/app/components/base/form/utils/index.ts deleted file mode 100644 index 0abb8d1ad5..0000000000 --- a/web/app/components/base/form/utils/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './secret-input' diff --git a/web/app/components/base/form/utils/zod-submit-validator.ts b/web/app/components/base/form/utils/zod-submit-validator.ts new file mode 100644 index 0000000000..23eacaf8a4 --- /dev/null +++ b/web/app/components/base/form/utils/zod-submit-validator.ts @@ -0,0 +1,22 @@ +import type { ZodSchema } from 'zod' + +type SubmitValidator = ({ value }: { value: T }) => { fields: Record } | undefined + +export const zodSubmitValidator = (schema: ZodSchema): SubmitValidator => { + return ({ value }) => { + const result = schema.safeParse(value) + if (!result.success) { + const fieldErrors: Record = {} + for (const issue of result.error.issues) { + const path = issue.path[0] + if (path === undefined) + continue + const key = String(path) + if (!fieldErrors[key]) + fieldErrors[key] = issue.message + } + return { fields: fieldErrors } + } + return undefined + } +} diff --git a/web/app/components/base/icons/assets/public/llm/Tongyi.svg b/web/app/components/base/icons/assets/public/llm/Tongyi.svg deleted file mode 100644 index cca23b3aae..0000000000 --- a/web/app/components/base/icons/assets/public/llm/Tongyi.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - - - - - - - - - - - diff --git a/web/app/components/base/icons/assets/public/llm/anthropic-short-light.svg b/web/app/components/base/icons/assets/public/llm/anthropic-short-light.svg deleted file mode 100644 index c8e2370803..0000000000 --- a/web/app/components/base/icons/assets/public/llm/anthropic-short-light.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/web/app/components/base/icons/assets/public/llm/deepseek.svg b/web/app/components/base/icons/assets/public/llm/deepseek.svg deleted file mode 100644 index 046f89e1ce..0000000000 --- a/web/app/components/base/icons/assets/public/llm/deepseek.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/web/app/components/base/icons/assets/public/llm/gemini.svg b/web/app/components/base/icons/assets/public/llm/gemini.svg deleted file mode 100644 index 698f6ea629..0000000000 --- a/web/app/components/base/icons/assets/public/llm/gemini.svg +++ /dev/null @@ -1,105 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/web/app/components/base/icons/assets/public/llm/grok.svg b/web/app/components/base/icons/assets/public/llm/grok.svg deleted file mode 100644 index 6c0cbe227d..0000000000 --- a/web/app/components/base/icons/assets/public/llm/grok.svg +++ /dev/null @@ -1,11 +0,0 @@ - - - - - - - - - - - diff --git a/web/app/components/base/icons/assets/public/llm/openai-small.svg b/web/app/components/base/icons/assets/public/llm/openai-small.svg deleted file mode 100644 index 4af58790e4..0000000000 --- a/web/app/components/base/icons/assets/public/llm/openai-small.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - - - - - - - - - - - diff --git a/web/app/components/base/icons/src/public/llm/AnthropicShortLight.json b/web/app/components/base/icons/src/public/llm/AnthropicShortLight.json deleted file mode 100644 index 2a8ff2f28a..0000000000 --- a/web/app/components/base/icons/src/public/llm/AnthropicShortLight.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "40", - "height": "40", - "viewBox": "0 0 40 40", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "40", - "height": "40", - "fill": "white" - }, - "children": [] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M25.7926 10.1311H21.5089L29.3208 29.869H33.6045L25.7926 10.1311ZM13.4164 10.1311L5.60449 29.869H9.97273L11.5703 25.724H19.743L21.3405 29.869H25.7087L17.8969 10.1311H13.4164ZM12.9834 22.0583L15.6566 15.1217L18.3299 22.0583H12.9834Z", - "fill": "black" - }, - "children": [] - } - ] - }, - "name": "AnthropicShortLight" -} diff --git a/web/app/components/base/icons/src/public/llm/AnthropicShortLight.tsx b/web/app/components/base/icons/src/public/llm/AnthropicShortLight.tsx deleted file mode 100644 index 2bd21f48da..0000000000 --- a/web/app/components/base/icons/src/public/llm/AnthropicShortLight.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './AnthropicShortLight.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'AnthropicShortLight' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/Deepseek.json b/web/app/components/base/icons/src/public/llm/Deepseek.json deleted file mode 100644 index 1483974a02..0000000000 --- a/web/app/components/base/icons/src/public/llm/Deepseek.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "40", - "height": "40", - "viewBox": "0 0 40 40", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "40", - "height": "40", - "fill": "white" - }, - "children": [] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M36.6676 11.2917C36.3316 11.1277 36.1871 11.4402 35.9906 11.599C35.9242 11.6511 35.8668 11.7188 35.8108 11.7787C35.3199 12.3048 34.747 12.6485 33.9996 12.6068C32.9046 12.5469 31.971 12.8907 31.1455 13.7293C30.9696 12.6954 30.3863 12.0782 29.4996 11.6824C29.0348 11.4766 28.5647 11.2709 28.2406 10.823C28.0127 10.5053 27.9515 10.1511 27.8368 9.80214C27.7652 9.59121 27.6923 9.37506 27.4502 9.33861C27.1871 9.29694 27.0843 9.51829 26.9814 9.70318C26.5674 10.4584 26.4084 11.2917 26.4228 12.1355C26.4592 14.0313 27.26 15.5417 28.8486 16.6173C29.0296 16.7397 29.0764 16.8646 29.0191 17.0443C28.9111 17.4141 28.7822 17.7735 28.6676 18.1433C28.596 18.3803 28.4879 18.4323 28.2354 18.3282C27.363 17.9637 26.609 17.4246 25.9436 16.7709C24.8135 15.6771 23.7914 14.4689 22.5166 13.5235C22.2171 13.3021 21.919 13.0964 21.609 12.9011C20.3082 11.6355 21.7796 10.5964 22.1194 10.474C22.4762 10.3464 22.2431 9.9037 21.092 9.90891C19.9423 9.91413 18.889 10.2995 17.5478 10.8126C17.3512 10.8907 17.1455 10.948 16.9332 10.9922C15.7158 10.7631 14.4515 10.711 13.1298 10.8594C10.6428 11.1381 8.65587 12.3152 7.19493 14.3255C5.44102 16.7397 5.02826 19.4845 5.53347 22.349C6.06473 25.3646 7.60249 27.8646 9.96707 29.8178C12.4176 31.8413 15.2406 32.8334 18.4606 32.6433C20.4163 32.5313 22.5947 32.2683 25.0504 30.1875C25.6702 30.4949 26.3199 30.6173 27.3994 30.711C28.2302 30.7891 29.0296 30.6694 29.6494 30.5417C30.6194 30.3361 30.5518 29.4375 30.2015 29.2709C27.3578 27.9454 27.9814 28.4845 27.4136 28.0495C28.859 26.3361 31.0374 24.5574 31.889 18.797C31.9554 18.3386 31.898 18.0522 31.889 17.6798C31.8838 17.4558 31.9346 17.3673 32.1923 17.3413C32.9046 17.2605 33.596 17.0651 34.2314 16.7137C36.0739 15.7058 36.816 14.0522 36.9918 12.0678C37.0179 11.7657 36.9866 11.4506 36.6676 11.2917ZM20.613 29.1485C17.8564 26.9793 16.5204 26.2657 15.9684 26.297C15.4527 26.3255 15.5452 26.9167 15.6584 27.3022C15.777 27.6823 15.9319 27.9454 16.1494 28.2787C16.2991 28.5001 16.402 28.8307 15.9996 29.0755C15.1116 29.6277 13.5687 28.8907 13.4958 28.8542C11.7001 27.797 10.1988 26.3985 9.14025 24.487C8.11941 22.6459 7.52566 20.6719 7.42801 18.5651C7.40197 18.0547 7.5517 17.875 8.05691 17.7839C8.72227 17.6615 9.40978 17.6355 10.0751 17.7318C12.8876 18.1433 15.2822 19.4037 17.2887 21.3959C18.4346 22.5339 19.3018 23.8907 20.195 25.2162C21.1442 26.6251 22.1663 27.9662 23.4671 29.0651C23.9254 29.4506 24.2926 29.7449 24.6428 29.961C23.5856 30.0782 21.8199 30.1042 20.613 29.1485ZM21.9332 20.6407C21.9332 20.4141 22.1142 20.2345 22.342 20.2345C22.3928 20.2345 22.4398 20.2449 22.4814 20.2605C22.5374 20.2813 22.5895 20.3126 22.6299 20.3594C22.7027 20.4298 22.7444 20.5339 22.7444 20.6407C22.7444 20.8673 22.5635 21.047 22.3368 21.047C22.109 21.047 21.9332 20.8673 21.9332 20.6407ZM26.036 22.7501C25.7731 22.8569 25.51 22.9506 25.2575 22.961C24.8655 22.9793 24.4371 22.8203 24.204 22.6251C23.8434 22.323 23.5856 22.1537 23.4762 21.6225C23.4306 21.3959 23.4567 21.047 23.497 20.8465C23.5908 20.4141 23.4866 20.1381 23.1832 19.8855C22.9346 19.6798 22.6207 19.6251 22.2744 19.6251C22.1455 19.6251 22.027 19.5678 21.9384 19.5209C21.7939 19.4479 21.6754 19.2683 21.7887 19.047C21.8251 18.9766 22.001 18.8022 22.0426 18.7709C22.5114 18.5027 23.053 18.5913 23.5543 18.7918C24.0191 18.9818 24.3694 19.3307 24.8746 19.823C25.3915 20.4194 25.484 20.5861 25.7783 21.0313C26.01 21.3829 26.2223 21.7422 26.3668 22.1537C26.454 22.4089 26.3408 22.6198 26.036 22.7501Z", - "fill": "#4D6BFE" - }, - "children": [] - } - ] - }, - "name": "Deepseek" -} diff --git a/web/app/components/base/icons/src/public/llm/Deepseek.tsx b/web/app/components/base/icons/src/public/llm/Deepseek.tsx deleted file mode 100644 index b19beb8b8f..0000000000 --- a/web/app/components/base/icons/src/public/llm/Deepseek.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './Deepseek.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'Deepseek' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/Gemini.json b/web/app/components/base/icons/src/public/llm/Gemini.json deleted file mode 100644 index 3121b1ea19..0000000000 --- a/web/app/components/base/icons/src/public/llm/Gemini.json +++ /dev/null @@ -1,807 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "40", - "height": "40", - "viewBox": "0 0 40 40", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "40", - "height": "40", - "fill": "white" - }, - "children": [] - }, - { - "type": "element", - "name": "mask", - "attributes": { - "id": "mask0_3892_95663", - "style": "mask-type:alpha", - "maskUnits": "userSpaceOnUse", - "x": "6", - "y": "6", - "width": "28", - "height": "29" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M20 6C20.2936 6 20.5488 6.2005 20.6205 6.48556C20.8393 7.3566 21.1277 8.20866 21.4828 9.03356C22.4116 11.191 23.6854 13.0791 25.3032 14.6968C26.9218 16.3146 28.8095 17.5888 30.9664 18.5172C31.7941 18.8735 32.6436 19.16 33.5149 19.3795C33.6533 19.4143 33.7762 19.4942 33.8641 19.6067C33.9519 19.7192 33.9998 19.8578 34 20.0005C34 20.2941 33.7995 20.5492 33.5149 20.621C32.6437 20.8399 31.7915 21.1282 30.9664 21.4833C28.8095 22.4121 26.9209 23.6859 25.3032 25.3036C23.6854 26.9223 22.4116 28.8099 21.4828 30.9669C21.1278 31.7919 20.8394 32.6439 20.6205 33.5149C20.586 33.6534 20.5062 33.7764 20.3937 33.8644C20.2813 33.9524 20.1427 34.0003 20 34.0005C19.8572 34.0003 19.7186 33.9525 19.6062 33.8645C19.4937 33.7765 19.414 33.6535 19.3795 33.5149C19.1605 32.6439 18.872 31.7918 18.5167 30.9669C17.5884 28.8099 16.3151 26.9214 14.6964 25.3036C13.0782 23.6859 11.1906 22.4121 9.03309 21.4833C8.20814 21.1283 7.35608 20.8399 6.48509 20.621C6.34667 20.5864 6.22377 20.5065 6.13589 20.3941C6.04801 20.2817 6.00018 20.1432 6 20.0005C6.00024 19.8578 6.04808 19.7192 6.13594 19.6067C6.2238 19.4942 6.34667 19.4143 6.48509 19.3795C7.35612 19.1607 8.20819 18.8723 9.03309 18.5172C11.1906 17.5888 13.0786 16.3146 14.6964 14.6968C16.3141 13.0791 17.5884 11.191 18.5167 9.03356C18.8719 8.20862 19.1604 7.35656 19.3795 6.48556C19.4508 6.2005 19.7064 6 20 6Z", - "fill": "black" - }, - "children": [] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M20 6C20.2936 6 20.5488 6.2005 20.6205 6.48556C20.8393 7.3566 21.1277 8.20866 21.4828 9.03356C22.4116 11.191 23.6854 13.0791 25.3032 14.6968C26.9218 16.3146 28.8095 17.5888 30.9664 18.5172C31.7941 18.8735 32.6436 19.16 33.5149 19.3795C33.6533 19.4143 33.7762 19.4942 33.8641 19.6067C33.9519 19.7192 33.9998 19.8578 34 20.0005C34 20.2941 33.7995 20.5492 33.5149 20.621C32.6437 20.8399 31.7915 21.1282 30.9664 21.4833C28.8095 22.4121 26.9209 23.6859 25.3032 25.3036C23.6854 26.9223 22.4116 28.8099 21.4828 30.9669C21.1278 31.7919 20.8394 32.6439 20.6205 33.5149C20.586 33.6534 20.5062 33.7764 20.3937 33.8644C20.2813 33.9524 20.1427 34.0003 20 34.0005C19.8572 34.0003 19.7186 33.9525 19.6062 33.8645C19.4937 33.7765 19.414 33.6535 19.3795 33.5149C19.1605 32.6439 18.872 31.7918 18.5167 30.9669C17.5884 28.8099 16.3151 26.9214 14.6964 25.3036C13.0782 23.6859 11.1906 22.4121 9.03309 21.4833C8.20814 21.1283 7.35608 20.8399 6.48509 20.621C6.34667 20.5864 6.22377 20.5065 6.13589 20.3941C6.04801 20.2817 6.00018 20.1432 6 20.0005C6.00024 19.8578 6.04808 19.7192 6.13594 19.6067C6.2238 19.4942 6.34667 19.4143 6.48509 19.3795C7.35612 19.1607 8.20819 18.8723 9.03309 18.5172C11.1906 17.5888 13.0786 16.3146 14.6964 14.6968C16.3141 13.0791 17.5884 11.191 18.5167 9.03356C18.8719 8.20862 19.1604 7.35656 19.3795 6.48556C19.4508 6.2005 19.7064 6 20 6Z", - "fill": "url(#paint0_linear_3892_95663)" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "mask": "url(#mask0_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter0_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M3.47232 27.8921C6.70753 29.0411 10.426 26.8868 11.7778 23.0804C13.1296 19.274 11.6028 15.2569 8.36763 14.108C5.13242 12.959 1.41391 15.1133 0.06211 18.9197C-1.28969 22.7261 0.23711 26.7432 3.47232 27.8921Z", - "fill": "#FFE432" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter1_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M17.8359 15.341C22.2806 15.341 25.8838 11.6588 25.8838 7.11644C25.8838 2.57412 22.2806 -1.10815 17.8359 -1.10815C13.3912 -1.10815 9.78809 2.57412 9.78809 7.11644C9.78809 11.6588 13.3912 15.341 17.8359 15.341Z", - "fill": "#FC413D" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter2_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M14.7081 41.6431C19.3478 41.4163 22.8707 36.3599 22.5768 30.3493C22.283 24.3387 18.2836 19.65 13.644 19.8769C9.00433 20.1037 5.48139 25.1601 5.77525 31.1707C6.06911 37.1813 10.0685 41.87 14.7081 41.6431Z", - "fill": "#00B95C" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter3_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M14.7081 41.6431C19.3478 41.4163 22.8707 36.3599 22.5768 30.3493C22.283 24.3387 18.2836 19.65 13.644 19.8769C9.00433 20.1037 5.48139 25.1601 5.77525 31.1707C6.06911 37.1813 10.0685 41.87 14.7081 41.6431Z", - "fill": "#00B95C" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter4_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M19.355 38.0071C23.2447 35.6405 24.2857 30.2506 21.6803 25.9684C19.0748 21.6862 13.8095 20.1334 9.91983 22.5C6.03016 24.8666 4.98909 30.2565 7.59454 34.5387C10.2 38.8209 15.4653 40.3738 19.355 38.0071Z", - "fill": "#00B95C" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter5_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M35.0759 24.5504C39.4477 24.5504 42.9917 21.1377 42.9917 16.9278C42.9917 12.7179 39.4477 9.30518 35.0759 9.30518C30.7042 9.30518 27.1602 12.7179 27.1602 16.9278C27.1602 21.1377 30.7042 24.5504 35.0759 24.5504Z", - "fill": "#3186FF" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter6_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M0.362818 23.6667C4.3882 26.7279 10.2688 25.7676 13.4976 21.5219C16.7264 17.2762 16.0806 11.3528 12.0552 8.29156C8.02982 5.23037 2.14917 6.19062 -1.07959 10.4364C-4.30835 14.6821 -3.66256 20.6055 0.362818 23.6667Z", - "fill": "#FBBC04" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter7_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M20.9877 28.1903C25.7924 31.4936 32.1612 30.5732 35.2128 26.1346C38.2644 21.696 36.8432 15.4199 32.0385 12.1166C27.2338 8.81334 20.865 9.73372 17.8134 14.1723C14.7618 18.611 16.183 24.887 20.9877 28.1903Z", - "fill": "#3186FF" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter8_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M29.7231 4.99175C30.9455 6.65415 29.3748 9.88535 26.2149 12.2096C23.0549 14.5338 19.5026 15.0707 18.2801 13.4088C17.0576 11.7468 18.6284 8.51514 21.7883 6.19092C24.9482 3.86717 28.5006 3.32982 29.7231 4.99175Z", - "fill": "#749BFF" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter9_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M19.6891 12.9486C24.5759 8.41581 26.2531 2.27858 23.4354 -0.759249C20.6176 -3.79708 14.3718 -2.58516 9.485 1.94765C4.59823 6.48046 2.92099 12.6177 5.73879 15.6555C8.55658 18.6933 14.8024 17.4814 19.6891 12.9486Z", - "fill": "#FC413D" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "g", - "attributes": { - "filter": "url(#filter10_f_3892_95663)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M9.6712 29.23C12.5757 31.3088 15.9102 31.6247 17.1191 29.9356C18.328 28.2465 16.9535 25.1921 14.049 23.1133C11.1446 21.0345 7.81003 20.7186 6.60113 22.4077C5.39223 24.0968 6.76675 27.1512 9.6712 29.23Z", - "fill": "#FFEE48" - }, - "children": [] - } - ] - } - ] - }, - { - "type": "element", - "name": "defs", - "attributes": {}, - "children": [ - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter0_f_3892_95663", - "x": "-3.44095", - "y": "10.7885", - "width": "18.7217", - "height": "20.4229", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "1.50514", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter1_f_3892_95663", - "x": "-4.76352", - "y": "-15.6598", - "width": "45.1989", - "height": "45.5524", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "7.2758", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter2_f_3892_95663", - "x": "-6.61209", - "y": "7.49899", - "width": "41.5757", - "height": "46.522", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "6.18495", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter3_f_3892_95663", - "x": "-6.61209", - "y": "7.49899", - "width": "41.5757", - "height": "46.522", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "6.18495", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter4_f_3892_95663", - "x": "-6.21073", - "y": "9.02316", - "width": "41.6959", - "height": "42.4608", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "6.18495", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter5_f_3892_95663", - "x": "15.405", - "y": "-2.44994", - "width": "39.3423", - "height": "38.7556", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "5.87756", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter6_f_3892_95663", - "x": "-13.7886", - "y": "-4.15284", - "width": "39.9951", - "height": "40.2639", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "5.32691", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter7_f_3892_95663", - "x": "6.6925", - "y": "0.620963", - "width": "39.6414", - "height": "39.065", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "4.75678", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter8_f_3892_95663", - "x": "9.35225", - "y": "-4.48661", - "width": "29.2984", - "height": "27.3739", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "4.25649", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter9_f_3892_95663", - "x": "-2.81919", - "y": "-9.62339", - "width": "34.8122", - "height": "34.143", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "3.59514", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "filter", - "attributes": { - "id": "filter10_f_3892_95663", - "x": "-2.73761", - "y": "12.4221", - "width": "29.1949", - "height": "27.4994", - "filterUnits": "userSpaceOnUse", - "color-interpolation-filters": "sRGB" - }, - "children": [ - { - "type": "element", - "name": "feFlood", - "attributes": { - "flood-opacity": "0", - "result": "BackgroundImageFix" - }, - "children": [] - }, - { - "type": "element", - "name": "feBlend", - "attributes": { - "mode": "normal", - "in": "SourceGraphic", - "in2": "BackgroundImageFix", - "result": "shape" - }, - "children": [] - }, - { - "type": "element", - "name": "feGaussianBlur", - "attributes": { - "stdDeviation": "4.44986", - "result": "effect1_foregroundBlur_3892_95663" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "linearGradient", - "attributes": { - "id": "paint0_linear_3892_95663", - "x1": "13.9595", - "y1": "24.7349", - "x2": "28.5025", - "y2": "12.4738", - "gradientUnits": "userSpaceOnUse" - }, - "children": [ - { - "type": "element", - "name": "stop", - "attributes": { - "stop-color": "#4893FC" - }, - "children": [] - }, - { - "type": "element", - "name": "stop", - "attributes": { - "offset": "0.27", - "stop-color": "#4893FC" - }, - "children": [] - }, - { - "type": "element", - "name": "stop", - "attributes": { - "offset": "0.777", - "stop-color": "#969DFF" - }, - "children": [] - }, - { - "type": "element", - "name": "stop", - "attributes": { - "offset": "1", - "stop-color": "#BD99FE" - }, - "children": [] - } - ] - } - ] - } - ] - }, - "name": "Gemini" -} diff --git a/web/app/components/base/icons/src/public/llm/Gemini.tsx b/web/app/components/base/icons/src/public/llm/Gemini.tsx deleted file mode 100644 index f5430036bb..0000000000 --- a/web/app/components/base/icons/src/public/llm/Gemini.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './Gemini.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'Gemini' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/Grok.json b/web/app/components/base/icons/src/public/llm/Grok.json deleted file mode 100644 index 590f845eeb..0000000000 --- a/web/app/components/base/icons/src/public/llm/Grok.json +++ /dev/null @@ -1,72 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "40", - "height": "40", - "viewBox": "0 0 40 40", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "40", - "height": "40", - "fill": "white" - }, - "children": [] - }, - { - "type": "element", - "name": "g", - "attributes": { - "clip-path": "url(#clip0_3892_95659)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M15.745 24.54L26.715 16.35C27.254 15.95 28.022 16.106 28.279 16.73C29.628 20.018 29.025 23.971 26.341 26.685C23.658 29.399 19.924 29.995 16.511 28.639L12.783 30.384C18.13 34.081 24.623 33.166 28.681 29.06C31.9 25.805 32.897 21.368 31.965 17.367L31.973 17.376C30.622 11.498 32.305 9.149 35.755 4.345L36 4L31.46 8.59V8.576L15.743 24.544M13.48 26.531C9.643 22.824 10.305 17.085 13.58 13.776C16 11.327 19.968 10.328 23.432 11.797L27.152 10.06C26.482 9.57 25.622 9.043 24.637 8.673C20.182 6.819 14.848 7.742 11.227 11.401C7.744 14.924 6.648 20.341 8.53 24.962C9.935 28.416 7.631 30.86 5.31 33.326C4.49 34.2 3.666 35.074 3 36L13.478 26.534", - "fill": "black" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "defs", - "attributes": {}, - "children": [ - { - "type": "element", - "name": "clipPath", - "attributes": { - "id": "clip0_3892_95659" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "33", - "height": "32", - "fill": "white", - "transform": "translate(3 4)" - }, - "children": [] - } - ] - } - ] - } - ] - }, - "name": "Grok" -} diff --git a/web/app/components/base/icons/src/public/llm/Grok.tsx b/web/app/components/base/icons/src/public/llm/Grok.tsx deleted file mode 100644 index 8b378de490..0000000000 --- a/web/app/components/base/icons/src/public/llm/Grok.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './Grok.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'Grok' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/OpenaiBlue.json b/web/app/components/base/icons/src/public/llm/OpenaiBlue.json deleted file mode 100644 index c5d4f974a2..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiBlue.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "24", - "height": "24", - "viewBox": "0 0 24 24", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "24", - "height": "24", - "rx": "6", - "fill": "#03A4EE" - }, - "children": [] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M19.7758 11.5959C19.9546 11.9948 20.0681 12.4213 20.1145 12.8563C20.1592 13.2913 20.1369 13.7315 20.044 14.1596C19.9529 14.5878 19.7947 14.9987 19.5746 15.377C19.4302 15.6298 19.2599 15.867 19.0639 16.0854C18.8696 16.3021 18.653 16.4981 18.4174 16.67C18.1801 16.842 17.9274 16.9864 17.6591 17.105C17.3926 17.222 17.1141 17.3114 16.8286 17.3698C16.6945 17.7859 16.4951 18.1797 16.2371 18.5339C15.9809 18.8881 15.6697 19.1993 15.3155 19.4555C14.9613 19.7134 14.5693 19.9129 14.1532 20.047C13.7371 20.1829 13.302 20.2499 12.8636 20.2499C12.573 20.2516 12.2807 20.2207 11.9953 20.1622C11.7116 20.102 11.433 20.0109 11.1665 19.8923C10.9 19.7736 10.6472 19.6258 10.4116 19.4538C10.1778 19.2819 9.96115 19.0841 9.76857 18.8658C9.33871 18.9586 8.89853 18.981 8.46351 18.9363C8.02849 18.8898 7.60207 18.7763 7.20143 18.5975C6.80252 18.4204 6.43284 18.1797 6.10786 17.8857C5.78289 17.5916 5.50606 17.2478 5.28769 16.8695C5.14153 16.6167 5.02117 16.3502 4.93004 16.0734C4.83891 15.7965 4.77873 15.5111 4.74778 15.2205C4.71683 14.9317 4.71855 14.6393 4.7495 14.3488C4.78045 14.0599 4.84407 13.7745 4.9352 13.4976C4.64289 13.1727 4.40217 12.803 4.22335 12.4041C4.04624 12.0034 3.93104 11.5787 3.88634 11.1437C3.83991 10.7087 3.86398 10.2685 3.95511 9.84036C4.04624 9.41222 4.20443 9.00127 4.42452 8.62299C4.56896 8.37023 4.73918 8.13123 4.93348 7.91458C5.12778 7.69793 5.34615 7.50191 5.58171 7.32997C5.81728 7.15802 6.07176 7.01187 6.33827 6.89495C6.6065 6.7763 6.88506 6.68861 7.17048 6.63015C7.3046 6.21232 7.50406 5.82029 7.76026 5.46608C8.01817 5.11188 8.32939 4.80066 8.6836 4.54274C9.03781 4.28654 9.42984 4.08708 9.84595 3.95125C10.2621 3.81713 10.6971 3.74835 11.1355 3.75007C11.4261 3.74835 11.7184 3.77758 12.0039 3.83776C12.2893 3.89794 12.5678 3.98736 12.8344 4.106C13.1009 4.22636 13.3536 4.37251 13.5892 4.54446C13.8248 4.71812 14.0414 4.91414 14.234 5.13251C14.6621 5.04138 15.1023 5.01903 15.5373 5.06373C15.9723 5.10844 16.3971 5.22364 16.7977 5.40074C17.1966 5.57957 17.5663 5.81857 17.8913 6.1126C18.2162 6.4049 18.4931 6.74707 18.7114 7.12707C18.8576 7.37811 18.9779 7.64463 19.0691 7.92318C19.1602 8.20001 19.2221 8.48544 19.2513 8.77602C19.2823 9.06661 19.2823 9.35892 19.2496 9.64951C19.2187 9.94009 19.155 10.2255 19.0639 10.5024C19.3579 10.8273 19.5969 11.1953 19.7758 11.5959ZM14.0466 18.9363C14.4214 18.7815 14.7619 18.5528 15.049 18.2657C15.3362 17.9785 15.5648 17.6381 15.7196 17.2615C15.8743 16.8867 15.9552 16.4843 15.9552 16.0785V12.2442C15.954 12.2407 15.9529 12.2367 15.9517 12.2321C15.9506 12.2287 15.9488 12.2252 15.9466 12.2218C15.9443 12.2184 15.9414 12.2155 15.938 12.2132C15.9345 12.2098 15.9311 12.2075 15.9276 12.2063L14.54 11.4051V16.0373C14.54 16.0837 14.5332 16.1318 14.5211 16.1765C14.5091 16.223 14.4919 16.2659 14.4678 16.3072C14.4438 16.3485 14.4162 16.3863 14.3819 16.419C14.3484 16.4523 14.3109 16.4812 14.2701 16.505L10.9842 18.4015C10.9567 18.4187 10.9103 18.4428 10.8862 18.4565C11.0221 18.5717 11.1699 18.6732 11.3247 18.7626C11.4811 18.852 11.6428 18.9277 11.8113 18.9896C11.9798 19.0497 12.1535 19.0962 12.3288 19.1271C12.5059 19.1581 12.6848 19.1735 12.8636 19.1735C13.2694 19.1735 13.6717 19.0927 14.0466 18.9363ZM6.22135 16.333C6.42596 16.6855 6.69592 16.9916 7.01745 17.2392C7.34071 17.4868 7.70695 17.6673 8.09899 17.7722C8.49102 17.8771 8.90025 17.9046 9.3026 17.8513C9.70495 17.798 10.0918 17.6673 10.4443 17.4644L13.7663 15.5472L13.7749 15.5386C13.7772 15.5363 13.7789 15.5329 13.78 15.5283C13.7823 15.5249 13.7841 15.5214 13.7852 15.518V13.9017L9.77545 16.2212C9.73418 16.2453 9.6912 16.2625 9.64649 16.2763C9.60007 16.2883 9.55364 16.2935 9.5055 16.2935C9.45907 16.2935 9.41265 16.2883 9.36622 16.2763C9.32152 16.2625 9.27681 16.2453 9.23554 16.2212L5.94967 14.323C5.92044 14.3058 5.87746 14.28 5.85339 14.2645C5.82244 14.4416 5.80696 14.6204 5.80696 14.7993C5.80696 14.9781 5.82415 15.1569 5.85511 15.334C5.88605 15.5094 5.9342 15.6831 5.99438 15.8516C6.05628 16.0201 6.13194 16.1817 6.22135 16.3364V16.333ZM5.35818 9.1629C5.15529 9.51539 5.02461 9.90398 4.97131 10.3063C4.918 10.7087 4.94552 11.1162 5.0504 11.51C5.15529 11.902 5.33583 12.2682 5.58343 12.5915C5.83103 12.913 6.13881 13.183 6.48958 13.3859L9.80984 15.3048C9.81328 15.3059 9.81729 15.3071 9.82188 15.3082H9.83391C9.8385 15.3082 9.84251 15.3071 9.84595 15.3048C9.84939 15.3036 9.85283 15.3019 9.85627 15.2996L11.249 14.4949L7.23926 12.1805C7.19971 12.1565 7.16189 12.1272 7.1275 12.0946C7.09418 12.0611 7.06529 12.0236 7.04153 11.9828C7.01917 11.9415 7.00026 11.8985 6.98822 11.8521C6.97619 11.8074 6.96931 11.761 6.97103 11.7128V7.80797C6.80252 7.86987 6.63917 7.94553 6.48442 8.03494C6.32967 8.12607 6.18352 8.22924 6.04596 8.34444C5.91013 8.45965 5.78289 8.58688 5.66769 8.72444C5.55248 8.86028 5.45103 9.00815 5.36162 9.1629H5.35818ZM16.7633 11.8177C16.8046 11.8418 16.8424 11.8693 16.8768 11.9037C16.9094 11.9364 16.9387 11.9742 16.9628 12.0155C16.9851 12.0567 17.004 12.1014 17.0161 12.1461C17.0264 12.1926 17.0332 12.239 17.0315 12.2871V16.192C17.5835 15.9891 18.0649 15.6332 18.4208 15.1655C18.7785 14.6978 18.9934 14.139 19.0433 13.5544C19.0931 12.9698 18.9762 12.3817 18.7046 11.8607C18.4329 11.3397 18.0185 10.9064 17.5095 10.6141L14.1893 8.69521C14.1858 8.69406 14.1818 8.69292 14.1772 8.69177H14.1652C14.1618 8.69292 14.1578 8.69406 14.1532 8.69521C14.1497 8.69636 14.1463 8.69808 14.1429 8.70037L12.757 9.50163L16.7667 11.8177H16.7633ZM18.1475 9.7372H18.1457V9.73892L18.1475 9.7372ZM18.1457 9.73548C18.2455 9.15774 18.1784 8.56281 17.9514 8.02119C17.7262 7.47956 17.3496 7.01359 16.8682 6.67658C16.3867 6.34128 15.8193 6.1487 15.233 6.12291C14.6449 6.09884 14.0638 6.24155 13.5548 6.53386L10.2345 8.45105C10.2311 8.45334 10.2282 8.45621 10.2259 8.45965L10.2191 8.46996C10.2179 8.4734 10.2168 8.47741 10.2156 8.482C10.2145 8.48544 10.2139 8.48945 10.2139 8.49403V10.0966L14.2237 7.78046C14.2649 7.75639 14.3096 7.7392 14.3543 7.72544C14.4008 7.7134 14.4472 7.70825 14.4936 7.70825C14.5418 7.70825 14.5882 7.7134 14.6346 7.72544C14.6793 7.7392 14.7223 7.75639 14.7636 7.78046L18.0494 9.67874C18.0787 9.69593 18.1217 9.72 18.1457 9.73548ZM9.45735 7.96101C9.45735 7.91458 9.46423 7.86816 9.47627 7.82173C9.4883 7.77702 9.5055 7.73232 9.52957 7.69105C9.55364 7.6515 9.58115 7.61368 9.61554 7.57929C9.64821 7.54662 9.68604 7.51739 9.72731 7.49503L13.0132 5.59848C13.0441 5.57957 13.0871 5.55549 13.1112 5.54346C12.6607 5.1669 12.1105 4.92618 11.5276 4.85224C10.9447 4.77658 10.3532 4.86943 9.82188 5.11875C9.28885 5.36807 8.83835 5.76527 8.52369 6.26047C8.20903 6.75739 8.04224 7.33169 8.04224 7.91974V11.7541C8.04339 11.7587 8.04454 11.7627 8.04568 11.7661C8.04683 11.7696 8.04855 11.773 8.05084 11.7765C8.05313 11.7799 8.056 11.7833 8.05944 11.7868C8.06173 11.7891 8.06517 11.7914 8.06976 11.7937L9.45735 12.5949V7.96101ZM10.2105 13.0282L11.997 14.0599L13.7835 13.0282V10.9666L11.9987 9.93493L10.2122 10.9666L10.2105 13.0282Z", - "fill": "white" - }, - "children": [] - } - ] - }, - "name": "OpenaiBlue" -} diff --git a/web/app/components/base/icons/src/public/llm/OpenaiBlue.tsx b/web/app/components/base/icons/src/public/llm/OpenaiBlue.tsx deleted file mode 100644 index 9934a77591..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiBlue.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './OpenaiBlue.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'OpenaiBlue' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/OpenaiSmall.json b/web/app/components/base/icons/src/public/llm/OpenaiSmall.json deleted file mode 100644 index aa72f614bc..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiSmall.json +++ /dev/null @@ -1,128 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "26", - "height": "26", - "viewBox": "0 0 26 26", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink" - }, - "children": [ - { - "type": "element", - "name": "g", - "attributes": { - "clip-path": "url(#clip0_3892_83671)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M1 13C1 9.27247 1 7.4087 1.60896 5.93853C2.42092 3.97831 3.97831 2.42092 5.93853 1.60896C7.4087 1 9.27247 1 13 1C16.7275 1 18.5913 1 20.0615 1.60896C22.0217 2.42092 23.5791 3.97831 24.391 5.93853C25 7.4087 25 9.27247 25 13C25 16.7275 25 18.5913 24.391 20.0615C23.5791 22.0217 22.0217 23.5791 20.0615 24.391C18.5913 25 16.7275 25 13 25C9.27247 25 7.4087 25 5.93853 24.391C3.97831 23.5791 2.42092 22.0217 1.60896 20.0615C1 18.5913 1 16.7275 1 13Z", - "fill": "white" - }, - "children": [] - }, - { - "type": "element", - "name": "rect", - "attributes": { - "width": "24", - "height": "24", - "transform": "translate(1 1)", - "fill": "url(#pattern0_3892_83671)" - }, - "children": [] - }, - { - "type": "element", - "name": "rect", - "attributes": { - "width": "24", - "height": "24", - "transform": "translate(1 1)", - "fill": "white", - "fill-opacity": "0.01" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M13 0.75C14.8603 0.75 16.2684 0.750313 17.3945 0.827148C18.5228 0.904144 19.3867 1.05876 20.1572 1.37793C22.1787 2.21525 23.7847 3.82133 24.6221 5.84277C24.9412 6.61333 25.0959 7.47723 25.1729 8.60547C25.2497 9.73161 25.25 11.1397 25.25 13C25.25 14.8603 25.2497 16.2684 25.1729 17.3945C25.0959 18.5228 24.9412 19.3867 24.6221 20.1572C23.7847 22.1787 22.1787 23.7847 20.1572 24.6221C19.3867 24.9412 18.5228 25.0959 17.3945 25.1729C16.2684 25.2497 14.8603 25.25 13 25.25C11.1397 25.25 9.73161 25.2497 8.60547 25.1729C7.47723 25.0959 6.61333 24.9412 5.84277 24.6221C3.82133 23.7847 2.21525 22.1787 1.37793 20.1572C1.05876 19.3867 0.904144 18.5228 0.827148 17.3945C0.750313 16.2684 0.75 14.8603 0.75 13C0.75 11.1397 0.750313 9.73161 0.827148 8.60547C0.904144 7.47723 1.05876 6.61333 1.37793 5.84277C2.21525 3.82133 3.82133 2.21525 5.84277 1.37793C6.61333 1.05876 7.47723 0.904144 8.60547 0.827148C9.73161 0.750313 11.1397 0.75 13 0.75Z", - "stroke": "#101828", - "stroke-opacity": "0.08", - "stroke-width": "0.5" - }, - "children": [] - }, - { - "type": "element", - "name": "defs", - "attributes": {}, - "children": [ - { - "type": "element", - "name": "pattern", - "attributes": { - "id": "pattern0_3892_83671", - "patternContentUnits": "objectBoundingBox", - "width": "1", - "height": "1" - }, - "children": [ - { - "type": "element", - "name": "use", - "attributes": { - "xlink:href": "#image0_3892_83671", - "transform": "scale(0.00625)" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "clipPath", - "attributes": { - "id": "clip0_3892_83671" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M1 13C1 9.27247 1 7.4087 1.60896 5.93853C2.42092 3.97831 3.97831 2.42092 5.93853 1.60896C7.4087 1 9.27247 1 13 1C16.7275 1 18.5913 1 20.0615 1.60896C22.0217 2.42092 23.5791 3.97831 24.391 5.93853C25 7.4087 25 9.27247 25 13C25 16.7275 25 18.5913 24.391 20.0615C23.5791 22.0217 22.0217 23.5791 20.0615 24.391C18.5913 25 16.7275 25 13 25C9.27247 25 7.4087 25 5.93853 24.391C3.97831 23.5791 2.42092 22.0217 1.60896 20.0615C1 18.5913 1 16.7275 1 13Z", - "fill": "white" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "image", - "attributes": { - "id": "image0_3892_83671", - "width": "160", - "height": "160", - "preserveAspectRatio": "none", - "xlink:href": "" - }, - "children": [] - } - ] - } - ] - }, - "name": "OpenaiSmall" -} diff --git a/web/app/components/base/icons/src/public/llm/OpenaiSmall.tsx b/web/app/components/base/icons/src/public/llm/OpenaiSmall.tsx deleted file mode 100644 index 6307091e0b..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiSmall.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './OpenaiSmall.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'OpenaiSmall' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/OpenaiTeal.json b/web/app/components/base/icons/src/public/llm/OpenaiTeal.json deleted file mode 100644 index ffd0981512..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiTeal.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "24", - "height": "24", - "viewBox": "0 0 24 24", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "24", - "height": "24", - "rx": "6", - "fill": "#009688" - }, - "children": [] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M19.7758 11.5959C19.9546 11.9948 20.0681 12.4213 20.1145 12.8563C20.1592 13.2913 20.1369 13.7315 20.044 14.1596C19.9529 14.5878 19.7947 14.9987 19.5746 15.377C19.4302 15.6298 19.2599 15.867 19.0639 16.0854C18.8696 16.3021 18.653 16.4981 18.4174 16.67C18.1801 16.842 17.9274 16.9864 17.6591 17.105C17.3926 17.222 17.1141 17.3114 16.8286 17.3698C16.6945 17.7859 16.4951 18.1797 16.2371 18.5339C15.9809 18.8881 15.6697 19.1993 15.3155 19.4555C14.9613 19.7134 14.5693 19.9129 14.1532 20.047C13.7371 20.1829 13.302 20.2499 12.8636 20.2499C12.573 20.2516 12.2807 20.2207 11.9953 20.1622C11.7116 20.102 11.433 20.0109 11.1665 19.8923C10.9 19.7736 10.6472 19.6258 10.4116 19.4538C10.1778 19.2819 9.96115 19.0841 9.76857 18.8658C9.33871 18.9586 8.89853 18.981 8.46351 18.9363C8.02849 18.8898 7.60207 18.7763 7.20143 18.5975C6.80252 18.4204 6.43284 18.1797 6.10786 17.8857C5.78289 17.5916 5.50606 17.2478 5.28769 16.8695C5.14153 16.6167 5.02117 16.3502 4.93004 16.0734C4.83891 15.7965 4.77873 15.5111 4.74778 15.2205C4.71683 14.9317 4.71855 14.6393 4.7495 14.3488C4.78045 14.0599 4.84407 13.7745 4.9352 13.4976C4.64289 13.1727 4.40217 12.803 4.22335 12.4041C4.04624 12.0034 3.93104 11.5787 3.88634 11.1437C3.83991 10.7087 3.86398 10.2685 3.95511 9.84036C4.04624 9.41222 4.20443 9.00127 4.42452 8.62299C4.56896 8.37023 4.73918 8.13123 4.93348 7.91458C5.12778 7.69793 5.34615 7.50191 5.58171 7.32997C5.81728 7.15802 6.07176 7.01187 6.33827 6.89495C6.6065 6.7763 6.88506 6.68861 7.17048 6.63015C7.3046 6.21232 7.50406 5.82029 7.76026 5.46608C8.01817 5.11188 8.32939 4.80066 8.6836 4.54274C9.03781 4.28654 9.42984 4.08708 9.84595 3.95125C10.2621 3.81713 10.6971 3.74835 11.1355 3.75007C11.4261 3.74835 11.7184 3.77758 12.0039 3.83776C12.2893 3.89794 12.5678 3.98736 12.8344 4.106C13.1009 4.22636 13.3536 4.37251 13.5892 4.54446C13.8248 4.71812 14.0414 4.91414 14.234 5.13251C14.6621 5.04138 15.1023 5.01903 15.5373 5.06373C15.9723 5.10844 16.3971 5.22364 16.7977 5.40074C17.1966 5.57957 17.5663 5.81857 17.8913 6.1126C18.2162 6.4049 18.4931 6.74707 18.7114 7.12707C18.8576 7.37811 18.9779 7.64463 19.0691 7.92318C19.1602 8.20001 19.2221 8.48544 19.2513 8.77602C19.2823 9.06661 19.2823 9.35892 19.2496 9.64951C19.2187 9.94009 19.155 10.2255 19.0639 10.5024C19.3579 10.8273 19.5969 11.1953 19.7758 11.5959ZM14.0466 18.9363C14.4214 18.7815 14.7619 18.5528 15.049 18.2657C15.3362 17.9785 15.5648 17.6381 15.7196 17.2615C15.8743 16.8867 15.9552 16.4843 15.9552 16.0785V12.2442C15.954 12.2407 15.9529 12.2367 15.9517 12.2321C15.9506 12.2287 15.9488 12.2252 15.9466 12.2218C15.9443 12.2184 15.9414 12.2155 15.938 12.2132C15.9345 12.2098 15.9311 12.2075 15.9276 12.2063L14.54 11.4051V16.0373C14.54 16.0837 14.5332 16.1318 14.5211 16.1765C14.5091 16.223 14.4919 16.2659 14.4678 16.3072C14.4438 16.3485 14.4162 16.3863 14.3819 16.419C14.3484 16.4523 14.3109 16.4812 14.2701 16.505L10.9842 18.4015C10.9567 18.4187 10.9103 18.4428 10.8862 18.4565C11.0221 18.5717 11.1699 18.6732 11.3247 18.7626C11.4811 18.852 11.6428 18.9277 11.8113 18.9896C11.9798 19.0497 12.1535 19.0962 12.3288 19.1271C12.5059 19.1581 12.6848 19.1735 12.8636 19.1735C13.2694 19.1735 13.6717 19.0927 14.0466 18.9363ZM6.22135 16.333C6.42596 16.6855 6.69592 16.9916 7.01745 17.2392C7.34071 17.4868 7.70695 17.6673 8.09899 17.7722C8.49102 17.8771 8.90025 17.9046 9.3026 17.8513C9.70495 17.798 10.0918 17.6673 10.4443 17.4644L13.7663 15.5472L13.7749 15.5386C13.7772 15.5363 13.7789 15.5329 13.78 15.5283C13.7823 15.5249 13.7841 15.5214 13.7852 15.518V13.9017L9.77545 16.2212C9.73418 16.2453 9.6912 16.2625 9.64649 16.2763C9.60007 16.2883 9.55364 16.2935 9.5055 16.2935C9.45907 16.2935 9.41265 16.2883 9.36622 16.2763C9.32152 16.2625 9.27681 16.2453 9.23554 16.2212L5.94967 14.323C5.92044 14.3058 5.87746 14.28 5.85339 14.2645C5.82244 14.4416 5.80696 14.6204 5.80696 14.7993C5.80696 14.9781 5.82415 15.1569 5.85511 15.334C5.88605 15.5094 5.9342 15.6831 5.99438 15.8516C6.05628 16.0201 6.13194 16.1817 6.22135 16.3364V16.333ZM5.35818 9.1629C5.15529 9.51539 5.02461 9.90398 4.97131 10.3063C4.918 10.7087 4.94552 11.1162 5.0504 11.51C5.15529 11.902 5.33583 12.2682 5.58343 12.5915C5.83103 12.913 6.13881 13.183 6.48958 13.3859L9.80984 15.3048C9.81328 15.3059 9.81729 15.3071 9.82188 15.3082H9.83391C9.8385 15.3082 9.84251 15.3071 9.84595 15.3048C9.84939 15.3036 9.85283 15.3019 9.85627 15.2996L11.249 14.4949L7.23926 12.1805C7.19971 12.1565 7.16189 12.1272 7.1275 12.0946C7.09418 12.0611 7.06529 12.0236 7.04153 11.9828C7.01917 11.9415 7.00026 11.8985 6.98822 11.8521C6.97619 11.8074 6.96931 11.761 6.97103 11.7128V7.80797C6.80252 7.86987 6.63917 7.94553 6.48442 8.03494C6.32967 8.12607 6.18352 8.22924 6.04596 8.34444C5.91013 8.45965 5.78289 8.58688 5.66769 8.72444C5.55248 8.86028 5.45103 9.00815 5.36162 9.1629H5.35818ZM16.7633 11.8177C16.8046 11.8418 16.8424 11.8693 16.8768 11.9037C16.9094 11.9364 16.9387 11.9742 16.9628 12.0155C16.9851 12.0567 17.004 12.1014 17.0161 12.1461C17.0264 12.1926 17.0332 12.239 17.0315 12.2871V16.192C17.5835 15.9891 18.0649 15.6332 18.4208 15.1655C18.7785 14.6978 18.9934 14.139 19.0433 13.5544C19.0931 12.9698 18.9762 12.3817 18.7046 11.8607C18.4329 11.3397 18.0185 10.9064 17.5095 10.6141L14.1893 8.69521C14.1858 8.69406 14.1818 8.69292 14.1772 8.69177H14.1652C14.1618 8.69292 14.1578 8.69406 14.1532 8.69521C14.1497 8.69636 14.1463 8.69808 14.1429 8.70037L12.757 9.50163L16.7667 11.8177H16.7633ZM18.1475 9.7372H18.1457V9.73892L18.1475 9.7372ZM18.1457 9.73548C18.2455 9.15774 18.1784 8.56281 17.9514 8.02119C17.7262 7.47956 17.3496 7.01359 16.8682 6.67658C16.3867 6.34128 15.8193 6.1487 15.233 6.12291C14.6449 6.09884 14.0638 6.24155 13.5548 6.53386L10.2345 8.45105C10.2311 8.45334 10.2282 8.45621 10.2259 8.45965L10.2191 8.46996C10.2179 8.4734 10.2168 8.47741 10.2156 8.482C10.2145 8.48544 10.2139 8.48945 10.2139 8.49403V10.0966L14.2237 7.78046C14.2649 7.75639 14.3096 7.7392 14.3543 7.72544C14.4008 7.7134 14.4472 7.70825 14.4936 7.70825C14.5418 7.70825 14.5882 7.7134 14.6346 7.72544C14.6793 7.7392 14.7223 7.75639 14.7636 7.78046L18.0494 9.67874C18.0787 9.69593 18.1217 9.72 18.1457 9.73548ZM9.45735 7.96101C9.45735 7.91458 9.46423 7.86816 9.47627 7.82173C9.4883 7.77702 9.5055 7.73232 9.52957 7.69105C9.55364 7.6515 9.58115 7.61368 9.61554 7.57929C9.64821 7.54662 9.68604 7.51739 9.72731 7.49503L13.0132 5.59848C13.0441 5.57957 13.0871 5.55549 13.1112 5.54346C12.6607 5.1669 12.1105 4.92618 11.5276 4.85224C10.9447 4.77658 10.3532 4.86943 9.82188 5.11875C9.28885 5.36807 8.83835 5.76527 8.52369 6.26047C8.20903 6.75739 8.04224 7.33169 8.04224 7.91974V11.7541C8.04339 11.7587 8.04454 11.7627 8.04568 11.7661C8.04683 11.7696 8.04855 11.773 8.05084 11.7765C8.05313 11.7799 8.056 11.7833 8.05944 11.7868C8.06173 11.7891 8.06517 11.7914 8.06976 11.7937L9.45735 12.5949V7.96101ZM10.2105 13.0282L11.997 14.0599L13.7835 13.0282V10.9666L11.9987 9.93493L10.2122 10.9666L10.2105 13.0282Z", - "fill": "white" - }, - "children": [] - } - ] - }, - "name": "OpenaiTeal" -} diff --git a/web/app/components/base/icons/src/public/llm/OpenaiTeal.tsx b/web/app/components/base/icons/src/public/llm/OpenaiTeal.tsx deleted file mode 100644 index ef803ea52f..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiTeal.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './OpenaiTeal.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'OpenaiTeal' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/OpenaiViolet.json b/web/app/components/base/icons/src/public/llm/OpenaiViolet.json deleted file mode 100644 index e80a85507e..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiViolet.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "24", - "height": "24", - "viewBox": "0 0 24 24", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg" - }, - "children": [ - { - "type": "element", - "name": "rect", - "attributes": { - "width": "24", - "height": "24", - "rx": "6", - "fill": "#AB68FF" - }, - "children": [] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M19.7758 11.5959C19.9546 11.9948 20.0681 12.4213 20.1145 12.8563C20.1592 13.2913 20.1369 13.7315 20.044 14.1596C19.9529 14.5878 19.7947 14.9987 19.5746 15.377C19.4302 15.6298 19.2599 15.867 19.0639 16.0854C18.8696 16.3021 18.653 16.4981 18.4174 16.67C18.1801 16.842 17.9274 16.9864 17.6591 17.105C17.3926 17.222 17.1141 17.3114 16.8286 17.3698C16.6945 17.7859 16.4951 18.1797 16.2371 18.5339C15.9809 18.8881 15.6697 19.1993 15.3155 19.4555C14.9613 19.7134 14.5693 19.9129 14.1532 20.047C13.7371 20.1829 13.302 20.2499 12.8636 20.2499C12.573 20.2516 12.2807 20.2207 11.9953 20.1622C11.7116 20.102 11.433 20.0109 11.1665 19.8923C10.9 19.7736 10.6472 19.6258 10.4116 19.4538C10.1778 19.2819 9.96115 19.0841 9.76857 18.8658C9.33871 18.9586 8.89853 18.981 8.46351 18.9363C8.02849 18.8898 7.60207 18.7763 7.20143 18.5975C6.80252 18.4204 6.43284 18.1797 6.10786 17.8857C5.78289 17.5916 5.50606 17.2478 5.28769 16.8695C5.14153 16.6167 5.02117 16.3502 4.93004 16.0734C4.83891 15.7965 4.77873 15.5111 4.74778 15.2205C4.71683 14.9317 4.71855 14.6393 4.7495 14.3488C4.78045 14.0599 4.84407 13.7745 4.9352 13.4976C4.64289 13.1727 4.40217 12.803 4.22335 12.4041C4.04624 12.0034 3.93104 11.5787 3.88634 11.1437C3.83991 10.7087 3.86398 10.2685 3.95511 9.84036C4.04624 9.41222 4.20443 9.00127 4.42452 8.62299C4.56896 8.37023 4.73918 8.13123 4.93348 7.91458C5.12778 7.69793 5.34615 7.50191 5.58171 7.32997C5.81728 7.15802 6.07176 7.01187 6.33827 6.89495C6.6065 6.7763 6.88506 6.68861 7.17048 6.63015C7.3046 6.21232 7.50406 5.82029 7.76026 5.46608C8.01817 5.11188 8.32939 4.80066 8.6836 4.54274C9.03781 4.28654 9.42984 4.08708 9.84595 3.95125C10.2621 3.81713 10.6971 3.74835 11.1355 3.75007C11.4261 3.74835 11.7184 3.77758 12.0039 3.83776C12.2893 3.89794 12.5678 3.98736 12.8344 4.106C13.1009 4.22636 13.3536 4.37251 13.5892 4.54446C13.8248 4.71812 14.0414 4.91414 14.234 5.13251C14.6621 5.04138 15.1023 5.01903 15.5373 5.06373C15.9723 5.10844 16.3971 5.22364 16.7977 5.40074C17.1966 5.57957 17.5663 5.81857 17.8913 6.1126C18.2162 6.4049 18.4931 6.74707 18.7114 7.12707C18.8576 7.37811 18.9779 7.64463 19.0691 7.92318C19.1602 8.20001 19.2221 8.48544 19.2513 8.77602C19.2823 9.06661 19.2823 9.35892 19.2496 9.64951C19.2187 9.94009 19.155 10.2255 19.0639 10.5024C19.3579 10.8273 19.5969 11.1953 19.7758 11.5959ZM14.0466 18.9363C14.4214 18.7815 14.7619 18.5528 15.049 18.2657C15.3362 17.9785 15.5648 17.6381 15.7196 17.2615C15.8743 16.8867 15.9552 16.4843 15.9552 16.0785V12.2442C15.954 12.2407 15.9529 12.2367 15.9517 12.2321C15.9506 12.2287 15.9488 12.2252 15.9466 12.2218C15.9443 12.2184 15.9414 12.2155 15.938 12.2132C15.9345 12.2098 15.9311 12.2075 15.9276 12.2063L14.54 11.4051V16.0373C14.54 16.0837 14.5332 16.1318 14.5211 16.1765C14.5091 16.223 14.4919 16.2659 14.4678 16.3072C14.4438 16.3485 14.4162 16.3863 14.3819 16.419C14.3484 16.4523 14.3109 16.4812 14.2701 16.505L10.9842 18.4015C10.9567 18.4187 10.9103 18.4428 10.8862 18.4565C11.0221 18.5717 11.1699 18.6732 11.3247 18.7626C11.4811 18.852 11.6428 18.9277 11.8113 18.9896C11.9798 19.0497 12.1535 19.0962 12.3288 19.1271C12.5059 19.1581 12.6848 19.1735 12.8636 19.1735C13.2694 19.1735 13.6717 19.0927 14.0466 18.9363ZM6.22135 16.333C6.42596 16.6855 6.69592 16.9916 7.01745 17.2392C7.34071 17.4868 7.70695 17.6673 8.09899 17.7722C8.49102 17.8771 8.90025 17.9046 9.3026 17.8513C9.70495 17.798 10.0918 17.6673 10.4443 17.4644L13.7663 15.5472L13.7749 15.5386C13.7772 15.5363 13.7789 15.5329 13.78 15.5283C13.7823 15.5249 13.7841 15.5214 13.7852 15.518V13.9017L9.77545 16.2212C9.73418 16.2453 9.6912 16.2625 9.64649 16.2763C9.60007 16.2883 9.55364 16.2935 9.5055 16.2935C9.45907 16.2935 9.41265 16.2883 9.36622 16.2763C9.32152 16.2625 9.27681 16.2453 9.23554 16.2212L5.94967 14.323C5.92044 14.3058 5.87746 14.28 5.85339 14.2645C5.82244 14.4416 5.80696 14.6204 5.80696 14.7993C5.80696 14.9781 5.82415 15.1569 5.85511 15.334C5.88605 15.5094 5.9342 15.6831 5.99438 15.8516C6.05628 16.0201 6.13194 16.1817 6.22135 16.3364V16.333ZM5.35818 9.1629C5.15529 9.51539 5.02461 9.90398 4.97131 10.3063C4.918 10.7087 4.94552 11.1162 5.0504 11.51C5.15529 11.902 5.33583 12.2682 5.58343 12.5915C5.83103 12.913 6.13881 13.183 6.48958 13.3859L9.80984 15.3048C9.81328 15.3059 9.81729 15.3071 9.82188 15.3082H9.83391C9.8385 15.3082 9.84251 15.3071 9.84595 15.3048C9.84939 15.3036 9.85283 15.3019 9.85627 15.2996L11.249 14.4949L7.23926 12.1805C7.19971 12.1565 7.16189 12.1272 7.1275 12.0946C7.09418 12.0611 7.06529 12.0236 7.04153 11.9828C7.01917 11.9415 7.00026 11.8985 6.98822 11.8521C6.97619 11.8074 6.96931 11.761 6.97103 11.7128V7.80797C6.80252 7.86987 6.63917 7.94553 6.48442 8.03494C6.32967 8.12607 6.18352 8.22924 6.04596 8.34444C5.91013 8.45965 5.78289 8.58688 5.66769 8.72444C5.55248 8.86028 5.45103 9.00815 5.36162 9.1629H5.35818ZM16.7633 11.8177C16.8046 11.8418 16.8424 11.8693 16.8768 11.9037C16.9094 11.9364 16.9387 11.9742 16.9628 12.0155C16.9851 12.0567 17.004 12.1014 17.0161 12.1461C17.0264 12.1926 17.0332 12.239 17.0315 12.2871V16.192C17.5835 15.9891 18.0649 15.6332 18.4208 15.1655C18.7785 14.6978 18.9934 14.139 19.0433 13.5544C19.0931 12.9698 18.9762 12.3817 18.7046 11.8607C18.4329 11.3397 18.0185 10.9064 17.5095 10.6141L14.1893 8.69521C14.1858 8.69406 14.1818 8.69292 14.1772 8.69177H14.1652C14.1618 8.69292 14.1578 8.69406 14.1532 8.69521C14.1497 8.69636 14.1463 8.69808 14.1429 8.70037L12.757 9.50163L16.7667 11.8177H16.7633ZM18.1475 9.7372H18.1457V9.73892L18.1475 9.7372ZM18.1457 9.73548C18.2455 9.15774 18.1784 8.56281 17.9514 8.02119C17.7262 7.47956 17.3496 7.01359 16.8682 6.67658C16.3867 6.34128 15.8193 6.1487 15.233 6.12291C14.6449 6.09884 14.0638 6.24155 13.5548 6.53386L10.2345 8.45105C10.2311 8.45334 10.2282 8.45621 10.2259 8.45965L10.2191 8.46996C10.2179 8.4734 10.2168 8.47741 10.2156 8.482C10.2145 8.48544 10.2139 8.48945 10.2139 8.49403V10.0966L14.2237 7.78046C14.2649 7.75639 14.3096 7.7392 14.3543 7.72544C14.4008 7.7134 14.4472 7.70825 14.4936 7.70825C14.5418 7.70825 14.5882 7.7134 14.6346 7.72544C14.6793 7.7392 14.7223 7.75639 14.7636 7.78046L18.0494 9.67874C18.0787 9.69593 18.1217 9.72 18.1457 9.73548ZM9.45735 7.96101C9.45735 7.91458 9.46423 7.86816 9.47627 7.82173C9.4883 7.77702 9.5055 7.73232 9.52957 7.69105C9.55364 7.6515 9.58115 7.61368 9.61554 7.57929C9.64821 7.54662 9.68604 7.51739 9.72731 7.49503L13.0132 5.59848C13.0441 5.57957 13.0871 5.55549 13.1112 5.54346C12.6607 5.1669 12.1105 4.92618 11.5276 4.85224C10.9447 4.77658 10.3532 4.86943 9.82188 5.11875C9.28885 5.36807 8.83835 5.76527 8.52369 6.26047C8.20903 6.75739 8.04224 7.33169 8.04224 7.91974V11.7541C8.04339 11.7587 8.04454 11.7627 8.04568 11.7661C8.04683 11.7696 8.04855 11.773 8.05084 11.7765C8.05313 11.7799 8.056 11.7833 8.05944 11.7868C8.06173 11.7891 8.06517 11.7914 8.06976 11.7937L9.45735 12.5949V7.96101ZM10.2105 13.0282L11.997 14.0599L13.7835 13.0282V10.9666L11.9987 9.93493L10.2122 10.9666L10.2105 13.0282Z", - "fill": "white" - }, - "children": [] - } - ] - }, - "name": "OpenaiViolet" -} diff --git a/web/app/components/base/icons/src/public/llm/OpenaiViolet.tsx b/web/app/components/base/icons/src/public/llm/OpenaiViolet.tsx deleted file mode 100644 index 9aa08c0f3b..0000000000 --- a/web/app/components/base/icons/src/public/llm/OpenaiViolet.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './OpenaiViolet.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'OpenaiViolet' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/Tongyi.json b/web/app/components/base/icons/src/public/llm/Tongyi.json deleted file mode 100644 index 9150ca226b..0000000000 --- a/web/app/components/base/icons/src/public/llm/Tongyi.json +++ /dev/null @@ -1,128 +0,0 @@ -{ - "icon": { - "type": "element", - "isRootNode": true, - "name": "svg", - "attributes": { - "width": "25", - "height": "25", - "viewBox": "0 0 25 25", - "fill": "none", - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink" - }, - "children": [ - { - "type": "element", - "name": "g", - "attributes": { - "clip-path": "url(#clip0_6305_73327)" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M0.5 12.5C0.5 8.77247 0.5 6.9087 1.10896 5.43853C1.92092 3.47831 3.47831 1.92092 5.43853 1.10896C6.9087 0.5 8.77247 0.5 12.5 0.5C16.2275 0.5 18.0913 0.5 19.5615 1.10896C21.5217 1.92092 23.0791 3.47831 23.891 5.43853C24.5 6.9087 24.5 8.77247 24.5 12.5C24.5 16.2275 24.5 18.0913 23.891 19.5615C23.0791 21.5217 21.5217 23.0791 19.5615 23.891C18.0913 24.5 16.2275 24.5 12.5 24.5C8.77247 24.5 6.9087 24.5 5.43853 23.891C3.47831 23.0791 1.92092 21.5217 1.10896 19.5615C0.5 18.0913 0.5 16.2275 0.5 12.5Z", - "fill": "white" - }, - "children": [] - }, - { - "type": "element", - "name": "rect", - "attributes": { - "width": "24", - "height": "24", - "transform": "translate(0.5 0.5)", - "fill": "url(#pattern0_6305_73327)" - }, - "children": [] - }, - { - "type": "element", - "name": "rect", - "attributes": { - "width": "24", - "height": "24", - "transform": "translate(0.5 0.5)", - "fill": "white", - "fill-opacity": "0.01" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "path", - "attributes": { - "d": "M12.5 0.25C14.3603 0.25 15.7684 0.250313 16.8945 0.327148C18.0228 0.404144 18.8867 0.558755 19.6572 0.87793C21.6787 1.71525 23.2847 3.32133 24.1221 5.34277C24.4412 6.11333 24.5959 6.97723 24.6729 8.10547C24.7497 9.23161 24.75 10.6397 24.75 12.5C24.75 14.3603 24.7497 15.7684 24.6729 16.8945C24.5959 18.0228 24.4412 18.8867 24.1221 19.6572C23.2847 21.6787 21.6787 23.2847 19.6572 24.1221C18.8867 24.4412 18.0228 24.5959 16.8945 24.6729C15.7684 24.7497 14.3603 24.75 12.5 24.75C10.6397 24.75 9.23161 24.7497 8.10547 24.6729C6.97723 24.5959 6.11333 24.4412 5.34277 24.1221C3.32133 23.2847 1.71525 21.6787 0.87793 19.6572C0.558755 18.8867 0.404144 18.0228 0.327148 16.8945C0.250313 15.7684 0.25 14.3603 0.25 12.5C0.25 10.6397 0.250313 9.23161 0.327148 8.10547C0.404144 6.97723 0.558755 6.11333 0.87793 5.34277C1.71525 3.32133 3.32133 1.71525 5.34277 0.87793C6.11333 0.558755 6.97723 0.404144 8.10547 0.327148C9.23161 0.250313 10.6397 0.25 12.5 0.25Z", - "stroke": "#101828", - "stroke-opacity": "0.08", - "stroke-width": "0.5" - }, - "children": [] - }, - { - "type": "element", - "name": "defs", - "attributes": {}, - "children": [ - { - "type": "element", - "name": "pattern", - "attributes": { - "id": "pattern0_6305_73327", - "patternContentUnits": "objectBoundingBox", - "width": "1", - "height": "1" - }, - "children": [ - { - "type": "element", - "name": "use", - "attributes": { - "xlink:href": "#image0_6305_73327", - "transform": "scale(0.00625)" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "clipPath", - "attributes": { - "id": "clip0_6305_73327" - }, - "children": [ - { - "type": "element", - "name": "path", - "attributes": { - "d": "M0.5 12.5C0.5 8.77247 0.5 6.9087 1.10896 5.43853C1.92092 3.47831 3.47831 1.92092 5.43853 1.10896C6.9087 0.5 8.77247 0.5 12.5 0.5C16.2275 0.5 18.0913 0.5 19.5615 1.10896C21.5217 1.92092 23.0791 3.47831 23.891 5.43853C24.5 6.9087 24.5 8.77247 24.5 12.5C24.5 16.2275 24.5 18.0913 23.891 19.5615C23.0791 21.5217 21.5217 23.0791 19.5615 23.891C18.0913 24.5 16.2275 24.5 12.5 24.5C8.77247 24.5 6.9087 24.5 5.43853 23.891C3.47831 23.0791 1.92092 21.5217 1.10896 19.5615C0.5 18.0913 0.5 16.2275 0.5 12.5Z", - "fill": "white" - }, - "children": [] - } - ] - }, - { - "type": "element", - "name": "image", - "attributes": { - "id": "image0_6305_73327", - "width": "160", - "height": "160", - "preserveAspectRatio": "none", - "xlink:href": "" - }, - "children": [] - } - ] - } - ] - }, - "name": "Tongyi" -} diff --git a/web/app/components/base/icons/src/public/llm/Tongyi.tsx b/web/app/components/base/icons/src/public/llm/Tongyi.tsx deleted file mode 100644 index 9934dee856..0000000000 --- a/web/app/components/base/icons/src/public/llm/Tongyi.tsx +++ /dev/null @@ -1,20 +0,0 @@ -// GENERATE BY script -// DON NOT EDIT IT MANUALLY - -import type { IconData } from '@/app/components/base/icons/IconBase' -import * as React from 'react' -import IconBase from '@/app/components/base/icons/IconBase' -import data from './Tongyi.json' - -const Icon = ( - { - ref, - ...props - }: React.SVGProps & { - ref?: React.RefObject> - }, -) => - -Icon.displayName = 'Tongyi' - -export default Icon diff --git a/web/app/components/base/icons/src/public/llm/index.ts b/web/app/components/base/icons/src/public/llm/index.ts index 0c5cef4a36..3a4306391e 100644 --- a/web/app/components/base/icons/src/public/llm/index.ts +++ b/web/app/components/base/icons/src/public/llm/index.ts @@ -1,7 +1,6 @@ export { default as Anthropic } from './Anthropic' export { default as AnthropicDark } from './AnthropicDark' export { default as AnthropicLight } from './AnthropicLight' -export { default as AnthropicShortLight } from './AnthropicShortLight' export { default as AnthropicText } from './AnthropicText' export { default as Azureai } from './Azureai' export { default as AzureaiText } from './AzureaiText' @@ -13,11 +12,8 @@ export { default as Chatglm } from './Chatglm' export { default as ChatglmText } from './ChatglmText' export { default as Cohere } from './Cohere' export { default as CohereText } from './CohereText' -export { default as Deepseek } from './Deepseek' -export { default as Gemini } from './Gemini' export { default as Gpt3 } from './Gpt3' export { default as Gpt4 } from './Gpt4' -export { default as Grok } from './Grok' export { default as Huggingface } from './Huggingface' export { default as HuggingfaceText } from './HuggingfaceText' export { default as HuggingfaceTextHub } from './HuggingfaceTextHub' @@ -30,19 +26,14 @@ export { default as Localai } from './Localai' export { default as LocalaiText } from './LocalaiText' export { default as Microsoft } from './Microsoft' export { default as OpenaiBlack } from './OpenaiBlack' -export { default as OpenaiBlue } from './OpenaiBlue' export { default as OpenaiGreen } from './OpenaiGreen' -export { default as OpenaiSmall } from './OpenaiSmall' -export { default as OpenaiTeal } from './OpenaiTeal' export { default as OpenaiText } from './OpenaiText' export { default as OpenaiTransparent } from './OpenaiTransparent' -export { default as OpenaiViolet } from './OpenaiViolet' export { default as OpenaiYellow } from './OpenaiYellow' export { default as Openllm } from './Openllm' export { default as OpenllmText } from './OpenllmText' export { default as Replicate } from './Replicate' export { default as ReplicateText } from './ReplicateText' -export { default as Tongyi } from './Tongyi' export { default as XorbitsInference } from './XorbitsInference' export { default as XorbitsInferenceText } from './XorbitsInferenceText' export { default as Zhipuai } from './Zhipuai' diff --git a/web/app/components/base/icons/src/public/tracing/DatabricksIcon.tsx b/web/app/components/base/icons/src/public/tracing/DatabricksIcon.tsx index a1e45d8bdf..87abe453ec 100644 --- a/web/app/components/base/icons/src/public/tracing/DatabricksIcon.tsx +++ b/web/app/components/base/icons/src/public/tracing/DatabricksIcon.tsx @@ -11,7 +11,7 @@ const Icon = ( ref, ...props }: React.SVGProps & { - ref?: React.RefObject> + ref?: React.RefObject> }, ) => diff --git a/web/app/components/base/icons/src/public/tracing/DatabricksIconBig.tsx b/web/app/components/base/icons/src/public/tracing/DatabricksIconBig.tsx index ef21c05a23..bebaa1b40e 100644 --- a/web/app/components/base/icons/src/public/tracing/DatabricksIconBig.tsx +++ b/web/app/components/base/icons/src/public/tracing/DatabricksIconBig.tsx @@ -11,7 +11,7 @@ const Icon = ( ref, ...props }: React.SVGProps & { - ref?: React.RefObject> + ref?: React.RefObject> }, ) => diff --git a/web/app/components/base/icons/src/public/tracing/MlflowIcon.tsx b/web/app/components/base/icons/src/public/tracing/MlflowIcon.tsx index 09a31882c9..3c86ed61f4 100644 --- a/web/app/components/base/icons/src/public/tracing/MlflowIcon.tsx +++ b/web/app/components/base/icons/src/public/tracing/MlflowIcon.tsx @@ -11,7 +11,7 @@ const Icon = ( ref, ...props }: React.SVGProps & { - ref?: React.RefObject> + ref?: React.RefObject> }, ) => diff --git a/web/app/components/base/icons/src/public/tracing/MlflowIconBig.tsx b/web/app/components/base/icons/src/public/tracing/MlflowIconBig.tsx index 03fef44991..fbb288d46a 100644 --- a/web/app/components/base/icons/src/public/tracing/MlflowIconBig.tsx +++ b/web/app/components/base/icons/src/public/tracing/MlflowIconBig.tsx @@ -11,7 +11,7 @@ const Icon = ( ref, ...props }: React.SVGProps & { - ref?: React.RefObject> + ref?: React.RefObject> }, ) => diff --git a/web/app/components/base/icons/src/public/tracing/TencentIcon.json b/web/app/components/base/icons/src/public/tracing/TencentIcon.json index 9fd54c0ce9..642fa75a92 100644 --- a/web/app/components/base/icons/src/public/tracing/TencentIcon.json +++ b/web/app/components/base/icons/src/public/tracing/TencentIcon.json @@ -1,16 +1,14 @@ { "icon": { "type": "element", - "isRootNode": true, "name": "svg", "attributes": { "width": "80px", "height": "18px", "viewBox": "0 0 80 18", - "version": "1.1", - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink" + "version": "1.1" }, + "isRootNode": true, "children": [ { "type": "element", diff --git a/web/app/components/base/icons/src/public/tracing/TencentIconBig.json b/web/app/components/base/icons/src/public/tracing/TencentIconBig.json index 9abd81455f..d0582e7f8d 100644 --- a/web/app/components/base/icons/src/public/tracing/TencentIconBig.json +++ b/web/app/components/base/icons/src/public/tracing/TencentIconBig.json @@ -1,16 +1,14 @@ { "icon": { "type": "element", - "isRootNode": true, "name": "svg", "attributes": { - "width": "120px", - "height": "27px", + "width": "80px", + "height": "18px", "viewBox": "0 0 80 18", - "version": "1.1", - "xmlns": "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink" + "version": "1.1" }, + "isRootNode": true, "children": [ { "type": "element", diff --git a/web/app/components/billing/apps-full-in-dialog/index.spec.tsx b/web/app/components/billing/apps-full-in-dialog/index.spec.tsx index d006a3222d..a11b582b0f 100644 --- a/web/app/components/billing/apps-full-in-dialog/index.spec.tsx +++ b/web/app/components/billing/apps-full-in-dialog/index.spec.tsx @@ -75,9 +75,6 @@ const buildAppContext = (overrides: Partial = {}): AppContextVa created_at: 0, role: 'normal', providers: [], - trial_credits: 200, - trial_credits_used: 0, - next_credit_reset_date: 0, } const langGeniusVersionInfo: LangGeniusVersionResponse = { current_env: '', @@ -99,7 +96,6 @@ const buildAppContext = (overrides: Partial = {}): AppContextVa mutateCurrentWorkspace: vi.fn(), langGeniusVersionInfo, isLoadingCurrentWorkspace: false, - isValidatingCurrentWorkspace: false, } const useSelector: AppContextValue['useSelector'] = selector => selector({ ...base, useSelector }) return { diff --git a/web/app/components/datasets/create/step-one/components/data-source-type-selector.tsx b/web/app/components/datasets/create/step-one/components/data-source-type-selector.tsx new file mode 100644 index 0000000000..6bdc2ace56 --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/data-source-type-selector.tsx @@ -0,0 +1,97 @@ +'use client' + +import { useCallback, useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' +import { DataSourceType } from '@/models/datasets' +import { cn } from '@/utils/classnames' +import s from '../index.module.css' + +type DataSourceTypeSelectorProps = { + currentType: DataSourceType + disabled: boolean + onChange: (type: DataSourceType) => void + onClearPreviews: (type: DataSourceType) => void +} + +type DataSourceLabelKey + = | 'stepOne.dataSourceType.file' + | 'stepOne.dataSourceType.notion' + | 'stepOne.dataSourceType.web' + +type DataSourceOption = { + type: DataSourceType + iconClass?: string + labelKey: DataSourceLabelKey +} + +const DATA_SOURCE_OPTIONS: DataSourceOption[] = [ + { + type: DataSourceType.FILE, + labelKey: 'stepOne.dataSourceType.file', + }, + { + type: DataSourceType.NOTION, + iconClass: s.notion, + labelKey: 'stepOne.dataSourceType.notion', + }, + { + type: DataSourceType.WEB, + iconClass: s.web, + labelKey: 'stepOne.dataSourceType.web', + }, +] + +/** + * Data source type selector component for choosing between file, notion, and web sources. + */ +function DataSourceTypeSelector({ + currentType, + disabled, + onChange, + onClearPreviews, +}: DataSourceTypeSelectorProps) { + const { t } = useTranslation() + + const isWebEnabled = ENABLE_WEBSITE_FIRECRAWL || ENABLE_WEBSITE_JINAREADER || ENABLE_WEBSITE_WATERCRAWL + + const handleTypeChange = useCallback((type: DataSourceType) => { + if (disabled) + return + onChange(type) + onClearPreviews(type) + }, [disabled, onChange, onClearPreviews]) + + const visibleOptions = useMemo(() => DATA_SOURCE_OPTIONS.filter((option) => { + if (option.type === DataSourceType.WEB) + return isWebEnabled + return true + }), [isWebEnabled]) + + return ( +
+ {visibleOptions.map(option => ( +
handleTypeChange(option.type)} + > + + + {t(option.labelKey, { ns: 'datasetCreation' })} + +
+ ))} +
+ ) +} + +export default DataSourceTypeSelector diff --git a/web/app/components/datasets/create/step-one/components/index.ts b/web/app/components/datasets/create/step-one/components/index.ts new file mode 100644 index 0000000000..5271835741 --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/index.ts @@ -0,0 +1,3 @@ +export { default as DataSourceTypeSelector } from './data-source-type-selector' +export { default as NextStepButton } from './next-step-button' +export { default as PreviewPanel } from './preview-panel' diff --git a/web/app/components/datasets/create/step-one/components/next-step-button.tsx b/web/app/components/datasets/create/step-one/components/next-step-button.tsx new file mode 100644 index 0000000000..71e4e87fcf --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/next-step-button.tsx @@ -0,0 +1,30 @@ +'use client' + +import { RiArrowRightLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' + +type NextStepButtonProps = { + disabled: boolean + onClick: () => void +} + +/** + * Reusable next step button component for dataset creation flow. + */ +function NextStepButton({ disabled, onClick }: NextStepButtonProps) { + const { t } = useTranslation() + + return ( +
+ +
+ ) +} + +export default NextStepButton diff --git a/web/app/components/datasets/create/step-one/components/preview-panel.tsx b/web/app/components/datasets/create/step-one/components/preview-panel.tsx new file mode 100644 index 0000000000..8ae0b7df55 --- /dev/null +++ b/web/app/components/datasets/create/step-one/components/preview-panel.tsx @@ -0,0 +1,62 @@ +'use client' + +import type { NotionPage } from '@/models/common' +import type { CrawlResultItem } from '@/models/datasets' +import { useTranslation } from 'react-i18next' +import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' +import FilePreview from '../../file-preview' +import NotionPagePreview from '../../notion-page-preview' +import WebsitePreview from '../../website/preview' + +type PreviewPanelProps = { + currentFile: File | undefined + currentNotionPage: NotionPage | undefined + currentWebsite: CrawlResultItem | undefined + notionCredentialId: string + isShowPlanUpgradeModal: boolean + hideFilePreview: () => void + hideNotionPagePreview: () => void + hideWebsitePreview: () => void + hidePlanUpgradeModal: () => void +} + +/** + * Right panel component for displaying file, notion page, or website previews. + */ +function PreviewPanel({ + currentFile, + currentNotionPage, + currentWebsite, + notionCredentialId, + isShowPlanUpgradeModal, + hideFilePreview, + hideNotionPagePreview, + hideWebsitePreview, + hidePlanUpgradeModal, +}: PreviewPanelProps) { + const { t } = useTranslation() + + return ( +
+ {currentFile && } + {currentNotionPage && ( + + )} + {currentWebsite && } + {isShowPlanUpgradeModal && ( + + )} +
+ ) +} + +export default PreviewPanel diff --git a/web/app/components/datasets/create/step-one/hooks/index.ts b/web/app/components/datasets/create/step-one/hooks/index.ts new file mode 100644 index 0000000000..bae5ce4fce --- /dev/null +++ b/web/app/components/datasets/create/step-one/hooks/index.ts @@ -0,0 +1,2 @@ +export { default as usePreviewState } from './use-preview-state' +export type { PreviewActions, PreviewState, UsePreviewStateReturn } from './use-preview-state' diff --git a/web/app/components/datasets/create/step-one/hooks/use-preview-state.ts b/web/app/components/datasets/create/step-one/hooks/use-preview-state.ts new file mode 100644 index 0000000000..3984947ab1 --- /dev/null +++ b/web/app/components/datasets/create/step-one/hooks/use-preview-state.ts @@ -0,0 +1,70 @@ +'use client' + +import type { NotionPage } from '@/models/common' +import type { CrawlResultItem } from '@/models/datasets' +import { useCallback, useState } from 'react' + +export type PreviewState = { + currentFile: File | undefined + currentNotionPage: NotionPage | undefined + currentWebsite: CrawlResultItem | undefined +} + +export type PreviewActions = { + showFilePreview: (file: File) => void + hideFilePreview: () => void + showNotionPagePreview: (page: NotionPage) => void + hideNotionPagePreview: () => void + showWebsitePreview: (website: CrawlResultItem) => void + hideWebsitePreview: () => void +} + +export type UsePreviewStateReturn = PreviewState & PreviewActions + +/** + * Custom hook for managing preview state across different data source types. + * Handles file, notion page, and website preview visibility. + */ +function usePreviewState(): UsePreviewStateReturn { + const [currentFile, setCurrentFile] = useState() + const [currentNotionPage, setCurrentNotionPage] = useState() + const [currentWebsite, setCurrentWebsite] = useState() + + const showFilePreview = useCallback((file: File) => { + setCurrentFile(file) + }, []) + + const hideFilePreview = useCallback(() => { + setCurrentFile(undefined) + }, []) + + const showNotionPagePreview = useCallback((page: NotionPage) => { + setCurrentNotionPage(page) + }, []) + + const hideNotionPagePreview = useCallback(() => { + setCurrentNotionPage(undefined) + }, []) + + const showWebsitePreview = useCallback((website: CrawlResultItem) => { + setCurrentWebsite(website) + }, []) + + const hideWebsitePreview = useCallback(() => { + setCurrentWebsite(undefined) + }, []) + + return { + currentFile, + currentNotionPage, + currentWebsite, + showFilePreview, + hideFilePreview, + showNotionPagePreview, + hideNotionPagePreview, + showWebsitePreview, + hideWebsitePreview, + } +} + +export default usePreviewState diff --git a/web/app/components/datasets/create/step-one/index.spec.tsx b/web/app/components/datasets/create/step-one/index.spec.tsx new file mode 100644 index 0000000000..1ff77dc1f6 --- /dev/null +++ b/web/app/components/datasets/create/step-one/index.spec.tsx @@ -0,0 +1,1204 @@ +import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types' +import type { NotionPage } from '@/models/common' +import type { CrawlOptions, CrawlResultItem, DataSet, FileItem } from '@/models/datasets' +import { act, fireEvent, render, renderHook, screen } from '@testing-library/react' +import { Plan } from '@/app/components/billing/type' +import { DataSourceType } from '@/models/datasets' +import { DataSourceTypeSelector, NextStepButton, PreviewPanel } from './components' +import { usePreviewState } from './hooks' +import StepOne from './index' + +// ========================================== +// Mock External Dependencies +// ========================================== + +// Mock config for website crawl features +vi.mock('@/config', () => ({ + ENABLE_WEBSITE_FIRECRAWL: true, + ENABLE_WEBSITE_JINAREADER: false, + ENABLE_WEBSITE_WATERCRAWL: false, +})) + +// Mock dataset detail context +let mockDatasetDetail: DataSet | undefined +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: DataSet | undefined }) => DataSet | undefined) => { + return selector({ dataset: mockDatasetDetail }) + }, +})) + +// Mock provider context +let mockPlan = { + type: Plan.professional, + usage: { vectorSpace: 50, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, + total: { vectorSpace: 100, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, +} +let mockEnableBilling = false + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: mockPlan, + enableBilling: mockEnableBilling, + }), +})) + +// Mock child components +vi.mock('../file-uploader', () => ({ + default: ({ onPreview, fileList }: { onPreview: (file: File) => void, fileList: FileItem[] }) => ( +
+ {fileList.length} + +
+ ), +})) + +vi.mock('../website', () => ({ + default: ({ onPreview }: { onPreview: (item: CrawlResultItem) => void }) => ( +
+ +
+ ), +})) + +vi.mock('../empty-dataset-creation-modal', () => ({ + default: ({ show, onHide }: { show: boolean, onHide: () => void }) => ( + show + ? ( +
+ +
+ ) + : null + ), +})) + +// NotionConnector is a base component - imported directly without mock +// It only depends on i18n which is globally mocked + +vi.mock('@/app/components/base/notion-page-selector', () => ({ + NotionPageSelector: ({ onPreview }: { onPreview: (page: NotionPage) => void }) => ( +
+ +
+ ), +})) + +vi.mock('@/app/components/billing/vector-space-full', () => ({ + default: () =>
Vector Space Full
, +})) + +vi.mock('@/app/components/billing/plan-upgrade-modal', () => ({ + default: ({ show, onClose }: { show: boolean, onClose: () => void }) => ( + show + ? ( +
+ +
+ ) + : null + ), +})) + +vi.mock('../file-preview', () => ({ + default: ({ file, hidePreview }: { file: File, hidePreview: () => void }) => ( +
+ {file.name} + +
+ ), +})) + +vi.mock('../notion-page-preview', () => ({ + default: ({ currentPage, hidePreview }: { currentPage: NotionPage, hidePreview: () => void }) => ( +
+ {currentPage.page_id} + +
+ ), +})) + +// WebsitePreview is a sibling component without API dependencies - imported directly +// It only depends on i18n which is globally mocked + +vi.mock('./upgrade-card', () => ({ + default: () =>
Upgrade Card
, +})) + +// ========================================== +// Test Data Builders +// ========================================== + +const createMockCustomFile = (overrides: { id?: string, name?: string } = {}) => { + const file = new File(['test content'], overrides.name ?? 'test.txt', { type: 'text/plain' }) + return Object.assign(file, { + id: overrides.id ?? 'uploaded-id', + extension: 'txt', + mime_type: 'text/plain', + created_by: 'user-1', + created_at: Date.now(), + }) +} + +const createMockFileItem = (overrides: Partial = {}): FileItem => ({ + fileID: `file-${Date.now()}`, + file: createMockCustomFile(overrides.file as { id?: string, name?: string }), + progress: 100, + ...overrides, +}) + +const createMockNotionPage = (overrides: Partial = {}): NotionPage => ({ + page_id: `page-${Date.now()}`, + type: 'page', + ...overrides, +} as NotionPage) + +const createMockCrawlResult = (overrides: Partial = {}): CrawlResultItem => ({ + title: 'Test Page', + markdown: 'Test content', + description: 'Test description', + source_url: 'https://example.com', + ...overrides, +}) + +const createMockDataSourceAuth = (overrides: Partial = {}): DataSourceAuth => ({ + credential_id: 'cred-1', + provider: 'notion_datasource', + plugin_id: 'plugin-1', + credentials_list: [{ id: 'cred-1', name: 'Workspace 1' }], + ...overrides, +} as DataSourceAuth) + +const defaultProps = { + dataSourceType: DataSourceType.FILE, + dataSourceTypeDisable: false, + onSetting: vi.fn(), + files: [] as FileItem[], + updateFileList: vi.fn(), + updateFile: vi.fn(), + notionPages: [] as NotionPage[], + notionCredentialId: '', + updateNotionPages: vi.fn(), + updateNotionCredentialId: vi.fn(), + onStepChange: vi.fn(), + changeType: vi.fn(), + websitePages: [] as CrawlResultItem[], + updateWebsitePages: vi.fn(), + onWebsiteCrawlProviderChange: vi.fn(), + onWebsiteCrawlJobIdChange: vi.fn(), + crawlOptions: { + crawl_sub_pages: true, + only_main_content: true, + includes: '', + excludes: '', + limit: 10, + max_depth: '', + use_sitemap: true, + } as CrawlOptions, + onCrawlOptionsChange: vi.fn(), + authedDataSourceList: [] as DataSourceAuth[], +} + +// ========================================== +// usePreviewState Hook Tests +// ========================================== +describe('usePreviewState Hook', () => { + // -------------------------------------------------------------------------- + // Initial State Tests + // -------------------------------------------------------------------------- + describe('Initial State', () => { + it('should initialize with all preview states undefined', () => { + // Arrange & Act + const { result } = renderHook(() => usePreviewState()) + + // Assert + expect(result.current.currentFile).toBeUndefined() + expect(result.current.currentNotionPage).toBeUndefined() + expect(result.current.currentWebsite).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // File Preview Tests + // -------------------------------------------------------------------------- + describe('File Preview', () => { + it('should show file preview when showFilePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockFile = new File(['test'], 'test.txt') + + // Act + act(() => { + result.current.showFilePreview(mockFile) + }) + + // Assert + expect(result.current.currentFile).toBe(mockFile) + }) + + it('should hide file preview when hideFilePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockFile = new File(['test'], 'test.txt') + + act(() => { + result.current.showFilePreview(mockFile) + }) + + // Act + act(() => { + result.current.hideFilePreview() + }) + + // Assert + expect(result.current.currentFile).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // Notion Page Preview Tests + // -------------------------------------------------------------------------- + describe('Notion Page Preview', () => { + it('should show notion page preview when showNotionPagePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockPage = createMockNotionPage() + + // Act + act(() => { + result.current.showNotionPagePreview(mockPage) + }) + + // Assert + expect(result.current.currentNotionPage).toBe(mockPage) + }) + + it('should hide notion page preview when hideNotionPagePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockPage = createMockNotionPage() + + act(() => { + result.current.showNotionPagePreview(mockPage) + }) + + // Act + act(() => { + result.current.hideNotionPagePreview() + }) + + // Assert + expect(result.current.currentNotionPage).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // Website Preview Tests + // -------------------------------------------------------------------------- + describe('Website Preview', () => { + it('should show website preview when showWebsitePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockWebsite = createMockCrawlResult() + + // Act + act(() => { + result.current.showWebsitePreview(mockWebsite) + }) + + // Assert + expect(result.current.currentWebsite).toBe(mockWebsite) + }) + + it('should hide website preview when hideWebsitePreview is called', () => { + // Arrange + const { result } = renderHook(() => usePreviewState()) + const mockWebsite = createMockCrawlResult() + + act(() => { + result.current.showWebsitePreview(mockWebsite) + }) + + // Act + act(() => { + result.current.hideWebsitePreview() + }) + + // Assert + expect(result.current.currentWebsite).toBeUndefined() + }) + }) + + // -------------------------------------------------------------------------- + // Callback Stability Tests (Memoization) + // -------------------------------------------------------------------------- + describe('Callback Stability', () => { + it('should maintain stable showFilePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.showFilePreview + + // Act + rerender() + + // Assert + expect(result.current.showFilePreview).toBe(initialCallback) + }) + + it('should maintain stable hideFilePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.hideFilePreview + + // Act + rerender() + + // Assert + expect(result.current.hideFilePreview).toBe(initialCallback) + }) + + it('should maintain stable showNotionPagePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.showNotionPagePreview + + // Act + rerender() + + // Assert + expect(result.current.showNotionPagePreview).toBe(initialCallback) + }) + + it('should maintain stable hideNotionPagePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.hideNotionPagePreview + + // Act + rerender() + + // Assert + expect(result.current.hideNotionPagePreview).toBe(initialCallback) + }) + + it('should maintain stable showWebsitePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.showWebsitePreview + + // Act + rerender() + + // Assert + expect(result.current.showWebsitePreview).toBe(initialCallback) + }) + + it('should maintain stable hideWebsitePreview callback reference', () => { + // Arrange + const { result, rerender } = renderHook(() => usePreviewState()) + const initialCallback = result.current.hideWebsitePreview + + // Act + rerender() + + // Assert + expect(result.current.hideWebsitePreview).toBe(initialCallback) + }) + }) +}) + +// ========================================== +// DataSourceTypeSelector Component Tests +// ========================================== +describe('DataSourceTypeSelector', () => { + const defaultSelectorProps = { + currentType: DataSourceType.FILE, + disabled: false, + onChange: vi.fn(), + onClearPreviews: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render all data source options when web is enabled', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.file')).toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')).toBeInTheDocument() + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.web')).toBeInTheDocument() + }) + + it('should highlight active type', () => { + // Arrange & Act + const { container } = render( + , + ) + + // Assert - The active item should have the active class + const items = container.querySelectorAll('[class*="dataSourceItem"]') + expect(items.length).toBeGreaterThan(0) + }) + }) + + // -------------------------------------------------------------------------- + // User Interactions Tests + // -------------------------------------------------------------------------- + describe('User Interactions', () => { + it('should call onChange when a type is clicked', () => { + // Arrange + const onChange = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(onChange).toHaveBeenCalledWith(DataSourceType.NOTION) + }) + + it('should call onClearPreviews when a type is clicked', () => { + // Arrange + const onClearPreviews = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.web')) + + // Assert + expect(onClearPreviews).toHaveBeenCalledWith(DataSourceType.WEB) + }) + + it('should not call onChange when disabled', () => { + // Arrange + const onChange = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(onChange).not.toHaveBeenCalled() + }) + + it('should not call onClearPreviews when disabled', () => { + // Arrange + const onClearPreviews = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(onClearPreviews).not.toHaveBeenCalled() + }) + }) +}) + +// ========================================== +// NextStepButton Component Tests +// ========================================== +describe('NextStepButton', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render with correct label', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.button')).toBeInTheDocument() + }) + + it('should render with arrow icon', () => { + // Arrange & Act + const { container } = render() + + // Assert + const svgIcon = container.querySelector('svg') + expect(svgIcon).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should be disabled when disabled prop is true', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByRole('button')).toBeDisabled() + }) + + it('should be enabled when disabled prop is false', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByRole('button')).not.toBeDisabled() + }) + + it('should call onClick when clicked and not disabled', () => { + // Arrange + const onClick = vi.fn() + render() + + // Act + fireEvent.click(screen.getByRole('button')) + + // Assert + expect(onClick).toHaveBeenCalledTimes(1) + }) + + it('should not call onClick when clicked and disabled', () => { + // Arrange + const onClick = vi.fn() + render() + + // Act + fireEvent.click(screen.getByRole('button')) + + // Assert + expect(onClick).not.toHaveBeenCalled() + }) + }) +}) + +// ========================================== +// PreviewPanel Component Tests +// ========================================== +describe('PreviewPanel', () => { + const defaultPreviewProps = { + currentFile: undefined as File | undefined, + currentNotionPage: undefined as NotionPage | undefined, + currentWebsite: undefined as CrawlResultItem | undefined, + notionCredentialId: 'cred-1', + isShowPlanUpgradeModal: false, + hideFilePreview: vi.fn(), + hideNotionPagePreview: vi.fn(), + hideWebsitePreview: vi.fn(), + hidePlanUpgradeModal: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Conditional Rendering Tests + // -------------------------------------------------------------------------- + describe('Conditional Rendering', () => { + it('should not render FilePreview when currentFile is undefined', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByTestId('file-preview')).not.toBeInTheDocument() + }) + + it('should render FilePreview when currentFile is defined', () => { + // Arrange + const file = new File(['test'], 'test.txt') + + // Act + render() + + // Assert + expect(screen.getByTestId('file-preview')).toBeInTheDocument() + }) + + it('should not render NotionPagePreview when currentNotionPage is undefined', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByTestId('notion-page-preview')).not.toBeInTheDocument() + }) + + it('should render NotionPagePreview when currentNotionPage is defined', () => { + // Arrange + const page = createMockNotionPage() + + // Act + render() + + // Assert + expect(screen.getByTestId('notion-page-preview')).toBeInTheDocument() + }) + + it('should not render WebsitePreview when currentWebsite is undefined', () => { + // Arrange & Act + render() + + // Assert - pagePreview is the title shown in WebsitePreview + expect(screen.queryByText('datasetCreation.stepOne.pagePreview')).not.toBeInTheDocument() + }) + + it('should render WebsitePreview when currentWebsite is defined', () => { + // Arrange + const website = createMockCrawlResult() + + // Act + render() + + // Assert - Check for the preview title and source URL + expect(screen.getByText('datasetCreation.stepOne.pagePreview')).toBeInTheDocument() + expect(screen.getByText(website.source_url)).toBeInTheDocument() + }) + + it('should not render PlanUpgradeModal when isShowPlanUpgradeModal is false', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByTestId('plan-upgrade-modal')).not.toBeInTheDocument() + }) + + it('should render PlanUpgradeModal when isShowPlanUpgradeModal is true', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByTestId('plan-upgrade-modal')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Event Handler Tests + // -------------------------------------------------------------------------- + describe('Event Handlers', () => { + it('should call hideFilePreview when file preview close is clicked', () => { + // Arrange + const hideFilePreview = vi.fn() + const file = new File(['test'], 'test.txt') + render() + + // Act + fireEvent.click(screen.getByTestId('hide-file-preview')) + + // Assert + expect(hideFilePreview).toHaveBeenCalledTimes(1) + }) + + it('should call hideNotionPagePreview when notion preview close is clicked', () => { + // Arrange + const hideNotionPagePreview = vi.fn() + const page = createMockNotionPage() + render() + + // Act + fireEvent.click(screen.getByTestId('hide-notion-preview')) + + // Assert + expect(hideNotionPagePreview).toHaveBeenCalledTimes(1) + }) + + it('should call hideWebsitePreview when website preview close is clicked', () => { + // Arrange + const hideWebsitePreview = vi.fn() + const website = createMockCrawlResult() + const { container } = render() + + // Act - Find the close button (div with cursor-pointer class containing the XMarkIcon) + const closeButton = container.querySelector('.cursor-pointer') + expect(closeButton).toBeInTheDocument() + fireEvent.click(closeButton!) + + // Assert + expect(hideWebsitePreview).toHaveBeenCalledTimes(1) + }) + + it('should call hidePlanUpgradeModal when modal close is clicked', () => { + // Arrange + const hidePlanUpgradeModal = vi.fn() + render() + + // Act + fireEvent.click(screen.getByTestId('close-upgrade-modal')) + + // Assert + expect(hidePlanUpgradeModal).toHaveBeenCalledTimes(1) + }) + }) +}) + +// ========================================== +// StepOne Component Tests +// ========================================== +describe('StepOne', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDatasetDetail = undefined + mockPlan = { + type: Plan.professional, + usage: { vectorSpace: 50, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, + total: { vectorSpace: 100, buildApps: 0, documentsUploadQuota: 0, vectorStorageQuota: 0 }, + } + mockEnableBilling = false + }) + + // -------------------------------------------------------------------------- + // Rendering Tests + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.steps.one')).toBeInTheDocument() + }) + + it('should render DataSourceTypeSelector when not editing existing dataset', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.dataSourceType.file')).toBeInTheDocument() + }) + + it('should render FileUploader when dataSourceType is FILE', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByTestId('file-uploader')).toBeInTheDocument() + }) + + it('should render NotionConnector when dataSourceType is NOTION and not authenticated', () => { + // Arrange & Act + render() + + // Assert - NotionConnector shows sync title and connect button + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.connect/i })).toBeInTheDocument() + }) + + it('should render NotionPageSelector when dataSourceType is NOTION and authenticated', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth()] + + // Act + render() + + // Assert + expect(screen.getByTestId('notion-page-selector')).toBeInTheDocument() + }) + + it('should render Website when dataSourceType is WEB', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByTestId('website')).toBeInTheDocument() + }) + + it('should render empty dataset creation link when no datasetId', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('datasetCreation.stepOne.emptyDatasetCreation')).toBeInTheDocument() + }) + + it('should not render empty dataset creation link when datasetId exists', () => { + // Arrange & Act + render() + + // Assert + expect(screen.queryByText('datasetCreation.stepOne.emptyDatasetCreation')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should pass files to FileUploader', () => { + // Arrange + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByTestId('file-count')).toHaveTextContent('1') + }) + + it('should call onSetting when NotionConnector connect button is clicked', () => { + // Arrange + const onSetting = vi.fn() + render() + + // Act - The NotionConnector's button calls onSetting + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.connect/i })) + + // Assert + expect(onSetting).toHaveBeenCalledTimes(1) + }) + + it('should call changeType when data source type is changed', () => { + // Arrange + const changeType = vi.fn() + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert + expect(changeType).toHaveBeenCalledWith(DataSourceType.NOTION) + }) + }) + + // -------------------------------------------------------------------------- + // State Management Tests + // -------------------------------------------------------------------------- + describe('State Management', () => { + it('should open empty dataset modal when link is clicked', () => { + // Arrange + render() + + // Act + fireEvent.click(screen.getByText('datasetCreation.stepOne.emptyDatasetCreation')) + + // Assert + expect(screen.getByTestId('empty-dataset-modal')).toBeInTheDocument() + }) + + it('should close empty dataset modal when close is clicked', () => { + // Arrange + render() + fireEvent.click(screen.getByText('datasetCreation.stepOne.emptyDatasetCreation')) + + // Act + fireEvent.click(screen.getByTestId('close-modal')) + + // Assert + expect(screen.queryByTestId('empty-dataset-modal')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Memoization Tests + // -------------------------------------------------------------------------- + describe('Memoization', () => { + it('should correctly compute isNotionAuthed based on authedDataSourceList', () => { + // Arrange - No auth + const { rerender } = render() + // NotionConnector shows the sync title when not authenticated + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + + // Act - Add auth + const authedDataSourceList = [createMockDataSourceAuth()] + rerender() + + // Assert + expect(screen.getByTestId('notion-page-selector')).toBeInTheDocument() + }) + + it('should correctly compute fileNextDisabled when files are empty', () => { + // Arrange & Act + render() + + // Assert - Button should be disabled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + + it('should correctly compute fileNextDisabled when files are loaded', () => { + // Arrange + const files = [createMockFileItem()] + + // Act + render() + + // Assert - Button should be enabled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).not.toBeDisabled() + }) + + it('should correctly compute fileNextDisabled when some files are not uploaded', () => { + // Arrange - Create a file item without id (not yet uploaded) + const file = new File(['test'], 'test.txt', { type: 'text/plain' }) + const fileItem: FileItem = { + fileID: 'temp-id', + file: Object.assign(file, { id: undefined, extension: 'txt', mime_type: 'text/plain' }), + progress: 0, + } + + // Act + render() + + // Assert - Button should be disabled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + }) + + // -------------------------------------------------------------------------- + // Callback Tests + // -------------------------------------------------------------------------- + describe('Callbacks', () => { + it('should call onStepChange when next button is clicked with valid files', () => { + // Arrange + const onStepChange = vi.fn() + const files = [createMockFileItem()] + render() + + // Act + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalledTimes(1) + }) + + it('should show plan upgrade modal when batch upload not supported and multiple files', () => { + // Arrange + mockEnableBilling = true + mockPlan.type = Plan.sandbox + const files = [createMockFileItem(), createMockFileItem()] + render() + + // Act + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(screen.getByTestId('plan-upgrade-modal')).toBeInTheDocument() + }) + + it('should show upgrade card when in sandbox plan with files', () => { + // Arrange + mockEnableBilling = true + mockPlan.type = Plan.sandbox + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByTestId('upgrade-card')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Vector Space Full Tests + // -------------------------------------------------------------------------- + describe('Vector Space Full', () => { + it('should show VectorSpaceFull when vector space is full and billing is enabled', () => { + // Arrange + mockEnableBilling = true + mockPlan.usage.vectorSpace = 100 + mockPlan.total.vectorSpace = 100 + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByTestId('vector-space-full')).toBeInTheDocument() + }) + + it('should disable next button when vector space is full', () => { + // Arrange + mockEnableBilling = true + mockPlan.usage.vectorSpace = 100 + mockPlan.total.vectorSpace = 100 + const files = [createMockFileItem()] + + // Act + render() + + // Assert + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + }) + + // -------------------------------------------------------------------------- + // Preview Integration Tests + // -------------------------------------------------------------------------- + describe('Preview Integration', () => { + it('should show file preview when file preview button is clicked', () => { + // Arrange + render() + + // Act + fireEvent.click(screen.getByTestId('preview-file')) + + // Assert + expect(screen.getByTestId('file-preview')).toBeInTheDocument() + }) + + it('should hide file preview when hide button is clicked', () => { + // Arrange + render() + fireEvent.click(screen.getByTestId('preview-file')) + + // Act + fireEvent.click(screen.getByTestId('hide-file-preview')) + + // Assert + expect(screen.queryByTestId('file-preview')).not.toBeInTheDocument() + }) + + it('should show notion page preview when preview button is clicked', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth()] + render() + + // Act + fireEvent.click(screen.getByTestId('preview-notion')) + + // Assert + expect(screen.getByTestId('notion-page-preview')).toBeInTheDocument() + }) + + it('should show website preview when preview button is clicked', () => { + // Arrange + render() + + // Act + fireEvent.click(screen.getByTestId('preview-website')) + + // Assert - Check for pagePreview title which is shown by WebsitePreview + expect(screen.getByText('datasetCreation.stepOne.pagePreview')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle empty notionPages array', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth()] + + // Act + render() + + // Assert - Button should be disabled when no pages selected + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + + it('should handle empty websitePages array', () => { + // Arrange & Act + render() + + // Assert - Button should be disabled when no pages crawled + expect(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })).toBeDisabled() + }) + + it('should handle empty authedDataSourceList', () => { + // Arrange & Act + render() + + // Assert - Should show NotionConnector with connect button + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + }) + + it('should handle authedDataSourceList without notion credentials', () => { + // Arrange + const authedDataSourceList = [createMockDataSourceAuth({ credentials_list: [] })] + + // Act + render() + + // Assert - Should show NotionConnector with connect button + expect(screen.getByText('datasetCreation.stepOne.notionSyncTitle')).toBeInTheDocument() + }) + + it('should clear previews when switching data source types', () => { + // Arrange + render() + fireEvent.click(screen.getByTestId('preview-file')) + expect(screen.getByTestId('file-preview')).toBeInTheDocument() + + // Act - Change to NOTION + fireEvent.click(screen.getByText('datasetCreation.stepOne.dataSourceType.notion')) + + // Assert - File preview should be cleared + expect(screen.queryByTestId('file-preview')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Integration Tests + // -------------------------------------------------------------------------- + describe('Integration', () => { + it('should complete file upload flow', () => { + // Arrange + const onStepChange = vi.fn() + const files = [createMockFileItem()] + + // Act + render() + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalled() + }) + + it('should complete notion page selection flow', () => { + // Arrange + const onStepChange = vi.fn() + const authedDataSourceList = [createMockDataSourceAuth()] + const notionPages = [createMockNotionPage()] + + // Act + render( + , + ) + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalled() + }) + + it('should complete website crawl flow', () => { + // Arrange + const onStepChange = vi.fn() + const websitePages = [createMockCrawlResult()] + + // Act + render( + , + ) + fireEvent.click(screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })) + + // Assert + expect(onStepChange).toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index 5c74e69e7e..a86c9d86c2 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -1,29 +1,25 @@ 'use client' + import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types' import type { DataSourceProvider, NotionPage } from '@/models/common' import type { CrawlOptions, CrawlResultItem, FileItem } from '@/models/datasets' -import { RiArrowRightLine, RiFolder6Line } from '@remixicon/react' +import { RiFolder6Line } from '@remixicon/react' import { useBoolean } from 'ahooks' -import * as React from 'react' -import { useCallback, useMemo, useState } from 'react' +import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import NotionConnector from '@/app/components/base/notion-connector' import { NotionPageSelector } from '@/app/components/base/notion-page-selector' -import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' import { Plan } from '@/app/components/billing/type' import VectorSpaceFull from '@/app/components/billing/vector-space-full' -import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' import { DataSourceType } from '@/models/datasets' import { cn } from '@/utils/classnames' import EmptyDatasetCreationModal from '../empty-dataset-creation-modal' -import FilePreview from '../file-preview' import FileUploader from '../file-uploader' -import NotionPagePreview from '../notion-page-preview' import Website from '../website' -import WebsitePreview from '../website/preview' +import { DataSourceTypeSelector, NextStepButton, PreviewPanel } from './components' +import { usePreviewState } from './hooks' import s from './index.module.css' import UpgradeCard from './upgrade-card' @@ -50,6 +46,24 @@ type IStepOneProps = { authedDataSourceList: DataSourceAuth[] } +// Helper function to check if notion is authenticated +function checkNotionAuth(authedDataSourceList: DataSourceAuth[]): boolean { + const notionSource = authedDataSourceList.find(item => item.provider === 'notion_datasource') + return Boolean(notionSource && notionSource.credentials_list.length > 0) +} + +// Helper function to get notion credential list +function getNotionCredentialList(authedDataSourceList: DataSourceAuth[]) { + return authedDataSourceList.find(item => item.provider === 'notion_datasource')?.credentials_list || [] +} + +// Lookup table for checking multiple items by data source type +const MULTIPLE_ITEMS_CHECK: Record boolean> = { + [DataSourceType.FILE]: ({ files }) => files.length > 1, + [DataSourceType.NOTION]: ({ notionPages }) => notionPages.length > 1, + [DataSourceType.WEB]: ({ websitePages }) => websitePages.length > 1, +} + const StepOne = ({ datasetId, dataSourceType: inCreatePageDataSourceType, @@ -72,76 +86,47 @@ const StepOne = ({ onCrawlOptionsChange, authedDataSourceList, }: IStepOneProps) => { - const dataset = useDatasetDetailContextWithSelector(state => state.dataset) - const [showModal, setShowModal] = useState(false) - const [currentFile, setCurrentFile] = useState() - const [currentNotionPage, setCurrentNotionPage] = useState() - const [currentWebsite, setCurrentWebsite] = useState() const { t } = useTranslation() + const dataset = useDatasetDetailContextWithSelector(state => state.dataset) + const { plan, enableBilling } = useProviderContext() - const modalShowHandle = () => setShowModal(true) - const modalCloseHandle = () => setShowModal(false) + // Preview state management + const { + currentFile, + currentNotionPage, + currentWebsite, + showFilePreview, + hideFilePreview, + showNotionPagePreview, + hideNotionPagePreview, + showWebsitePreview, + hideWebsitePreview, + } = usePreviewState() - const updateCurrentFile = useCallback((file: File) => { - setCurrentFile(file) - }, []) + // Empty dataset modal state + const [showModal, { setTrue: openModal, setFalse: closeModal }] = useBoolean(false) - const hideFilePreview = useCallback(() => { - setCurrentFile(undefined) - }, []) - - const updateCurrentPage = useCallback((page: NotionPage) => { - setCurrentNotionPage(page) - }, []) - - const hideNotionPagePreview = useCallback(() => { - setCurrentNotionPage(undefined) - }, []) - - const updateWebsite = useCallback((website: CrawlResultItem) => { - setCurrentWebsite(website) - }, []) - - const hideWebsitePreview = useCallback(() => { - setCurrentWebsite(undefined) - }, []) + // Plan upgrade modal state + const [isShowPlanUpgradeModal, { setTrue: showPlanUpgradeModal, setFalse: hidePlanUpgradeModal }] = useBoolean(false) + // Computed values const shouldShowDataSourceTypeList = !datasetId || (datasetId && !dataset?.data_source_type) const isInCreatePage = shouldShowDataSourceTypeList - const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : dataset?.data_source_type - const { plan, enableBilling } = useProviderContext() - const allFileLoaded = (files.length > 0 && files.every(file => file.file.id)) - const hasNotin = notionPages.length > 0 + // Default to FILE type when no type is provided from either source + const dataSourceType = isInCreatePage + ? (inCreatePageDataSourceType ?? DataSourceType.FILE) + : (dataset?.data_source_type ?? DataSourceType.FILE) + + const allFileLoaded = files.length > 0 && files.every(file => file.file.id) + const hasNotion = notionPages.length > 0 const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace - const isShowVectorSpaceFull = (allFileLoaded || hasNotin) && isVectorSpaceFull && enableBilling + const isShowVectorSpaceFull = (allFileLoaded || hasNotion) && isVectorSpaceFull && enableBilling const supportBatchUpload = !enableBilling || plan.type !== Plan.sandbox - const notSupportBatchUpload = !supportBatchUpload - const [isShowPlanUpgradeModal, { - setTrue: showPlanUpgradeModal, - setFalse: hidePlanUpgradeModal, - }] = useBoolean(false) - const onStepChange = useCallback(() => { - if (notSupportBatchUpload) { - let isMultiple = false - if (dataSourceType === DataSourceType.FILE && files.length > 1) - isMultiple = true + const isNotionAuthed = useMemo(() => checkNotionAuth(authedDataSourceList), [authedDataSourceList]) + const notionCredentialList = useMemo(() => getNotionCredentialList(authedDataSourceList), [authedDataSourceList]) - if (dataSourceType === DataSourceType.NOTION && notionPages.length > 1) - isMultiple = true - - if (dataSourceType === DataSourceType.WEB && websitePages.length > 1) - isMultiple = true - - if (isMultiple) { - showPlanUpgradeModal() - return - } - } - doOnStepChange() - }, [dataSourceType, doOnStepChange, files.length, notSupportBatchUpload, notionPages.length, showPlanUpgradeModal, websitePages.length]) - - const nextDisabled = useMemo(() => { + const fileNextDisabled = useMemo(() => { if (!files.length) return true if (files.some(file => !file.file.id)) @@ -149,109 +134,50 @@ const StepOne = ({ return isShowVectorSpaceFull }, [files, isShowVectorSpaceFull]) - const isNotionAuthed = useMemo(() => { - if (!authedDataSourceList) - return false - const notionSource = authedDataSourceList.find(item => item.provider === 'notion_datasource') - if (!notionSource) - return false - return notionSource.credentials_list.length > 0 - }, [authedDataSourceList]) + // Clear previews when switching data source type + const handleClearPreviews = useCallback((newType: DataSourceType) => { + if (newType !== DataSourceType.FILE) + hideFilePreview() + if (newType !== DataSourceType.NOTION) + hideNotionPagePreview() + if (newType !== DataSourceType.WEB) + hideWebsitePreview() + }, [hideFilePreview, hideNotionPagePreview, hideWebsitePreview]) - const notionCredentialList = useMemo(() => { - return authedDataSourceList.find(item => item.provider === 'notion_datasource')?.credentials_list || [] - }, [authedDataSourceList]) + // Handle step change with batch upload check + const onStepChange = useCallback(() => { + if (!supportBatchUpload && dataSourceType) { + const checkFn = MULTIPLE_ITEMS_CHECK[dataSourceType] + if (checkFn?.({ files, notionPages, websitePages })) { + showPlanUpgradeModal() + return + } + } + doOnStepChange() + }, [dataSourceType, doOnStepChange, files, supportBatchUpload, notionPages, showPlanUpgradeModal, websitePages]) return (
+ {/* Left Panel - Form */}
- { - shouldShowDataSourceTypeList && ( + {shouldShowDataSourceTypeList && ( + <>
{t('steps.one', { ns: 'datasetCreation' })}
- ) - } - { - shouldShowDataSourceTypeList && ( -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.FILE) - hideNotionPagePreview() - hideWebsitePreview() - }} - > - - - {t('stepOne.dataSourceType.file', { ns: 'datasetCreation' })} - -
-
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.NOTION) - hideFilePreview() - hideWebsitePreview() - }} - > - - - {t('stepOne.dataSourceType.notion', { ns: 'datasetCreation' })} - -
- {(ENABLE_WEBSITE_FIRECRAWL || ENABLE_WEBSITE_JINAREADER || ENABLE_WEBSITE_WATERCRAWL) && ( -
{ - if (dataSourceTypeDisable) - return - changeType(DataSourceType.WEB) - hideFilePreview() - hideNotionPagePreview() - }} - > - - - {t('stepOne.dataSourceType.web', { ns: 'datasetCreation' })} - -
- )} -
- ) - } + + + )} + + {/* File Data Source */} {dataSourceType === DataSourceType.FILE && ( <> {isShowVectorSpaceFull && ( @@ -268,24 +194,17 @@ const StepOne = ({
)} -
- -
- { - enableBilling && plan.type === Plan.sandbox && files.length > 0 && ( -
-
- -
- ) - } + + {enableBilling && plan.type === Plan.sandbox && files.length > 0 && ( +
+
+ +
+ )} )} + + {/* Notion Data Source */} {dataSourceType === DataSourceType.NOTION && ( <> {!isNotionAuthed && } @@ -295,7 +214,7 @@ const StepOne = ({ page.page_id)} onSelect={updateNotionPages} - onPreview={updateCurrentPage} + onPreview={showNotionPagePreview} credentialList={notionCredentialList} onSelectCredential={updateNotionCredentialId} datasetId={datasetId} @@ -306,23 +225,21 @@ const StepOne = ({
)} -
- -
+ )} )} + + {/* Web Data Source */} {dataSourceType === DataSourceType.WEB && ( <>
)} -
- -
+ )} + + {/* Empty Dataset Creation Link */} {!datasetId && ( <>
- + {t('stepOne.emptyDatasetCreation', { ns: 'datasetCreation' })} )}
- +
-
- {currentFile && } - {currentNotionPage && ( - - )} - {currentWebsite && } - {isShowPlanUpgradeModal && ( - - )} -
+ + {/* Right Panel - Preview */} +
) diff --git a/web/app/components/datasets/list/dataset-card/components/corner-labels.tsx b/web/app/components/datasets/list/dataset-card/components/corner-labels.tsx new file mode 100644 index 0000000000..03ca543ee7 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/corner-labels.tsx @@ -0,0 +1,36 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import CornerLabel from '@/app/components/base/corner-label' + +type CornerLabelsProps = { + dataset: DataSet +} + +const CornerLabels = ({ dataset }: CornerLabelsProps) => { + const { t } = useTranslation() + + if (!dataset.embedding_available) { + return ( + + ) + } + + if (dataset.runtime_mode === 'rag_pipeline') { + return ( + + ) + } + + return null +} + +export default React.memo(CornerLabels) diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx new file mode 100644 index 0000000000..854f34f49c --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-footer.tsx @@ -0,0 +1,62 @@ +import type { DataSet } from '@/models/datasets' +import { RiFileTextFill, RiRobot2Fill } from '@remixicon/react' +import * as React from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import Tooltip from '@/app/components/base/tooltip' +import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { cn } from '@/utils/classnames' + +const EXTERNAL_PROVIDER = 'external' + +type DatasetCardFooterProps = { + dataset: DataSet +} + +const DatasetCardFooter = ({ dataset }: DatasetCardFooterProps) => { + const { t } = useTranslation() + const { formatTimeFromNow } = useFormatTimeFromNow() + const isExternalProvider = dataset.provider === EXTERNAL_PROVIDER + + const documentCount = useMemo(() => { + const availableDocCount = dataset.total_available_documents ?? 0 + if (availableDocCount < dataset.document_count) + return `${availableDocCount} / ${dataset.document_count}` + return `${dataset.document_count}` + }, [dataset.document_count, dataset.total_available_documents]) + + const documentCountTooltip = useMemo(() => { + const availableDocCount = dataset.total_available_documents ?? 0 + if (availableDocCount < dataset.document_count) + return t('partialEnabled', { ns: 'dataset', count: dataset.document_count, num: availableDocCount }) + return t('docAllEnabled', { ns: 'dataset', count: availableDocCount }) + }, [t, dataset.document_count, dataset.total_available_documents]) + + return ( +
+ +
+ + {documentCount} +
+
+ {!isExternalProvider && ( + +
+ + {dataset.app_count} +
+
+ )} + / + {`${t('updated', { ns: 'dataset' })} ${formatTimeFromNow(dataset.updated_at * 1000)}`} +
+ ) +} + +export default React.memo(DatasetCardFooter) diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-header.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-header.tsx new file mode 100644 index 0000000000..abe7595e14 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-header.tsx @@ -0,0 +1,148 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import AppIcon from '@/app/components/base/app-icon' +import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { useKnowledge } from '@/hooks/use-knowledge' +import { DOC_FORM_ICON_WITH_BG, DOC_FORM_TEXT } from '@/models/datasets' +import { cn } from '@/utils/classnames' + +const EXTERNAL_PROVIDER = 'external' + +type DatasetCardHeaderProps = { + dataset: DataSet +} + +// DocModeInfo component - placed before usage +type DocModeInfoProps = { + dataset: DataSet + isExternalProvider: boolean + isShowDocModeInfo: boolean +} + +const DocModeInfo = ({ + dataset, + isExternalProvider, + isShowDocModeInfo, +}: DocModeInfoProps) => { + const { t } = useTranslation() + const { formatIndexingTechniqueAndMethod } = useKnowledge() + + if (isExternalProvider) { + return ( +
+ {t('externalKnowledgeBase', { ns: 'dataset' })} +
+ ) + } + + if (!isShowDocModeInfo) + return null + + const indexingText = dataset.indexing_technique + ? formatIndexingTechniqueAndMethod( + dataset.indexing_technique as 'economy' | 'high_quality', + dataset.retrieval_model_dict?.search_method as Parameters[1], + ) + : '' + + return ( +
+ {dataset.doc_form && ( + + {t(`chunkingMode.${DOC_FORM_TEXT[dataset.doc_form]}`, { ns: 'dataset' })} + + )} + {dataset.indexing_technique && indexingText && ( + + {indexingText} + + )} + {dataset.is_multimodal && ( + + {t('multimodal', { ns: 'dataset' })} + + )} +
+ ) +} + +// Main DatasetCardHeader component +const DatasetCardHeader = ({ dataset }: DatasetCardHeaderProps) => { + const { t } = useTranslation() + const { formatTimeFromNow } = useFormatTimeFromNow() + + const isExternalProvider = dataset.provider === EXTERNAL_PROVIDER + + const isShowChunkingModeIcon = dataset.doc_form && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published) + const isShowDocModeInfo = Boolean( + dataset.doc_form + && dataset.indexing_technique + && dataset.retrieval_model_dict?.search_method + && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published), + ) + + const chunkingModeIcon = dataset.doc_form ? DOC_FORM_ICON_WITH_BG[dataset.doc_form] : React.Fragment + const Icon = isExternalProvider ? DOC_FORM_ICON_WITH_BG.external : chunkingModeIcon + + const iconInfo = useMemo(() => dataset.icon_info || { + icon: '📙', + icon_type: 'emoji' as const, + icon_background: '#FFF4ED', + icon_url: '', + }, [dataset.icon_info]) + + const editTimeText = useMemo( + () => `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${formatTimeFromNow(dataset.updated_at * 1000)}`, + [t, dataset.updated_at, formatTimeFromNow], + ) + + return ( +
+
+ + {(isShowChunkingModeIcon || isExternalProvider) && ( +
+ +
+ )} +
+
+
+ {dataset.name} +
+
+
{dataset.author_name}
+
·
+
{editTimeText}
+
+ +
+
+ ) +} + +export default React.memo(DatasetCardHeader) diff --git a/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.tsx b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.tsx new file mode 100644 index 0000000000..8162bc94c4 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/dataset-card-modals.tsx @@ -0,0 +1,55 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import Confirm from '@/app/components/base/confirm' +import RenameDatasetModal from '../../../rename-modal' + +type ModalState = { + showRenameModal: boolean + showConfirmDelete: boolean + confirmMessage: string +} + +type DatasetCardModalsProps = { + dataset: DataSet + modalState: ModalState + onCloseRename: () => void + onCloseConfirm: () => void + onConfirmDelete: () => void + onSuccess?: () => void +} + +const DatasetCardModals = ({ + dataset, + modalState, + onCloseRename, + onCloseConfirm, + onConfirmDelete, + onSuccess, +}: DatasetCardModalsProps) => { + const { t } = useTranslation() + + return ( + <> + {modalState.showRenameModal && ( + + )} + {modalState.showConfirmDelete && ( + + )} + + ) +} + +export default React.memo(DatasetCardModals) diff --git a/web/app/components/datasets/list/dataset-card/components/description.tsx b/web/app/components/datasets/list/dataset-card/components/description.tsx new file mode 100644 index 0000000000..79604e92ab --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/description.tsx @@ -0,0 +1,18 @@ +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import { cn } from '@/utils/classnames' + +type DescriptionProps = { + dataset: DataSet +} + +const Description = ({ dataset }: DescriptionProps) => ( +
+ {dataset.description} +
+) + +export default React.memo(Description) diff --git a/web/app/components/datasets/list/dataset-card/components/operations-popover.tsx b/web/app/components/datasets/list/dataset-card/components/operations-popover.tsx new file mode 100644 index 0000000000..80ae2fb7a1 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/operations-popover.tsx @@ -0,0 +1,52 @@ +import type { DataSet } from '@/models/datasets' +import { RiMoreFill } from '@remixicon/react' +import * as React from 'react' +import CustomPopover from '@/app/components/base/popover' +import { cn } from '@/utils/classnames' +import Operations from '../operations' + +type OperationsPopoverProps = { + dataset: DataSet + isCurrentWorkspaceDatasetOperator: boolean + openRenameModal: () => void + handleExportPipeline: (include?: boolean) => void + detectIsUsedByApp: () => void +} + +const OperationsPopover = ({ + dataset, + isCurrentWorkspaceDatasetOperator, + openRenameModal, + handleExportPipeline, + detectIsUsedByApp, +}: OperationsPopoverProps) => ( +
+ + )} + className="z-20 min-w-[186px]" + popupClassName="rounded-xl bg-none shadow-none ring-0 min-w-[186px]" + position="br" + trigger="click" + btnElement={( +
+ +
+ )} + btnClassName={open => + cn( + 'size-9 cursor-pointer justify-center rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0 shadow-lg shadow-shadow-shadow-5 ring-[2px] ring-inset ring-components-actionbar-bg hover:border-components-actionbar-border', + open ? 'border-components-actionbar-border bg-state-base-hover' : '', + )} + /> +
+) + +export default React.memo(OperationsPopover) diff --git a/web/app/components/datasets/list/dataset-card/components/tag-area.tsx b/web/app/components/datasets/list/dataset-card/components/tag-area.tsx new file mode 100644 index 0000000000..f55a064387 --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/components/tag-area.tsx @@ -0,0 +1,55 @@ +import type { Tag } from '@/app/components/base/tag-management/constant' +import type { DataSet } from '@/models/datasets' +import * as React from 'react' +import TagSelector from '@/app/components/base/tag-management/selector' +import { cn } from '@/utils/classnames' + +type TagAreaProps = { + dataset: DataSet + tags: Tag[] + setTags: (tags: Tag[]) => void + onSuccess?: () => void + isHoveringTagSelector: boolean + onClick: (e: React.MouseEvent) => void +} + +const TagArea = React.forwardRef(({ + dataset, + tags, + setTags, + onSuccess, + isHoveringTagSelector, + onClick, +}, ref) => ( +
+
0 && 'visible', + )} + > + tag.id)} + selectedTags={tags} + onCacheUpdate={setTags} + onChange={onSuccess} + /> +
+
+
+)) +TagArea.displayName = 'TagArea' + +export default TagArea diff --git a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts new file mode 100644 index 0000000000..ad68a1df1c --- /dev/null +++ b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts @@ -0,0 +1,138 @@ +import type { Tag } from '@/app/components/base/tag-management/constant' +import type { DataSet } from '@/models/datasets' +import { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Toast from '@/app/components/base/toast' +import { useCheckDatasetUsage, useDeleteDataset } from '@/service/use-dataset-card' +import { useExportPipelineDSL } from '@/service/use-pipeline' + +type ModalState = { + showRenameModal: boolean + showConfirmDelete: boolean + confirmMessage: string +} + +type UseDatasetCardStateOptions = { + dataset: DataSet + onSuccess?: () => void +} + +export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateOptions) => { + const { t } = useTranslation() + const [tags, setTags] = useState(dataset.tags) + + useEffect(() => { + setTags(dataset.tags) + }, [dataset.tags]) + + // Modal state + const [modalState, setModalState] = useState({ + showRenameModal: false, + showConfirmDelete: false, + confirmMessage: '', + }) + + // Export state + const [exporting, setExporting] = useState(false) + + // Modal handlers + const openRenameModal = useCallback(() => { + setModalState(prev => ({ ...prev, showRenameModal: true })) + }, []) + + const closeRenameModal = useCallback(() => { + setModalState(prev => ({ ...prev, showRenameModal: false })) + }, []) + + const closeConfirmDelete = useCallback(() => { + setModalState(prev => ({ ...prev, showConfirmDelete: false })) + }, []) + + // API mutations + const { mutateAsync: checkUsage } = useCheckDatasetUsage() + const { mutateAsync: deleteDatasetMutation } = useDeleteDataset() + const { mutateAsync: exportPipelineConfig } = useExportPipelineDSL() + + // Export pipeline handler + const handleExportPipeline = useCallback(async (include: boolean = false) => { + const { pipeline_id, name } = dataset + if (!pipeline_id || exporting) + return + + try { + setExporting(true) + const { data } = await exportPipelineConfig({ + pipelineId: pipeline_id, + include, + }) + const a = document.createElement('a') + const file = new Blob([data], { type: 'application/yaml' }) + const url = URL.createObjectURL(file) + a.href = url + a.download = `${name}.pipeline` + a.click() + URL.revokeObjectURL(url) + } + catch { + Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + } + finally { + setExporting(false) + } + }, [dataset, exportPipelineConfig, exporting, t]) + + // Delete flow handlers + const detectIsUsedByApp = useCallback(async () => { + try { + const { is_using: isUsedByApp } = await checkUsage(dataset.id) + const message = isUsedByApp + ? t('datasetUsedByApp', { ns: 'dataset' })! + : t('deleteDatasetConfirmContent', { ns: 'dataset' })! + setModalState(prev => ({ + ...prev, + confirmMessage: message, + showConfirmDelete: true, + })) + } + catch (e: unknown) { + if (e instanceof Response) { + const res = await e.json() + Toast.notify({ type: 'error', message: res?.message || 'Unknown error' }) + } + else { + Toast.notify({ type: 'error', message: (e as Error)?.message || 'Unknown error' }) + } + } + }, [dataset.id, checkUsage, t]) + + const onConfirmDelete = useCallback(async () => { + try { + await deleteDatasetMutation(dataset.id) + Toast.notify({ type: 'success', message: t('datasetDeleted', { ns: 'dataset' }) }) + onSuccess?.() + } + finally { + closeConfirmDelete() + } + }, [dataset.id, deleteDatasetMutation, onSuccess, t, closeConfirmDelete]) + + return { + // Tag state + tags, + setTags, + + // Modal state + modalState, + openRenameModal, + closeRenameModal, + closeConfirmDelete, + + // Export state + exporting, + + // Handlers + handleExportPipeline, + detectIsUsedByApp, + onConfirmDelete, + } +} diff --git a/web/app/components/datasets/list/dataset-card/index.tsx b/web/app/components/datasets/list/dataset-card/index.tsx index 99404b0454..85dba7e8ff 100644 --- a/web/app/components/datasets/list/dataset-card/index.tsx +++ b/web/app/components/datasets/list/dataset-card/index.tsx @@ -1,28 +1,17 @@ 'use client' -import type { Tag } from '@/app/components/base/tag-management/constant' import type { DataSet } from '@/models/datasets' -import { RiFileTextFill, RiMoreFill, RiRobot2Fill } from '@remixicon/react' import { useHover } from 'ahooks' import { useRouter } from 'next/navigation' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' -import { useTranslation } from 'react-i18next' -import AppIcon from '@/app/components/base/app-icon' -import Confirm from '@/app/components/base/confirm' -import CornerLabel from '@/app/components/base/corner-label' -import CustomPopover from '@/app/components/base/popover' -import TagSelector from '@/app/components/base/tag-management/selector' -import Toast from '@/app/components/base/toast' -import Tooltip from '@/app/components/base/tooltip' +import { useMemo, useRef } from 'react' import { useSelector as useAppContextWithSelector } from '@/context/app-context' -import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' -import { useKnowledge } from '@/hooks/use-knowledge' -import { DOC_FORM_ICON_WITH_BG, DOC_FORM_TEXT } from '@/models/datasets' -import { checkIsUsedInApp, deleteDataset } from '@/service/datasets' -import { useExportPipelineDSL } from '@/service/use-pipeline' -import { cn } from '@/utils/classnames' -import RenameDatasetModal from '../../rename-modal' -import Operations from './operations' +import CornerLabels from './components/corner-labels' +import DatasetCardFooter from './components/dataset-card-footer' +import DatasetCardHeader from './components/dataset-card-header' +import DatasetCardModals from './components/dataset-card-modals' +import Description from './components/description' +import OperationsPopover from './components/operations-popover' +import TagArea from './components/tag-area' +import { useDatasetCardState } from './hooks/use-dataset-card-state' const EXTERNAL_PROVIDER = 'external' @@ -35,320 +24,80 @@ const DatasetCard = ({ dataset, onSuccess, }: DatasetCardProps) => { - const { t } = useTranslation() const { push } = useRouter() const isCurrentWorkspaceDatasetOperator = useAppContextWithSelector(state => state.isCurrentWorkspaceDatasetOperator) - const [tags, setTags] = useState(dataset.tags) const tagSelectorRef = useRef(null) const isHoveringTagSelector = useHover(tagSelectorRef) - const [showRenameModal, setShowRenameModal] = useState(false) - const [showConfirmDelete, setShowConfirmDelete] = useState(false) - const [confirmMessage, setConfirmMessage] = useState('') - const [exporting, setExporting] = useState(false) + const { + tags, + setTags, + modalState, + openRenameModal, + closeRenameModal, + closeConfirmDelete, + handleExportPipeline, + detectIsUsedByApp, + onConfirmDelete, + } = useDatasetCardState({ dataset, onSuccess }) - const isExternalProvider = useMemo(() => { - return dataset.provider === EXTERNAL_PROVIDER - }, [dataset.provider]) + const isExternalProvider = dataset.provider === EXTERNAL_PROVIDER const isPipelineUnpublished = useMemo(() => { return dataset.runtime_mode === 'rag_pipeline' && !dataset.is_published }, [dataset.runtime_mode, dataset.is_published]) - const isShowChunkingModeIcon = useMemo(() => { - return dataset.doc_form && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published) - }, [dataset.doc_form, dataset.runtime_mode, dataset.is_published]) - const isShowDocModeInfo = useMemo(() => { - return dataset.doc_form && dataset.indexing_technique && dataset.retrieval_model_dict?.search_method && (dataset.runtime_mode !== 'rag_pipeline' || dataset.is_published) - }, [dataset.doc_form, dataset.indexing_technique, dataset.retrieval_model_dict?.search_method, dataset.runtime_mode, dataset.is_published]) - const chunkingModeIcon = dataset.doc_form ? DOC_FORM_ICON_WITH_BG[dataset.doc_form] : React.Fragment - const Icon = isExternalProvider ? DOC_FORM_ICON_WITH_BG.external : chunkingModeIcon - const iconInfo = dataset.icon_info || { - icon: '📙', - icon_type: 'emoji', - icon_background: '#FFF4ED', - icon_url: '', + const handleCardClick = (e: React.MouseEvent) => { + e.preventDefault() + if (isExternalProvider) + push(`/datasets/${dataset.id}/hitTesting`) + else if (isPipelineUnpublished) + push(`/datasets/${dataset.id}/pipeline`) + else + push(`/datasets/${dataset.id}/documents`) } - const { formatIndexingTechniqueAndMethod } = useKnowledge() - const documentCount = useMemo(() => { - const availableDocCount = dataset.total_available_documents ?? 0 - if (availableDocCount === dataset.document_count) - return `${dataset.document_count}` - if (availableDocCount < dataset.document_count) - return `${availableDocCount} / ${dataset.document_count}` - }, [dataset.document_count, dataset.total_available_documents]) - const documentCountTooltip = useMemo(() => { - const availableDocCount = dataset.total_available_documents ?? 0 - if (availableDocCount === dataset.document_count) - return t('docAllEnabled', { ns: 'dataset', count: availableDocCount }) - if (availableDocCount < dataset.document_count) - return t('partialEnabled', { ns: 'dataset', count: dataset.document_count, num: availableDocCount }) - }, [t, dataset.document_count, dataset.total_available_documents]) - const { formatTimeFromNow } = useFormatTimeFromNow() - const editTimeText = useMemo(() => { - return `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${formatTimeFromNow(dataset.updated_at * 1000)}` - }, [t, dataset.updated_at, formatTimeFromNow]) - - const openRenameModal = useCallback(() => { - setShowRenameModal(true) - }, []) - - const { mutateAsync: exportPipelineConfig } = useExportPipelineDSL() - - const handleExportPipeline = useCallback(async (include = false) => { - const { pipeline_id, name } = dataset - if (!pipeline_id) - return - - if (exporting) - return - - try { - setExporting(true) - const { data } = await exportPipelineConfig({ - pipelineId: pipeline_id, - include, - }) - const a = document.createElement('a') - const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${name}.pipeline` - a.click() - URL.revokeObjectURL(url) - } - catch { - Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) - } - finally { - setExporting(false) - } - }, [dataset, exportPipelineConfig, exporting, t]) - - const detectIsUsedByApp = useCallback(async () => { - try { - const { is_using: isUsedByApp } = await checkIsUsedInApp(dataset.id) - setConfirmMessage(isUsedByApp ? t('datasetUsedByApp', { ns: 'dataset' })! : t('deleteDatasetConfirmContent', { ns: 'dataset' })!) - setShowConfirmDelete(true) - } - catch (e: any) { - const res = await e.json() - Toast.notify({ type: 'error', message: res?.message || 'Unknown error' }) - } - }, [dataset.id, t]) - - const onConfirmDelete = useCallback(async () => { - try { - await deleteDataset(dataset.id) - Toast.notify({ type: 'success', message: t('datasetDeleted', { ns: 'dataset' }) }) - if (onSuccess) - onSuccess() - } - finally { - setShowConfirmDelete(false) - } - }, [dataset.id, onSuccess, t]) - - useEffect(() => { - setTags(dataset.tags) - }, [dataset]) + const handleTagAreaClick = (e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + } return ( <>
{ - e.preventDefault() - if (isExternalProvider) - push(`/datasets/${dataset.id}/hitTesting`) - else if (isPipelineUnpublished) - push(`/datasets/${dataset.id}/pipeline`) - else - push(`/datasets/${dataset.id}/documents`) - }} + onClick={handleCardClick} > - {!dataset.embedding_available && ( - - )} - {dataset.embedding_available && dataset.runtime_mode === 'rag_pipeline' && ( - - )} -
-
- - {(isShowChunkingModeIcon || isExternalProvider) && ( -
- -
- )} -
-
-
- {dataset.name} -
-
-
{dataset.author_name}
-
·
-
{editTimeText}
-
-
- {isExternalProvider && {t('externalKnowledgeBase', { ns: 'dataset' })}} - {!isExternalProvider && isShowDocModeInfo && ( - <> - {dataset.doc_form && ( - - {t(`chunkingMode.${DOC_FORM_TEXT[dataset.doc_form]}`, { ns: 'dataset' })} - - )} - {dataset.indexing_technique && ( - - {formatIndexingTechniqueAndMethod(dataset.indexing_technique, dataset.retrieval_model_dict?.search_method) as any} - - )} - {dataset.is_multimodal && ( - - {t('multimodal', { ns: 'dataset' })} - - )} - - )} -
-
-
-
- {dataset.description} -
-
{ - e.stopPropagation() - e.preventDefault() - }} - > -
0 && 'visible', - )} - > - tag.id)} - selectedTags={tags} - onCacheUpdate={setTags} - onChange={onSuccess} - /> -
- {/* Tag Mask */} -
-
-
- -
- - {documentCount} -
-
- {!isExternalProvider && ( - -
- - {dataset.app_count} -
-
- )} - / - {`${t('updated', { ns: 'dataset' })} ${formatTimeFromNow(dataset.updated_at * 1000)}`} -
-
- - )} - className="z-20 min-w-[186px]" - popupClassName="rounded-xl bg-none shadow-none ring-0 min-w-[186px]" - position="br" - trigger="click" - btnElement={( -
- -
- )} - btnClassName={open => - cn( - 'size-9 cursor-pointer justify-center rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0 shadow-lg shadow-shadow-shadow-5 ring-[2px] ring-inset ring-components-actionbar-bg hover:border-components-actionbar-border', - open ? 'border-components-actionbar-border bg-state-base-hover' : '', - )} - /> -
-
- {showRenameModal && ( - + + + setShowRenameModal(false)} + tags={tags} + setTags={setTags} onSuccess={onSuccess} + isHoveringTagSelector={isHoveringTagSelector} + onClick={handleTagAreaClick} /> - )} - {showConfirmDelete && ( - setShowConfirmDelete(false)} + + - )} +
+ ) } diff --git a/web/app/components/goto-anything/actions/commands/account.tsx b/web/app/components/goto-anything/actions/commands/account.tsx index 6465932a75..d1fa36b6f0 100644 --- a/web/app/components/goto-anything/actions/commands/account.tsx +++ b/web/app/components/goto-anything/actions/commands/account.tsx @@ -1,7 +1,7 @@ import type { SlashCommandHandler } from './types' import { RiUser3Line } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Account command dependency types - no external dependencies needed @@ -21,6 +21,7 @@ export const accountCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'account', title: i18n.t('account.account', { ns: 'common', lng: locale }), diff --git a/web/app/components/goto-anything/actions/commands/community.tsx b/web/app/components/goto-anything/actions/commands/community.tsx index fcd9a15000..685149402d 100644 --- a/web/app/components/goto-anything/actions/commands/community.tsx +++ b/web/app/components/goto-anything/actions/commands/community.tsx @@ -1,7 +1,7 @@ import type { SlashCommandHandler } from './types' import { RiDiscordLine } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Community command dependency types @@ -22,6 +22,7 @@ export const communityCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'community', title: i18n.t('userProfile.community', { ns: 'common', lng: locale }), diff --git a/web/app/components/goto-anything/actions/commands/docs.tsx b/web/app/components/goto-anything/actions/commands/docs.tsx index 9f09d32094..8b04e84157 100644 --- a/web/app/components/goto-anything/actions/commands/docs.tsx +++ b/web/app/components/goto-anything/actions/commands/docs.tsx @@ -1,8 +1,8 @@ import type { SlashCommandHandler } from './types' import { RiBookOpenLine } from '@remixicon/react' import * as React from 'react' +import { getI18n } from 'react-i18next' import { defaultDocBaseUrl } from '@/context/i18n' -import i18n from '@/i18n-config/i18next-config' import { getDocLanguage } from '@/i18n-config/language' import { registerCommands, unregisterCommands } from './command-bus' @@ -19,6 +19,7 @@ export const docsCommand: SlashCommandHandler = { // Direct execution function execute: () => { + const i18n = getI18n() const currentLocale = i18n.language const docLanguage = getDocLanguage(currentLocale) const url = `${defaultDocBaseUrl}/${docLanguage}` @@ -26,6 +27,7 @@ export const docsCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'doc', title: i18n.t('userProfile.helpCenter', { ns: 'common', lng: locale }), @@ -41,6 +43,7 @@ export const docsCommand: SlashCommandHandler = { }, register(_deps: DocDeps) { + const i18n = getI18n() registerCommands({ 'navigation.doc': async (_args) => { // Get the current language from i18n diff --git a/web/app/components/goto-anything/actions/commands/forum.tsx b/web/app/components/goto-anything/actions/commands/forum.tsx index e32632b4b5..36116ceb1f 100644 --- a/web/app/components/goto-anything/actions/commands/forum.tsx +++ b/web/app/components/goto-anything/actions/commands/forum.tsx @@ -1,7 +1,7 @@ import type { SlashCommandHandler } from './types' import { RiFeedbackLine } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Forum command dependency types @@ -22,6 +22,7 @@ export const forumCommand: SlashCommandHandler = { }, async search(args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'forum', title: i18n.t('userProfile.forum', { ns: 'common', lng: locale }), diff --git a/web/app/components/goto-anything/actions/commands/language.tsx b/web/app/components/goto-anything/actions/commands/language.tsx index df94fd49ce..f4bafc1d58 100644 --- a/web/app/components/goto-anything/actions/commands/language.tsx +++ b/web/app/components/goto-anything/actions/commands/language.tsx @@ -1,6 +1,6 @@ import type { CommandSearchResult } from '../types' import type { SlashCommandHandler } from './types' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { languages } from '@/i18n-config/language' import { registerCommands, unregisterCommands } from './command-bus' @@ -14,6 +14,7 @@ const buildLanguageCommands = (query: string): CommandSearchResult[] => { const list = languages.filter(item => item.supported && ( !q || item.name.toLowerCase().includes(q) || String(item.value).toLowerCase().includes(q) )) + const i18n = getI18n() return list.map(item => ({ id: `lang-${item.value}`, title: item.name, diff --git a/web/app/components/goto-anything/actions/commands/slash.tsx b/web/app/components/goto-anything/actions/commands/slash.tsx index ec0f333cd4..6aad67731f 100644 --- a/web/app/components/goto-anything/actions/commands/slash.tsx +++ b/web/app/components/goto-anything/actions/commands/slash.tsx @@ -2,8 +2,8 @@ import type { ActionItem } from '../types' import { useTheme } from 'next-themes' import { useEffect } from 'react' +import { getI18n } from 'react-i18next' import { setLocaleOnClient } from '@/i18n-config' -import i18n from '@/i18n-config/i18next-config' import { accountCommand } from './account' import { executeCommand } from './command-bus' import { communityCommand } from './community' @@ -14,6 +14,8 @@ import { slashCommandRegistry } from './registry' import { themeCommand } from './theme' import { zenCommand } from './zen' +const i18n = getI18n() + export const slashAction: ActionItem = { key: '/', shortcut: '/', diff --git a/web/app/components/goto-anything/actions/commands/theme.tsx b/web/app/components/goto-anything/actions/commands/theme.tsx index 335182af67..ba1416229d 100644 --- a/web/app/components/goto-anything/actions/commands/theme.tsx +++ b/web/app/components/goto-anything/actions/commands/theme.tsx @@ -2,7 +2,7 @@ import type { CommandSearchResult } from '../types' import type { SlashCommandHandler } from './types' import { RiComputerLine, RiMoonLine, RiSunLine } from '@remixicon/react' import * as React from 'react' -import i18n from '@/i18n-config/i18next-config' +import { getI18n } from 'react-i18next' import { registerCommands, unregisterCommands } from './command-bus' // Theme dependency types @@ -32,6 +32,7 @@ const THEME_ITEMS = [ ] as const const buildThemeCommands = (query: string, locale?: string): CommandSearchResult[] => { + const i18n = getI18n() const q = query.toLowerCase() const list = THEME_ITEMS.filter(item => !q diff --git a/web/app/components/goto-anything/actions/commands/zen.tsx b/web/app/components/goto-anything/actions/commands/zen.tsx index d6d9f1e5a2..1645e40fd9 100644 --- a/web/app/components/goto-anything/actions/commands/zen.tsx +++ b/web/app/components/goto-anything/actions/commands/zen.tsx @@ -1,8 +1,8 @@ import type { SlashCommandHandler } from './types' import { RiFullscreenLine } from '@remixicon/react' import * as React from 'react' +import { getI18n } from 'react-i18next' import { isInWorkflowPage } from '@/app/components/workflow/constants' -import i18n from '@/i18n-config/i18next-config' import { registerCommands, unregisterCommands } from './command-bus' // Zen command dependency types - no external dependencies needed @@ -32,6 +32,7 @@ export const zenCommand: SlashCommandHandler = { execute: toggleZenMode, async search(_args: string, locale: string = 'en') { + const i18n = getI18n() return [{ id: 'zen', title: i18n.t('gotoAnything.actions.zenTitle', { ns: 'app', lng: locale }) || 'Zen Mode', diff --git a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx index 956352d6d3..f02e276f55 100644 --- a/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx +++ b/web/app/components/header/account-setting/data-source-page-new/install-from-marketplace.tsx @@ -15,7 +15,6 @@ import Divider from '@/app/components/base/divider' import Loading from '@/app/components/base/loading' import List from '@/app/components/plugins/marketplace/list' import ProviderCard from '@/app/components/plugins/provider-card' -import { getLocaleOnClient } from '@/i18n-config' import { cn } from '@/utils/classnames' import { getMarketplaceUrl } from '@/utils/var' import { @@ -33,7 +32,6 @@ const InstallFromMarketplace = ({ const { t } = useTranslation() const { theme } = useTheme() const [collapse, setCollapse] = useState(false) - const locale = getLocaleOnClient() const { plugins: allPlugins, isLoading: isAllPluginsLoading, @@ -70,7 +68,6 @@ const InstallFromMarketplace = ({ marketplaceCollectionPluginsMap={{}} plugins={allPlugins} showInstallButton - locale={locale} cardContainerClassName="grid grid-cols-2 gap-2" cardRender={cardRender} emptyClassName="h-auto" diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index 5d888281e9..2a0604421f 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -2,13 +2,13 @@ import type { Item } from '@/app/components/base/select' import type { Locale } from '@/i18n-config' +import { useRouter } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { SimpleSelect } from '@/app/components/base/select' import { ToastContext } from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' - import { useLocale } from '@/context/i18n' import { setLocaleOnClient } from '@/i18n-config' import { languages } from '@/i18n-config/language' @@ -25,6 +25,7 @@ export default function LanguagePage() { const { notify } = useContext(ToastContext) const [editing, setEditing] = useState(false) const { t } = useTranslation() + const router = useRouter() const handleSelectLanguage = async (item: Item) => { const url = '/account/interface-language' @@ -35,7 +36,8 @@ export default function LanguagePage() { await updateUserProfile({ url, body: { [bodyKey]: item.value } }) notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) - setLocaleOnClient(item.value.toString() as Locale) + setLocaleOnClient(item.value.toString() as Locale, false) + router.refresh() } catch (e) { notify({ type: 'error', message: (e as Error).message }) diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index 57b464e0e7..d3daaee859 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -6,10 +6,8 @@ import { RiBrainLine, } from '@remixicon/react' import { useDebounce } from 'ahooks' -import { useEffect, useMemo } from 'react' +import { useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { IS_CLOUD_EDITION } from '@/config' -import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { cn } from '@/utils/classnames' @@ -22,7 +20,6 @@ import { } from './hooks' import InstallFromMarketplace from './install-from-marketplace' import ProviderAddedCard from './provider-added-card' -import QuotaPanel from './provider-added-card/quota-panel' import SystemModelSelector from './system-model-selector' type Props = { @@ -34,16 +31,19 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an const ModelProviderPage = ({ searchText }: Props) => { const debouncedSearchText = useDebounce(searchText, { wait: 500 }) const { t } = useTranslation() - const { mutateCurrentWorkspace, isValidatingCurrentWorkspace } = useAppContext() - const { data: textGenerationDefaultModel } = useDefaultModel(ModelTypeEnum.textGeneration) - const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding) - const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) - const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) - const { data: ttsDefaultModel } = useDefaultModel(ModelTypeEnum.tts) + const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration) + const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding) + const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank) + const { data: speech2textDefaultModel, isLoading: isSpeech2textDefaultModelLoading } = useDefaultModel(ModelTypeEnum.speech2text) + const { data: ttsDefaultModel, isLoading: isTTSDefaultModelLoading } = useDefaultModel(ModelTypeEnum.tts) const { modelProviders: providers } = useProviderContext() const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures) - const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel - + const isDefaultModelLoading = isTextGenerationDefaultModelLoading + || isEmbeddingsDefaultModelLoading + || isRerankDefaultModelLoading + || isSpeech2textDefaultModelLoading + || isTTSDefaultModelLoading + const defaultModelNotConfigured = !isDefaultModelLoading && !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel const [configuredProviders, notConfiguredProviders] = useMemo(() => { const configuredProviders: ModelProvider[] = [] const notConfiguredProviders: ModelProvider[] = [] @@ -88,10 +88,6 @@ const ModelProviderPage = ({ searchText }: Props) => { return [filteredConfiguredProviders, filteredNotConfiguredProviders] }, [configuredProviders, debouncedSearchText, notConfiguredProviders]) - useEffect(() => { - mutateCurrentWorkspace() - }, [mutateCurrentWorkspace]) - return (
@@ -115,10 +111,10 @@ const ModelProviderPage = ({ searchText }: Props) => { rerankDefaultModel={rerankDefaultModel} speech2textDefaultModel={speech2textDefaultModel} ttsDefaultModel={ttsDefaultModel} + isLoading={isDefaultModelLoading} />
- {IS_CLOUD_EDITION && } {!filteredConfiguredProviders?.length && (
diff --git a/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx b/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx index 0e7506bf96..289146f2d2 100644 --- a/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx +++ b/web/app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx @@ -14,7 +14,6 @@ import Divider from '@/app/components/base/divider' import Loading from '@/app/components/base/loading' import List from '@/app/components/plugins/marketplace/list' import ProviderCard from '@/app/components/plugins/provider-card' -import { getLocaleOnClient } from '@/i18n-config' import { cn } from '@/utils/classnames' import { getMarketplaceUrl } from '@/utils/var' import { @@ -32,7 +31,6 @@ const InstallFromMarketplace = ({ const { t } = useTranslation() const { theme } = useTheme() const [collapse, setCollapse] = useState(false) - const locale = getLocaleOnClient() const { plugins: allPlugins, isLoading: isAllPluginsLoading, @@ -69,7 +67,6 @@ const InstallFromMarketplace = ({ marketplaceCollectionPluginsMap={{}} plugins={allPlugins} showInstallButton - locale={locale} cardContainerClassName="grid grid-cols-2 gap-2" cardRender={cardRender} emptyClassName="h-auto" diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx index cbaef21a70..59d7b2c0c8 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx @@ -7,7 +7,6 @@ import { useToastContext } from '@/app/components/base/toast' import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth' import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' import Indicator from '@/app/components/header/indicator' -import { IS_CLOUD_EDITION } from '@/config' import { useEventEmitterContextContext } from '@/context/event-emitter' import { changeModelProviderPriority } from '@/service/common' import { cn } from '@/utils/classnames' @@ -115,7 +114,7 @@ const CredentialPanel = ({ provider={provider} /> { - systemConfig.enabled && isCustomConfigured && IS_CLOUD_EDITION && ( + systemConfig.enabled && isCustomConfigured && ( = ({ const systemConfig = provider.system_configuration const hasModelList = fetched && !!modelList.length const { isCurrentWorkspaceManager } = useAppContext() - const showModelProvider = systemConfig.enabled && [...MODEL_PROVIDER_QUOTA_GET_PAID].includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION + const showQuota = systemConfig.enabled && [...MODEL_PROVIDER_QUOTA_GET_PAID].includes(provider.provider) && !IS_CE_EDITION const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager const getModelList = async (providerName: string) => { @@ -104,6 +104,13 @@ const ProviderAddedCard: FC = ({ }
+ { + showQuota && ( + + ) + } { showCredential && ( = ({ { collapsed && (
- {(showModelProvider || !notConfigured) && ( + {(showQuota || !notConfigured) && ( <>
{ @@ -143,7 +150,7 @@ const ProviderAddedCard: FC = ({
)} - {!showModelProvider && notConfigured && ( + {!showQuota && notConfigured && (
{t('modelProvider.configureTip', { ns: 'common' })} diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx index e296bc4555..cd49148403 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.tsx @@ -1,163 +1,66 @@ import type { FC } from 'react' import type { ModelProvider } from '../declarations' -import type { Plugin } from '@/app/components/plugins/types' -import { useBoolean } from 'ahooks' -import * as React from 'react' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { AnthropicShortLight, Deepseek, Gemini, Grok, OpenaiSmall, Tongyi } from '@/app/components/base/icons/src/public/llm' -import Loading from '@/app/components/base/loading' import Tooltip from '@/app/components/base/tooltip' -import InstallFromMarketplace from '@/app/components/plugins/install-plugin/install-from-marketplace' -import { useAppContext } from '@/context/app-context' -import useTimestamp from '@/hooks/use-timestamp' -import { cn } from '@/utils/classnames' import { formatNumber } from '@/utils/format' -import { PreferredProviderTypeEnum } from '../declarations' -import { useMarketplaceAllPlugins } from '../hooks' -import { modelNameMap, ModelProviderQuotaGetPaid } from '../utils' - -const allProviders = [ - { key: ModelProviderQuotaGetPaid.OPENAI, Icon: OpenaiSmall }, - { key: ModelProviderQuotaGetPaid.ANTHROPIC, Icon: AnthropicShortLight }, - { key: ModelProviderQuotaGetPaid.GEMINI, Icon: Gemini }, - { key: ModelProviderQuotaGetPaid.X, Icon: Grok }, - { key: ModelProviderQuotaGetPaid.DEEPSEEK, Icon: Deepseek }, - { key: ModelProviderQuotaGetPaid.TONGYI, Icon: Tongyi }, -] as const - -// Map provider key to plugin ID -// provider key format: langgenius/provider/model, plugin ID format: langgenius/provider -const providerKeyToPluginId: Record = { - [ModelProviderQuotaGetPaid.OPENAI]: 'langgenius/openai', - [ModelProviderQuotaGetPaid.ANTHROPIC]: 'langgenius/anthropic', - [ModelProviderQuotaGetPaid.GEMINI]: 'langgenius/gemini', - [ModelProviderQuotaGetPaid.X]: 'langgenius/x', - [ModelProviderQuotaGetPaid.DEEPSEEK]: 'langgenius/deepseek', - [ModelProviderQuotaGetPaid.TONGYI]: 'langgenius/tongyi', -} +import { + CustomConfigurationStatusEnum, + PreferredProviderTypeEnum, + QuotaUnitEnum, +} from '../declarations' +import { + MODEL_PROVIDER_QUOTA_GET_PAID, +} from '../utils' +import PriorityUseTip from './priority-use-tip' type QuotaPanelProps = { - providers: ModelProvider[] - isLoading?: boolean + provider: ModelProvider } const QuotaPanel: FC = ({ - providers, - isLoading = false, + provider, }) => { const { t } = useTranslation() - const { currentWorkspace } = useAppContext() - const credits = Math.max((currentWorkspace.trial_credits - currentWorkspace.trial_credits_used) || 0, 0) - const providerMap = useMemo(() => new Map( - providers.map(p => [p.provider, p.preferred_provider_type]), - ), [providers]) - const { formatTime } = useTimestamp() - const { - plugins: allPlugins, - } = useMarketplaceAllPlugins(providers, '') - const [selectedPlugin, setSelectedPlugin] = useState(null) - const [isShowInstallModal, { - setTrue: showInstallFromMarketplace, - setFalse: hideInstallFromMarketplace, - }] = useBoolean(false) - const selectedPluginIdRef = useRef(null) - const handleIconClick = useCallback((key: string) => { - const providerType = providerMap.get(key) - if (!providerType && allPlugins) { - const pluginId = providerKeyToPluginId[key] - const plugin = allPlugins.find(p => p.plugin_id === pluginId) - if (plugin) { - setSelectedPlugin(plugin) - selectedPluginIdRef.current = pluginId - showInstallFromMarketplace() - } - } - }, [allPlugins, providerMap, showInstallFromMarketplace]) - - useEffect(() => { - if (isShowInstallModal && selectedPluginIdRef.current) { - const isInstalled = providers.some(p => p.provider.startsWith(selectedPluginIdRef.current!)) - if (isInstalled) { - hideInstallFromMarketplace() - selectedPluginIdRef.current = null - } - } - }, [providers, isShowInstallModal, hideInstallFromMarketplace]) - - if (isLoading) { - return ( -
- -
- ) - } + const customConfig = provider.custom_configuration + const priorityUseType = provider.preferred_provider_type + const systemConfig = provider.system_configuration + const currentQuota = systemConfig.enabled && systemConfig.quota_configurations.find(item => item.quota_type === systemConfig.current_quota_type) + const openaiOrAnthropic = MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider) return ( -
+
{t('modelProvider.quota', { ns: 'common' })} - -
-
-
- {formatNumber(credits)} - {t('modelProvider.credits', { ns: 'common' })} - {currentWorkspace.next_credit_reset_date - ? ( - <> - · - - {t('modelProvider.resetDate', { - ns: 'common', - date: formatTime(currentWorkspace.next_credit_reset_date, t('dateFormat', { ns: 'appLog' })), - interpolation: { escapeValue: false }, - })} - - - ) - : null} -
-
- {allProviders.map(({ key, Icon }) => { - const providerType = providerMap.get(key) - const usingQuota = providerType === PreferredProviderTypeEnum.system - const getTooltipKey = () => { - if (usingQuota) - return 'modelProvider.card.modelSupported' - if (providerType === PreferredProviderTypeEnum.custom) - return 'modelProvider.card.modelAPI' - return 'modelProvider.card.modelNotSupported' - } - return ( - -
handleIconClick(key)} - > - - {!usingQuota && ( -
- )} -
- - ) - })} -
-
- {isShowInstallModal && selectedPlugin && ( - - )} +
+ { + currentQuota && ( +
+ {formatNumber(Math.max((currentQuota?.quota_limit || 0) - (currentQuota?.quota_used || 0), 0))} + { + currentQuota?.quota_unit === QuotaUnitEnum.tokens && 'Tokens' + } + { + currentQuota?.quota_unit === QuotaUnitEnum.times && t('modelProvider.callTimes', { ns: 'common' }) + } + { + currentQuota?.quota_unit === QuotaUnitEnum.credits && t('modelProvider.credits', { ns: 'common' }) + } +
+ ) + } + { + priorityUseType === PreferredProviderTypeEnum.system && customConfig.status === CustomConfigurationStatusEnum.active && ( + + ) + }
) } -export default React.memo(QuotaPanel) +export default QuotaPanel diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 74222ed56d..29c71e04fc 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -3,7 +3,7 @@ import type { DefaultModel, DefaultModelResponse, } from '../declarations' -import { RiEqualizer2Line } from '@remixicon/react' +import { RiEqualizer2Line, RiLoader2Line } from '@remixicon/react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -32,6 +32,7 @@ type SystemModelSelectorProps = { speech2textDefaultModel: DefaultModelResponse | undefined ttsDefaultModel: DefaultModelResponse | undefined notConfigured: boolean + isLoading?: boolean } const SystemModel: FC = ({ textGenerationDefaultModel, @@ -40,6 +41,7 @@ const SystemModel: FC = ({ speech2textDefaultModel, ttsDefaultModel, notConfigured, + isLoading, }) => { const { t } = useTranslation() const { notify } = useToastContext() @@ -129,13 +131,16 @@ const SystemModel: FC = ({ crossAxis: 8, }} > - setOpen(v => !v)}> + setOpen(v => !v)}> diff --git a/web/app/components/header/account-setting/model-provider-page/utils.ts b/web/app/components/header/account-setting/model-provider-page/utils.ts index d958f3eef3..b60d6a0c7b 100644 --- a/web/app/components/header/account-setting/model-provider-page/utils.ts +++ b/web/app/components/header/account-setting/model-provider-page/utils.ts @@ -17,25 +17,7 @@ import { ModelTypeEnum, } from './declarations' -export enum ModelProviderQuotaGetPaid { - ANTHROPIC = 'langgenius/anthropic/anthropic', - OPENAI = 'langgenius/openai/openai', - // AZURE_OPENAI = 'langgenius/azure_openai/azure_openai', - GEMINI = 'langgenius/gemini/google', - X = 'langgenius/x/x', - DEEPSEEK = 'langgenius/deepseek/deepseek', - TONGYI = 'langgenius/tongyi/tongyi', -} -export const MODEL_PROVIDER_QUOTA_GET_PAID = [ModelProviderQuotaGetPaid.ANTHROPIC, ModelProviderQuotaGetPaid.OPENAI, ModelProviderQuotaGetPaid.GEMINI, ModelProviderQuotaGetPaid.X, ModelProviderQuotaGetPaid.DEEPSEEK, ModelProviderQuotaGetPaid.TONGYI] - -export const modelNameMap = { - [ModelProviderQuotaGetPaid.OPENAI]: 'OpenAI', - [ModelProviderQuotaGetPaid.ANTHROPIC]: 'Anthropic', - [ModelProviderQuotaGetPaid.GEMINI]: 'Gemini', - [ModelProviderQuotaGetPaid.X]: 'xAI', - [ModelProviderQuotaGetPaid.DEEPSEEK]: 'DeepSeek', - [ModelProviderQuotaGetPaid.TONGYI]: 'TONGYI', -} +export const MODEL_PROVIDER_QUOTA_GET_PAID = ['langgenius/anthropic/anthropic', 'langgenius/openai/openai', 'langgenius/azure_openai/azure_openai'] export const isNullOrUndefined = (value: any) => { return value === undefined || value === null diff --git a/web/app/components/i18n-server.tsx b/web/app/components/i18n-server.tsx deleted file mode 100644 index 01dc5f0f13..0000000000 --- a/web/app/components/i18n-server.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import * as React from 'react' -import { getLocaleOnServer } from '@/i18n-config/server' -import { ToastProvider } from './base/toast' -import I18N from './i18n' - -export type II18NServerProps = { - children: React.ReactNode -} - -const I18NServer = async ({ - children, -}: II18NServerProps) => { - const locale = await getLocaleOnServer() - - return ( - - {children} - - ) -} - -export default I18NServer diff --git a/web/app/components/i18n.tsx b/web/app/components/i18n.tsx deleted file mode 100644 index e9af2face9..0000000000 --- a/web/app/components/i18n.tsx +++ /dev/null @@ -1,45 +0,0 @@ -'use client' - -import type { FC } from 'react' -import type { Locale } from '@/i18n-config' -import { usePrefetchQuery } from '@tanstack/react-query' -import { useHydrateAtoms } from 'jotai/utils' -import * as React from 'react' -import { useEffect, useState } from 'react' -import { localeAtom } from '@/context/i18n' -import { setLocaleOnClient } from '@/i18n-config' -import { getSystemFeatures } from '@/service/common' -import Loading from './base/loading' - -export type II18nProps = { - locale: Locale - children: React.ReactNode -} -const I18n: FC = ({ - locale, - children, -}) => { - useHydrateAtoms([[localeAtom, locale]]) - const [loading, setLoading] = useState(true) - - usePrefetchQuery({ - queryKey: ['systemFeatures'], - queryFn: getSystemFeatures, - }) - - useEffect(() => { - setLocaleOnClient(locale, false).then(() => { - setLoading(false) - }) - }, [locale]) - - if (loading) - return
- - return ( - <> - {children} - - ) -} -export default React.memo(I18n) diff --git a/web/app/components/plugins/base/deprecation-notice.tsx b/web/app/components/plugins/base/deprecation-notice.tsx index c2ddfa6975..513b27a2cf 100644 --- a/web/app/components/plugins/base/deprecation-notice.tsx +++ b/web/app/components/plugins/base/deprecation-notice.tsx @@ -1,4 +1,5 @@ import type { FC } from 'react' +import { useTranslation } from '#i18n' import { RiAlertFill } from '@remixicon/react' import { camelCase } from 'es-toolkit/string' import Link from 'next/link' @@ -6,14 +7,12 @@ import * as React from 'react' import { useMemo } from 'react' import { Trans } from 'react-i18next' import { cn } from '@/utils/classnames' -import { useMixedTranslation } from '../marketplace/hooks' type DeprecationNoticeProps = { status: 'deleted' | 'active' deprecatedReason: string alternativePluginId: string alternativePluginURL: string - locale?: string className?: string innerWrapperClassName?: string iconWrapperClassName?: string @@ -34,13 +33,12 @@ const DeprecationNotice: FC = ({ deprecatedReason, alternativePluginId, alternativePluginURL, - locale, className, innerWrapperClassName, iconWrapperClassName, textClassName, }) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const deprecatedReasonKey = useMemo(() => { if (!deprecatedReason) diff --git a/web/app/components/plugins/card/index.spec.tsx b/web/app/components/plugins/card/index.spec.tsx index 4a3e5a587b..fd97534ec4 100644 --- a/web/app/components/plugins/card/index.spec.tsx +++ b/web/app/components/plugins/card/index.spec.tsx @@ -502,31 +502,6 @@ describe('Card', () => { }) }) - // ================================ - // Locale Tests - // ================================ - describe('Locale', () => { - it('should use locale from props when provided', () => { - const plugin = createMockPlugin({ - label: { 'en-US': 'English Title', 'zh-Hans': '中文标题' }, - }) - - render() - - expect(screen.getByText('中文标题')).toBeInTheDocument() - }) - - it('should fallback to default locale when prop locale not found', () => { - const plugin = createMockPlugin({ - label: { 'en-US': 'English Title' }, - }) - - render() - - expect(screen.getByText('English Title')).toBeInTheDocument() - }) - }) - // ================================ // Memoization Tests // ================================ diff --git a/web/app/components/plugins/card/index.tsx b/web/app/components/plugins/card/index.tsx index ada26801de..8578421116 100644 --- a/web/app/components/plugins/card/index.tsx +++ b/web/app/components/plugins/card/index.tsx @@ -1,15 +1,13 @@ 'use client' import type { Plugin } from '../types' -import type { Locale } from '@/i18n-config' +import { useTranslation } from '#i18n' import { RiAlertFill } from '@remixicon/react' import * as React from 'react' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { useGetLanguage } from '@/context/i18n' import useTheme from '@/hooks/use-theme' import { renderI18nObject, } from '@/i18n-config' -import { getLanguage } from '@/i18n-config/language' import { Theme } from '@/types/app' import { cn } from '@/utils/classnames' import Partner from '../base/badges/partner' @@ -33,7 +31,6 @@ export type Props = { footer?: React.ReactNode isLoading?: boolean loadingFileName?: string - locale?: Locale limitedInstall?: boolean } @@ -48,13 +45,11 @@ const Card = ({ footer, isLoading = false, loadingFileName, - locale: localeFromProps, limitedInstall = false, }: Props) => { - const defaultLocale = useGetLanguage() - const locale = localeFromProps ? getLanguage(localeFromProps) : defaultLocale - const { t } = useMixedTranslation(localeFromProps) - const { categoriesMap } = useCategories(t, true) + const locale = useGetLanguage() + const { t } = useTranslation() + const { categoriesMap } = useCategories(true) const { category, type, name, org, label, brief, icon, icon_dark, verified, badges = [] } = payload const { theme } = useTheme() const iconSrc = theme === Theme.dark && icon_dark ? icon_dark : icon diff --git a/web/app/components/plugins/hooks.ts b/web/app/components/plugins/hooks.ts index 262935205b..65d073cc2f 100644 --- a/web/app/components/plugins/hooks.ts +++ b/web/app/components/plugins/hooks.ts @@ -1,4 +1,3 @@ -import type { TFunction } from 'i18next' import type { CategoryKey, TagKey } from './constants' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' @@ -13,9 +12,8 @@ export type Tag = { label: string } -export const useTags = (translateFromOut?: TFunction) => { - const { t: translation } = useTranslation() - const t = translateFromOut || translation +export const useTags = () => { + const { t } = useTranslation() const tags = useMemo(() => { return tagKeys.map((tag) => { @@ -53,9 +51,8 @@ type Category = { label: string } -export const useCategories = (translateFromOut?: TFunction, isSingle?: boolean) => { - const { t: translation } = useTranslation() - const t = translateFromOut || translation +export const useCategories = (isSingle?: boolean) => { + const { t } = useTranslation() const categories = useMemo(() => { return categoryKeys.map((category) => { diff --git a/web/app/components/plugins/marketplace/description/index.spec.tsx b/web/app/components/plugins/marketplace/description/index.spec.tsx index b5c8cb716b..054949ee1f 100644 --- a/web/app/components/plugins/marketplace/description/index.spec.tsx +++ b/web/app/components/plugins/marketplace/description/index.spec.tsx @@ -1,7 +1,5 @@ import { render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' - -// Import component after mocks are set up import Description from './index' // ================================ @@ -30,20 +28,18 @@ const commonTranslations: Record = { 'operation.in': 'in', } -// Mock getLocaleOnServer and translate -vi.mock('@/i18n-config/server', () => ({ - getLocaleOnServer: vi.fn(() => Promise.resolve(mockDefaultLocale)), - getTranslation: vi.fn((locale: string, ns: string) => { - return Promise.resolve({ - t: (key: string) => { - if (ns === 'plugin') - return pluginTranslations[key] || key - if (ns === 'common') - return commonTranslations[key] || key - return key - }, - }) - }), +// Mock i18n hooks +vi.mock('#i18n', () => ({ + useLocale: vi.fn(() => mockDefaultLocale), + useTranslation: vi.fn((ns: string) => ({ + t: (key: string) => { + if (ns === 'plugin') + return pluginTranslations[key] || key + if (ns === 'common') + return commonTranslations[key] || key + return key + }, + })), })) // ================================ @@ -59,29 +55,29 @@ describe('Description', () => { // Rendering Tests // ================================ describe('Rendering', () => { - it('should render without crashing', async () => { - const { container } = render(await Description({})) + it('should render without crashing', () => { + const { container } = render() expect(container.firstChild).toBeInTheDocument() }) - it('should render h1 heading with empower text', async () => { - render(await Description({})) + it('should render h1 heading with empower text', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toBeInTheDocument() expect(heading).toHaveTextContent('Empower your AI development') }) - it('should render h2 subheading', async () => { - render(await Description({})) + it('should render h2 subheading', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toBeInTheDocument() }) - it('should apply correct CSS classes to h1', async () => { - render(await Description({})) + it('should apply correct CSS classes to h1', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toHaveClass('title-4xl-semi-bold') @@ -90,8 +86,8 @@ describe('Description', () => { expect(heading).toHaveClass('text-text-primary') }) - it('should apply correct CSS classes to h2', async () => { - render(await Description({})) + it('should apply correct CSS classes to h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('body-md-regular') @@ -104,14 +100,18 @@ describe('Description', () => { // Non-Chinese Locale Rendering Tests // ================================ describe('Non-Chinese Locale Rendering', () => { - it('should render discover text for en-US locale', async () => { - render(await Description({ locale: 'en-US' })) + beforeEach(() => { + mockDefaultLocale = 'en-US' + }) + + it('should render discover text for en-US locale', () => { + render() expect(screen.getByText(/Discover/)).toBeInTheDocument() }) - it('should render all category names', async () => { - render(await Description({ locale: 'en-US' })) + it('should render all category names', () => { + render() expect(screen.getByText('Models')).toBeInTheDocument() expect(screen.getByText('Tools')).toBeInTheDocument() @@ -122,36 +122,36 @@ describe('Description', () => { expect(screen.getByText('Bundles')).toBeInTheDocument() }) - it('should render "and" conjunction text', async () => { - render(await Description({ locale: 'en-US' })) + it('should render "and" conjunction text', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('and') }) - it('should render "in" preposition at the end for non-Chinese locales', async () => { - render(await Description({ locale: 'en-US' })) + it('should render "in" preposition at the end for non-Chinese locales', () => { + render() expect(screen.getByText('in')).toBeInTheDocument() }) - it('should render Dify Marketplace text at the end for non-Chinese locales', async () => { - render(await Description({ locale: 'en-US' })) + it('should render Dify Marketplace text at the end for non-Chinese locales', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should render category spans with styled underline effect', async () => { - const { container } = render(await Description({ locale: 'en-US' })) + it('should render category spans with styled underline effect', () => { + const { container } = render() const styledSpans = container.querySelectorAll('.body-md-medium.relative.z-\\[1\\]') // 7 category spans (models, tools, datasources, triggers, agents, extensions, bundles) expect(styledSpans.length).toBe(7) }) - it('should apply text-text-secondary class to category spans', async () => { - const { container } = render(await Description({ locale: 'en-US' })) + it('should apply text-text-secondary class to category spans', () => { + const { container } = render() const styledSpans = container.querySelectorAll('.text-text-secondary') expect(styledSpans.length).toBeGreaterThanOrEqual(7) @@ -162,29 +162,33 @@ describe('Description', () => { // Chinese (zh-Hans) Locale Rendering Tests // ================================ describe('Chinese (zh-Hans) Locale Rendering', () => { - it('should render "in" text at the beginning for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + beforeEach(() => { + mockDefaultLocale = 'zh-Hans' + }) + + it('should render "in" text at the beginning for zh-Hans locale', () => { + render() // In zh-Hans mode, "in" appears at the beginning const inElements = screen.getAllByText('in') expect(inElements.length).toBeGreaterThanOrEqual(1) }) - it('should render Dify Marketplace text for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render Dify Marketplace text for zh-Hans locale', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should render discover text for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render discover text for zh-Hans locale', () => { + render() expect(screen.getByText(/Discover/)).toBeInTheDocument() }) - it('should render all categories for zh-Hans locale', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render all categories for zh-Hans locale', () => { + render() expect(screen.getByText('Models')).toBeInTheDocument() expect(screen.getByText('Tools')).toBeInTheDocument() @@ -195,8 +199,8 @@ describe('Description', () => { expect(screen.getByText('Bundles')).toBeInTheDocument() }) - it('should render both zh-Hans specific elements and shared elements', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render both zh-Hans specific elements and shared elements', () => { + render() // zh-Hans has specific element order: "in" -> Dify Marketplace -> Discover // then the same category list with "and" -> Bundles @@ -206,61 +210,57 @@ describe('Description', () => { }) // ================================ - // Locale Prop Variations Tests + // Locale Variations Tests // ================================ - describe('Locale Prop Variations', () => { - it('should use default locale when locale prop is undefined', async () => { + describe('Locale Variations', () => { + it('should use en-US locale by default', () => { mockDefaultLocale = 'en-US' - render(await Description({})) + render() - // Should use the default locale from getLocaleOnServer expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should use provided locale prop instead of default', async () => { + it('should handle ja-JP locale as non-Chinese', () => { mockDefaultLocale = 'ja-JP' - render(await Description({ locale: 'en-US' })) - - // The locale prop should be used, triggering non-Chinese rendering - const subheading = screen.getByRole('heading', { level: 2 }) - expect(subheading).toBeInTheDocument() - }) - - it('should handle ja-JP locale as non-Chinese', async () => { - render(await Description({ locale: 'ja-JP' })) + render() // Should render in non-Chinese format (discover first, then "in Dify Marketplace" at end) const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should handle ko-KR locale as non-Chinese', async () => { - render(await Description({ locale: 'ko-KR' })) + it('should handle ko-KR locale as non-Chinese', () => { + mockDefaultLocale = 'ko-KR' + render() // Should render in non-Chinese format expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle de-DE locale as non-Chinese', async () => { - render(await Description({ locale: 'de-DE' })) + it('should handle de-DE locale as non-Chinese', () => { + mockDefaultLocale = 'de-DE' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle fr-FR locale as non-Chinese', async () => { - render(await Description({ locale: 'fr-FR' })) + it('should handle fr-FR locale as non-Chinese', () => { + mockDefaultLocale = 'fr-FR' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle pt-BR locale as non-Chinese', async () => { - render(await Description({ locale: 'pt-BR' })) + it('should handle pt-BR locale as non-Chinese', () => { + mockDefaultLocale = 'pt-BR' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) - it('should handle es-ES locale as non-Chinese', async () => { - render(await Description({ locale: 'es-ES' })) + it('should handle es-ES locale as non-Chinese', () => { + mockDefaultLocale = 'es-ES' + render() expect(screen.getByText('Empower your AI development')).toBeInTheDocument() }) @@ -270,24 +270,27 @@ describe('Description', () => { // Conditional Rendering Tests // ================================ describe('Conditional Rendering', () => { - it('should render zh-Hans specific content when locale is zh-Hans', async () => { - const { container } = render(await Description({ locale: 'zh-Hans' })) + it('should render zh-Hans specific content when locale is zh-Hans', () => { + mockDefaultLocale = 'zh-Hans' + const { container } = render() // zh-Hans has additional span with mr-1 before "in" text at the start const mrSpan = container.querySelector('span.mr-1') expect(mrSpan).toBeInTheDocument() }) - it('should render non-Chinese specific content when locale is not zh-Hans', async () => { - render(await Description({ locale: 'en-US' })) + it('should render non-Chinese specific content when locale is not zh-Hans', () => { + mockDefaultLocale = 'en-US' + render() // Non-Chinese has "in" and "Dify Marketplace" at the end const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading.textContent).toContain('Dify Marketplace') }) - it('should not render zh-Hans intro content for non-Chinese locales', async () => { - render(await Description({ locale: 'en-US' })) + it('should not render zh-Hans intro content for non-Chinese locales', () => { + mockDefaultLocale = 'en-US' + render() // For en-US, the order should be Discover ... in Dify Marketplace // The "in" text should only appear once at the end @@ -303,8 +306,9 @@ describe('Description', () => { expect(inIndex).toBeLessThan(marketplaceIndex) }) - it('should render zh-Hans with proper word order', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render zh-Hans with proper word order', () => { + mockDefaultLocale = 'zh-Hans' + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -323,58 +327,58 @@ describe('Description', () => { // Category Styling Tests // ================================ describe('Category Styling', () => { - it('should apply underline effect with after pseudo-element styling', async () => { - const { container } = render(await Description({})) + it('should apply underline effect with after pseudo-element styling', () => { + const { container } = render() const categorySpan = container.querySelector('.after\\:absolute') expect(categorySpan).toBeInTheDocument() }) - it('should apply correct after pseudo-element classes', async () => { - const { container } = render(await Description({})) + it('should apply correct after pseudo-element classes', () => { + const { container } = render() // Check for the specific after pseudo-element classes const categorySpans = container.querySelectorAll('.after\\:bottom-\\[1\\.5px\\]') expect(categorySpans.length).toBe(7) }) - it('should apply full width to after element', async () => { - const { container } = render(await Description({})) + it('should apply full width to after element', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.after\\:w-full') expect(categorySpans.length).toBe(7) }) - it('should apply correct height to after element', async () => { - const { container } = render(await Description({})) + it('should apply correct height to after element', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.after\\:h-2') expect(categorySpans.length).toBe(7) }) - it('should apply bg-text-text-selected to after element', async () => { - const { container } = render(await Description({})) + it('should apply bg-text-text-selected to after element', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.after\\:bg-text-text-selected') expect(categorySpans.length).toBe(7) }) - it('should have z-index 1 on category spans', async () => { - const { container } = render(await Description({})) + it('should have z-index 1 on category spans', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.z-\\[1\\]') expect(categorySpans.length).toBe(7) }) - it('should apply left margin to category spans', async () => { - const { container } = render(await Description({})) + it('should apply left margin to category spans', () => { + const { container } = render() const categorySpans = container.querySelectorAll('.ml-1') expect(categorySpans.length).toBeGreaterThanOrEqual(7) }) - it('should apply both left and right margin to specific spans', async () => { - const { container } = render(await Description({})) + it('should apply both left and right margin to specific spans', () => { + const { container } = render() // Extensions and Bundles spans have both ml-1 and mr-1 const extensionsBundlesSpans = container.querySelectorAll('.ml-1.mr-1') @@ -386,28 +390,17 @@ describe('Description', () => { // Edge Cases Tests // ================================ describe('Edge Cases', () => { - it('should handle empty props object', async () => { - const { container } = render(await Description({})) - - expect(container.firstChild).toBeInTheDocument() - }) - - it('should render fragment as root element', async () => { - const { container } = render(await Description({})) + it('should render fragment as root element', () => { + const { container } = render() // Fragment renders h1 and h2 as direct children expect(container.querySelector('h1')).toBeInTheDocument() expect(container.querySelector('h2')).toBeInTheDocument() }) - it('should handle locale prop with undefined value', async () => { - render(await Description({ locale: undefined })) - - expect(screen.getByRole('heading', { level: 1 })).toBeInTheDocument() - }) - - it('should handle zh-Hant as non-Chinese simplified', async () => { - render(await Description({ locale: 'zh-Hant' })) + it('should handle zh-Hant as non-Chinese simplified', () => { + mockDefaultLocale = 'zh-Hant' + render() // zh-Hant is different from zh-Hans, should use non-Chinese format const subheading = screen.getByRole('heading', { level: 2 }) @@ -426,8 +419,8 @@ describe('Description', () => { // Content Structure Tests // ================================ describe('Content Structure', () => { - it('should have comma separators between categories', async () => { - render(await Description({})) + it('should have comma separators between categories', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -436,8 +429,8 @@ describe('Description', () => { expect(content).toMatch(/Models[^\n\r,\u2028\u2029]*,.*Tools[^\n\r,\u2028\u2029]*,.*Data Sources[^\n\r,\u2028\u2029]*,.*Triggers[^\n\r,\u2028\u2029]*,.*Agent Strategies[^\n\r,\u2028\u2029]*,.*Extensions/) }) - it('should have "and" before last category (Bundles)', async () => { - render(await Description({})) + it('should have "and" before last category (Bundles)', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -449,8 +442,9 @@ describe('Description', () => { expect(andIndex).toBeLessThan(bundlesIndex) }) - it('should render all text elements in correct order for en-US', async () => { - render(await Description({ locale: 'en-US' })) + it('should render all text elements in correct order for en-US', () => { + mockDefaultLocale = 'en-US' + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -477,8 +471,9 @@ describe('Description', () => { } }) - it('should render all text elements in correct order for zh-Hans', async () => { - render(await Description({ locale: 'zh-Hans' })) + it('should render all text elements in correct order for zh-Hans', () => { + mockDefaultLocale = 'zh-Hans' + render() const subheading = screen.getByRole('heading', { level: 2 }) const content = subheading.textContent || '' @@ -499,82 +494,48 @@ describe('Description', () => { // Layout Tests // ================================ describe('Layout', () => { - it('should have shrink-0 on h1 heading', async () => { - render(await Description({})) + it('should have shrink-0 on h1 heading', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toHaveClass('shrink-0') }) - it('should have shrink-0 on h2 subheading', async () => { - render(await Description({})) + it('should have shrink-0 on h2 subheading', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('shrink-0') }) - it('should have flex layout on h2', async () => { - render(await Description({})) + it('should have flex layout on h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('flex') }) - it('should have items-center on h2', async () => { - render(await Description({})) + it('should have items-center on h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('items-center') }) - it('should have justify-center on h2', async () => { - render(await Description({})) + it('should have justify-center on h2', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toHaveClass('justify-center') }) }) - // ================================ - // Translation Function Tests - // ================================ - describe('Translation Functions', () => { - it('should call getTranslation for plugin namespace', async () => { - const { getTranslation } = await import('@/i18n-config/server') - render(await Description({ locale: 'en-US' })) - - expect(getTranslation).toHaveBeenCalledWith('en-US', 'plugin') - }) - - it('should call getTranslation for common namespace', async () => { - const { getTranslation } = await import('@/i18n-config/server') - render(await Description({ locale: 'en-US' })) - - expect(getTranslation).toHaveBeenCalledWith('en-US', 'common') - }) - - it('should call getLocaleOnServer when locale prop is undefined', async () => { - const { getLocaleOnServer } = await import('@/i18n-config/server') - render(await Description({})) - - expect(getLocaleOnServer).toHaveBeenCalled() - }) - - it('should use locale prop when provided', async () => { - const { getTranslation } = await import('@/i18n-config/server') - render(await Description({ locale: 'ja-JP' })) - - expect(getTranslation).toHaveBeenCalledWith('ja-JP', 'plugin') - expect(getTranslation).toHaveBeenCalledWith('ja-JP', 'common') - }) - }) - // ================================ // Accessibility Tests // ================================ describe('Accessibility', () => { - it('should have proper heading hierarchy', async () => { - render(await Description({})) + it('should have proper heading hierarchy', () => { + render() const h1 = screen.getByRole('heading', { level: 1 }) const h2 = screen.getByRole('heading', { level: 2 }) @@ -583,22 +544,22 @@ describe('Description', () => { expect(h2).toBeInTheDocument() }) - it('should have readable text content', async () => { - render(await Description({})) + it('should have readable text content', () => { + render() const h1 = screen.getByRole('heading', { level: 1 }) expect(h1.textContent).not.toBe('') }) - it('should have visible h1 heading', async () => { - render(await Description({})) + it('should have visible h1 heading', () => { + render() const heading = screen.getByRole('heading', { level: 1 }) expect(heading).toBeVisible() }) - it('should have visible h2 heading', async () => { - render(await Description({})) + it('should have visible h2 heading', () => { + render() const subheading = screen.getByRole('heading', { level: 2 }) expect(subheading).toBeVisible() @@ -615,8 +576,8 @@ describe('Description Integration', () => { mockDefaultLocale = 'en-US' }) - it('should render complete component structure', async () => { - const { container } = render(await Description({ locale: 'en-US' })) + it('should render complete component structure', () => { + const { container } = render() // Main headings expect(container.querySelector('h1')).toBeInTheDocument() @@ -627,8 +588,9 @@ describe('Description Integration', () => { expect(categorySpans.length).toBe(7) }) - it('should render complete zh-Hans structure', async () => { - const { container } = render(await Description({ locale: 'zh-Hans' })) + it('should render complete zh-Hans structure', () => { + mockDefaultLocale = 'zh-Hans' + const { container } = render() // Main headings expect(container.querySelector('h1')).toBeInTheDocument() @@ -639,14 +601,16 @@ describe('Description Integration', () => { expect(categorySpans.length).toBe(7) }) - it('should correctly switch between zh-Hans and en-US layouts', async () => { + it('should correctly differentiate between zh-Hans and en-US layouts', () => { // Render en-US - const { container: enContainer, unmount: unmountEn } = render(await Description({ locale: 'en-US' })) + mockDefaultLocale = 'en-US' + const { container: enContainer, unmount: unmountEn } = render() const enContent = enContainer.querySelector('h2')?.textContent || '' unmountEn() // Render zh-Hans - const { container: zhContainer } = render(await Description({ locale: 'zh-Hans' })) + mockDefaultLocale = 'zh-Hans' + const { container: zhContainer } = render() const zhContent = zhContainer.querySelector('h2')?.textContent || '' // Both should have all categories @@ -666,14 +630,16 @@ describe('Description Integration', () => { expect(zhMarketplaceIndex).toBeLessThan(zhDiscoverIndex) }) - it('should maintain consistent styling across locales', async () => { + it('should maintain consistent styling across locales', () => { // Render en-US - const { container: enContainer, unmount: unmountEn } = render(await Description({ locale: 'en-US' })) + mockDefaultLocale = 'en-US' + const { container: enContainer, unmount: unmountEn } = render() const enCategoryCount = enContainer.querySelectorAll('.body-md-medium').length unmountEn() // Render zh-Hans - const { container: zhContainer } = render(await Description({ locale: 'zh-Hans' })) + mockDefaultLocale = 'zh-Hans' + const { container: zhContainer } = render() const zhCategoryCount = zhContainer.querySelectorAll('.body-md-medium').length // Both should have same number of styled category spans diff --git a/web/app/components/plugins/marketplace/description/index.tsx b/web/app/components/plugins/marketplace/description/index.tsx index d3ca964538..30ccbdb76e 100644 --- a/web/app/components/plugins/marketplace/description/index.tsx +++ b/web/app/components/plugins/marketplace/description/index.tsx @@ -1,17 +1,11 @@ -/* eslint-disable dify-i18n/require-ns-option */ -import type { Locale } from '@/i18n-config' -import { getLocaleOnServer, getTranslation } from '@/i18n-config/server' +import { useLocale, useTranslation } from '#i18n' -type DescriptionProps = { - locale?: Locale -} -const Description = async ({ - locale: localeFromProps, -}: DescriptionProps) => { - const localeDefault = await getLocaleOnServer() - const { t } = await getTranslation(localeFromProps || localeDefault, 'plugin') - const { t: tCommon } = await getTranslation(localeFromProps || localeDefault, 'common') - const isZhHans = localeFromProps === 'zh-Hans' +const Description = () => { + const { t } = useTranslation('plugin') + const { t: tCommon } = useTranslation('common') + const locale = useLocale() + + const isZhHans = locale === 'zh-Hans' return ( <> diff --git a/web/app/components/plugins/marketplace/empty/index.spec.tsx b/web/app/components/plugins/marketplace/empty/index.spec.tsx index 4cbc85a309..bc8e701dfc 100644 --- a/web/app/components/plugins/marketplace/empty/index.spec.tsx +++ b/web/app/components/plugins/marketplace/empty/index.spec.tsx @@ -7,9 +7,9 @@ import Line from './line' // Mock external dependencies only // ================================ -// Mock useMixedTranslation hook -vi.mock('../hooks', () => ({ - useMixedTranslation: (_locale?: string) => ({ +// Mock i18n translation hook +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: (key: string, options?: { ns?: string }) => { // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -471,36 +471,6 @@ describe('Empty', () => { }) }) - // ================================ - // Locale Prop Tests - // ================================ - describe('Locale Prop', () => { - it('should pass locale to useMixedTranslation', () => { - render() - - // Translation should still work - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - - it('should handle undefined locale', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - - it('should handle en-US locale', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - - it('should handle ja-JP locale', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - }) - // ================================ // Placeholder Cards Layout Tests // ================================ @@ -634,7 +604,6 @@ describe('Empty', () => { text="Custom message" lightCard className="custom-wrapper" - locale="en-US" />, ) @@ -695,12 +664,6 @@ describe('Empty', () => { expect(container.querySelector('.only-class')).toBeInTheDocument() }) - it('should render with only locale prop', () => { - render() - - expect(screen.getByText('No plugin found')).toBeInTheDocument() - }) - it('should handle text with unicode characters', () => { render() @@ -813,7 +776,7 @@ describe('Empty and Line Integration', () => { }) it('should render complete Empty component structure', () => { - const { container } = render() + const { container } = render() // Container expect(container.querySelector('.test')).toBeInTheDocument() diff --git a/web/app/components/plugins/marketplace/empty/index.tsx b/web/app/components/plugins/marketplace/empty/index.tsx index 3c33d9b92a..6e5adff1b4 100644 --- a/web/app/components/plugins/marketplace/empty/index.tsx +++ b/web/app/components/plugins/marketplace/empty/index.tsx @@ -1,6 +1,6 @@ 'use client' +import { useTranslation } from '#i18n' import { Group } from '@/app/components/base/icons/src/vender/other' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { cn } from '@/utils/classnames' import Line from './line' @@ -8,16 +8,14 @@ type Props = { text?: string lightCard?: boolean className?: string - locale?: string } const Empty = ({ text, lightCard, className, - locale, }: Props) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() return (
{ } } -/** - * ! Support zh-Hans, pt-BR, ja-JP and en-US for Marketplace page - * ! For other languages, use en-US as fallback - */ -export const useMixedTranslation = (localeFromOuter?: string) => { - let t = useTranslation().t - - if (localeFromOuter) - t = i18n.getFixedT(localeFromOuter) - - return { - t, - } -} - export const useMarketplaceContainerScroll = ( callback: () => void, scrollContainerId = 'marketplace-container', diff --git a/web/app/components/plugins/marketplace/index.spec.tsx b/web/app/components/plugins/marketplace/index.spec.tsx index 3073897ba1..b3b1d58dd4 100644 --- a/web/app/components/plugins/marketplace/index.spec.tsx +++ b/web/app/components/plugins/marketplace/index.spec.tsx @@ -11,7 +11,6 @@ import { PluginCategoryEnum } from '@/app/components/plugins/types' // Note: Import after mocks are set up import { DEFAULT_SORT, SCROLL_BOTTOM_THRESHOLD } from './constants' import { MarketplaceContext, MarketplaceContextProvider, useMarketplaceContext } from './context' -import { useMixedTranslation } from './hooks' import PluginTypeSwitch, { PLUGIN_TYPE_SEARCH_MAP } from './plugin-type-switch' import StickySearchAndSwitchWrapper from './sticky-search-and-switch-wrapper' import { @@ -602,48 +601,6 @@ describe('utils', () => { }) }) -// ================================ -// Hooks Tests -// ================================ -describe('hooks', () => { - describe('useMixedTranslation', () => { - it('should return translation function', () => { - const { result } = renderHook(() => useMixedTranslation()) - - expect(result.current.t).toBeDefined() - expect(typeof result.current.t).toBe('function') - }) - - it('should return translation key when no translation found', () => { - const { result } = renderHook(() => useMixedTranslation()) - - // The global mock returns key with namespace prefix - expect(result.current.t('category.all', { ns: 'plugin' })).toBe('plugin.category.all') - }) - - it('should use locale from outer when provided', () => { - const { result } = renderHook(() => useMixedTranslation('zh-Hans')) - - expect(result.current.t).toBeDefined() - }) - - it('should handle different locale values', () => { - const locales = ['en-US', 'zh-Hans', 'ja-JP', 'pt-BR'] - locales.forEach((locale) => { - const { result } = renderHook(() => useMixedTranslation(locale)) - expect(result.current.t).toBeDefined() - expect(typeof result.current.t).toBe('function') - }) - }) - - it('should use getFixedT when localeFromOuter is provided', () => { - const { result } = renderHook(() => useMixedTranslation('fr-FR')) - // The global mock returns key with namespace prefix - expect(result.current.t('search', { ns: 'plugin' })).toBe('plugin.search') - }) - }) -}) - // ================================ // useMarketplaceCollectionsAndPlugins Tests // ================================ @@ -2088,17 +2045,6 @@ describe('StickySearchAndSwitchWrapper', () => { }) describe('Props', () => { - it('should accept locale prop', () => { - render( - - - , - ) - - // Component should render without errors - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() - }) - it('should accept showSearchParams prop', () => { render( diff --git a/web/app/components/plugins/marketplace/index.tsx b/web/app/components/plugins/marketplace/index.tsx index ff9a4d60bc..08d1bc833f 100644 --- a/web/app/components/plugins/marketplace/index.tsx +++ b/web/app/components/plugins/marketplace/index.tsx @@ -1,6 +1,5 @@ import type { MarketplaceCollection, SearchParams } from './types' import type { Plugin } from '@/app/components/plugins/types' -import type { Locale } from '@/i18n-config' import { TanstackQueryInitializer } from '@/context/query-client' import { MarketplaceContextProvider } from './context' import Description from './description' @@ -9,7 +8,6 @@ import StickySearchAndSwitchWrapper from './sticky-search-and-switch-wrapper' import { getMarketplaceCollectionsAndPlugins } from './utils' type MarketplaceProps = { - locale: Locale showInstallButton?: boolean shouldExclude?: boolean searchParams?: SearchParams @@ -18,7 +16,6 @@ type MarketplaceProps = { showSearchParams?: boolean } const Marketplace = async ({ - locale, showInstallButton = true, shouldExclude, searchParams, @@ -42,14 +39,12 @@ const Marketplace = async ({ scrollContainerId={scrollContainerId} showSearchParams={showSearchParams} > - + { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const { theme } = useTheme() const [isShowInstallFromMarketplace, { setTrue: showInstallFromMarketplace, setFalse: hideInstallFromMarketplace, }] = useBoolean(false) - const localeFromLocale = useLocale() - const { getTagLabel } = useTags(t) + const locale = useLocale() + const { getTagLabel } = useTags() // Memoize marketplace link params to prevent unnecessary re-renders const marketplaceLinkParams = useMemo(() => ({ - language: localeFromLocale, + language: locale, theme, - }), [localeFromLocale, theme]) + }), [locale, theme]) // Memoize tag labels to prevent recreating array on every render const tagLabels = useMemo(() => @@ -52,7 +48,6 @@ const CardWrapperComponent = ({ ({ - useMixedTranslation: (_locale?: string) => ({ +// Mock i18n translation hook +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: (key: string, options?: { ns?: string, num?: number }) => { // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -28,6 +27,7 @@ vi.mock('../hooks', () => ({ return translations[fullKey] || key }, }), + useLocale: () => 'en-US', })) // Mock useMarketplaceContext with controllable values @@ -148,15 +148,15 @@ vi.mock('@/app/components/plugins/install-plugin/install-from-marketplace', () = // Mock SortDropdown component vi.mock('../sort-dropdown', () => ({ - default: ({ locale }: { locale: Locale }) => ( -
Sort
+ default: () => ( +
Sort
), })) // Mock Empty component vi.mock('../empty', () => ({ - default: ({ className, locale }: { className?: string, locale?: string }) => ( -
+ default: ({ className }: { className?: string }) => ( +
No plugins found
), @@ -233,7 +233,6 @@ describe('List', () => { marketplaceCollectionPluginsMap: {} as Record, plugins: undefined, showInstallButton: false, - locale: 'en-US' as Locale, cardContainerClassName: '', cardRender: undefined, onMoreClick: undefined, @@ -351,18 +350,6 @@ describe('List', () => { expect(screen.getByTestId('empty-component')).toHaveClass('custom-empty-class') }) - it('should pass locale to Empty component', () => { - render( - , - ) - - expect(screen.getByTestId('empty-component')).toHaveAttribute('data-locale', 'zh-CN') - }) - it('should pass showInstallButton to CardWrapper', () => { const plugins = createMockPluginList(1) @@ -508,7 +495,6 @@ describe('ListWithCollection', () => { marketplaceCollections: [] as MarketplaceCollection[], marketplaceCollectionPluginsMap: {} as Record, showInstallButton: false, - locale: 'en-US' as Locale, cardContainerClassName: '', cardRender: undefined, onMoreClick: undefined, @@ -820,7 +806,6 @@ describe('ListWrapper', () => { marketplaceCollections: [] as MarketplaceCollection[], marketplaceCollectionPluginsMap: {} as Record, showInstallButton: false, - locale: 'en-US' as Locale, } beforeEach(() => { @@ -901,14 +886,6 @@ describe('ListWrapper', () => { expect(screen.queryByTestId('sort-dropdown')).not.toBeInTheDocument() }) - - it('should pass locale to SortDropdown', () => { - mockContextValues.plugins = createMockPluginList(1) - - render() - - expect(screen.getByTestId('sort-dropdown')).toHaveAttribute('data-locale', 'zh-CN') - }) }) // ================================ @@ -1169,7 +1146,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1188,7 +1164,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1209,7 +1184,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={plugins} - locale="en-US" />, ) @@ -1231,7 +1205,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1252,7 +1225,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1274,7 +1246,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1293,7 +1264,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1310,7 +1280,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1327,7 +1296,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={true} - locale="en-US" />, ) @@ -1354,7 +1322,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={false} - locale="en-US" />, ) @@ -1375,7 +1342,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollectionPluginsMap={{}} plugins={[plugin]} showInstallButton={false} - locale="en-US" />, ) @@ -1390,7 +1356,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1414,7 +1379,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1432,7 +1396,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1450,7 +1413,6 @@ describe('CardWrapper (via List integration)', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={[plugin]} - locale="en-US" />, ) @@ -1482,7 +1444,6 @@ describe('Combined Workflows', () => { , ) @@ -1501,7 +1462,6 @@ describe('Combined Workflows', () => { , ) @@ -1521,7 +1481,6 @@ describe('Combined Workflows', () => { , ) @@ -1535,7 +1494,6 @@ describe('Combined Workflows', () => { , ) @@ -1551,7 +1509,6 @@ describe('Combined Workflows', () => { , ) @@ -1569,7 +1526,6 @@ describe('Combined Workflows', () => { , ) @@ -1601,7 +1557,6 @@ describe('Accessibility', () => { , ) @@ -1625,7 +1580,6 @@ describe('Accessibility', () => { marketplaceCollections={collections} marketplaceCollectionPluginsMap={pluginsMap} onMoreClick={onMoreClick} - locale="en-US" />, ) @@ -1642,7 +1596,6 @@ describe('Accessibility', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={plugins} - locale="en-US" />, ) @@ -1668,7 +1621,6 @@ describe('Performance', () => { marketplaceCollections={[]} marketplaceCollectionPluginsMap={{}} plugins={plugins} - locale="en-US" />, ) const endTime = performance.now() @@ -1689,7 +1641,6 @@ describe('Performance', () => { , ) const endTime = performance.now() diff --git a/web/app/components/plugins/marketplace/list/index.tsx b/web/app/components/plugins/marketplace/list/index.tsx index 54889b232f..80b33d0ffd 100644 --- a/web/app/components/plugins/marketplace/list/index.tsx +++ b/web/app/components/plugins/marketplace/list/index.tsx @@ -1,7 +1,6 @@ 'use client' import type { Plugin } from '../../types' import type { MarketplaceCollection } from '../types' -import type { Locale } from '@/i18n-config' import { cn } from '@/utils/classnames' import Empty from '../empty' import CardWrapper from './card-wrapper' @@ -12,7 +11,6 @@ type ListProps = { marketplaceCollectionPluginsMap: Record plugins?: Plugin[] showInstallButton?: boolean - locale: Locale cardContainerClassName?: string cardRender?: (plugin: Plugin) => React.JSX.Element | null onMoreClick?: () => void @@ -23,7 +21,6 @@ const List = ({ marketplaceCollectionPluginsMap, plugins, showInstallButton, - locale, cardContainerClassName, cardRender, onMoreClick, @@ -37,7 +34,6 @@ const List = ({ marketplaceCollections={marketplaceCollections} marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap} showInstallButton={showInstallButton} - locale={locale} cardContainerClassName={cardContainerClassName} cardRender={cardRender} onMoreClick={onMoreClick} @@ -61,7 +57,6 @@ const List = ({ key={`${plugin.org}/${plugin.name}`} plugin={plugin} showInstallButton={showInstallButton} - locale={locale} /> ) }) @@ -71,7 +66,7 @@ const List = ({ } { plugins && !plugins.length && ( - + ) } diff --git a/web/app/components/plugins/marketplace/list/list-with-collection.tsx b/web/app/components/plugins/marketplace/list/list-with-collection.tsx index 8830cc5ddf..c17715e71e 100644 --- a/web/app/components/plugins/marketplace/list/list-with-collection.tsx +++ b/web/app/components/plugins/marketplace/list/list-with-collection.tsx @@ -3,9 +3,8 @@ import type { MarketplaceCollection } from '../types' import type { SearchParamsFromCollection } from '@/app/components/plugins/marketplace/types' import type { Plugin } from '@/app/components/plugins/types' -import type { Locale } from '@/i18n-config' +import { useLocale, useTranslation } from '#i18n' import { RiArrowRightSLine } from '@remixicon/react' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { getLanguage } from '@/i18n-config/language' import { cn } from '@/utils/classnames' import CardWrapper from './card-wrapper' @@ -14,7 +13,6 @@ type ListWithCollectionProps = { marketplaceCollections: MarketplaceCollection[] marketplaceCollectionPluginsMap: Record showInstallButton?: boolean - locale: Locale cardContainerClassName?: string cardRender?: (plugin: Plugin) => React.JSX.Element | null onMoreClick?: (searchParams?: SearchParamsFromCollection) => void @@ -23,12 +21,12 @@ const ListWithCollection = ({ marketplaceCollections, marketplaceCollectionPluginsMap, showInstallButton, - locale, cardContainerClassName, cardRender, onMoreClick, }: ListWithCollectionProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() + const locale = useLocale() return ( <> @@ -72,7 +70,6 @@ const ListWithCollection = ({ key={plugin.plugin_id} plugin={plugin} showInstallButton={showInstallButton} - locale={locale} /> ) }) diff --git a/web/app/components/plugins/marketplace/list/list-wrapper.tsx b/web/app/components/plugins/marketplace/list/list-wrapper.tsx index f8126eb34b..84fcf92daf 100644 --- a/web/app/components/plugins/marketplace/list/list-wrapper.tsx +++ b/web/app/components/plugins/marketplace/list/list-wrapper.tsx @@ -1,10 +1,9 @@ 'use client' import type { Plugin } from '../../types' import type { MarketplaceCollection } from '../types' -import type { Locale } from '@/i18n-config' +import { useTranslation } from '#i18n' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { useMarketplaceContext } from '../context' import SortDropdown from '../sort-dropdown' import List from './index' @@ -13,15 +12,13 @@ type ListWrapperProps = { marketplaceCollections: MarketplaceCollection[] marketplaceCollectionPluginsMap: Record showInstallButton?: boolean - locale: Locale } const ListWrapper = ({ marketplaceCollections, marketplaceCollectionPluginsMap, showInstallButton, - locale, }: ListWrapperProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const plugins = useMarketplaceContext(v => v.plugins) const pluginsTotal = useMarketplaceContext(v => v.pluginsTotal) const marketplaceCollectionsFromClient = useMarketplaceContext(v => v.marketplaceCollectionsFromClient) @@ -55,7 +52,7 @@ const ListWrapper = ({
{t('marketplace.pluginsResult', { ns: 'plugin', num: pluginsTotal })}
- +
) } @@ -73,7 +70,6 @@ const ListWrapper = ({ marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMapFromClient || marketplaceCollectionPluginsMap} plugins={plugins} showInstallButton={showInstallButton} - locale={locale} onMoreClick={handleMoreClick} /> ) diff --git a/web/app/components/plugins/marketplace/plugin-type-switch.tsx b/web/app/components/plugins/marketplace/plugin-type-switch.tsx index 2a89e6847e..b9572413ed 100644 --- a/web/app/components/plugins/marketplace/plugin-type-switch.tsx +++ b/web/app/components/plugins/marketplace/plugin-type-switch.tsx @@ -1,4 +1,5 @@ 'use client' +import { useTranslation } from '#i18n' import { RiArchive2Line, RiBrain2Line, @@ -12,7 +13,6 @@ import { Trigger as TriggerIcon } from '@/app/components/base/icons/src/vender/p import { cn } from '@/utils/classnames' import { PluginCategoryEnum } from '../types' import { useMarketplaceContext } from './context' -import { useMixedTranslation } from './hooks' export const PLUGIN_TYPE_SEARCH_MAP = { all: 'all', @@ -25,16 +25,14 @@ export const PLUGIN_TYPE_SEARCH_MAP = { bundle: 'bundle', } type PluginTypeSwitchProps = { - locale?: string className?: string showSearchParams?: boolean } const PluginTypeSwitch = ({ - locale, className, showSearchParams, }: PluginTypeSwitchProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const activePluginType = useMarketplaceContext(s => s.activePluginType) const handleActivePluginTypeChange = useMarketplaceContext(s => s.handleActivePluginTypeChange) diff --git a/web/app/components/plugins/marketplace/search-box/index.spec.tsx b/web/app/components/plugins/marketplace/search-box/index.spec.tsx index 8c3131f6d1..3e9cc40be0 100644 --- a/web/app/components/plugins/marketplace/search-box/index.spec.tsx +++ b/web/app/components/plugins/marketplace/search-box/index.spec.tsx @@ -10,9 +10,9 @@ import ToolSelectorTrigger from './trigger/tool-selector' // Mock external dependencies only // ================================ -// Mock useMixedTranslation hook -vi.mock('../hooks', () => ({ - useMixedTranslation: (_locale?: string) => ({ +// Mock i18n translation hook +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: (key: string, options?: { ns?: string }) => { // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -364,13 +364,6 @@ describe('SearchBox', () => { expect(container.querySelector('.custom-input-class')).toBeInTheDocument() }) - it('should pass locale to TagsFilter', () => { - render() - - // TagsFilter should be rendered with locale - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() - }) - it('should handle empty placeholder', () => { render() @@ -449,12 +442,6 @@ describe('SearchBoxWrapper', () => { expect(screen.getByRole('textbox')).toBeInTheDocument() }) - it('should render with locale prop', () => { - render() - - expect(screen.getByRole('textbox')).toBeInTheDocument() - }) - it('should render in marketplace mode', () => { const { container } = render() @@ -500,13 +487,6 @@ describe('SearchBoxWrapper', () => { expect(screen.getByPlaceholderText('Search plugins')).toBeInTheDocument() }) - - it('should pass locale to useMixedTranslation', () => { - render() - - // Translation should still work - expect(screen.getByPlaceholderText('Search plugins')).toBeInTheDocument() - }) }) }) @@ -665,12 +645,6 @@ describe('MarketplaceTrigger', () => { }) describe('Props Variations', () => { - it('should handle locale prop', () => { - render() - - expect(screen.getByText('All Tags')).toBeInTheDocument() - }) - it('should handle empty tagsMap', () => { const { container } = render( , @@ -1251,7 +1225,6 @@ describe('Combined Workflows', () => { supportAddCustomTool onShowAddCustomCollectionModal={vi.fn()} placeholder="Search plugins" - locale="en-US" wrapperClassName="custom-wrapper" inputClassName="custom-input" autoFocus={false} diff --git a/web/app/components/plugins/marketplace/search-box/index.tsx b/web/app/components/plugins/marketplace/search-box/index.tsx index 05f98782b9..b6e1f8ee70 100644 --- a/web/app/components/plugins/marketplace/search-box/index.tsx +++ b/web/app/components/plugins/marketplace/search-box/index.tsx @@ -13,7 +13,6 @@ type SearchBoxProps = { tags: string[] onTagsChange: (tags: string[]) => void placeholder?: string - locale?: string supportAddCustomTool?: boolean usedInMarketplace?: boolean onShowAddCustomCollectionModal?: () => void @@ -28,7 +27,6 @@ const SearchBox = ({ tags, onTagsChange, placeholder = '', - locale, usedInMarketplace = false, supportAddCustomTool, onShowAddCustomCollectionModal, @@ -49,7 +47,6 @@ const SearchBox = ({ tags={tags} onTagsChange={onTagsChange} usedInMarketplace - locale={locale} />
@@ -109,7 +106,6 @@ const SearchBox = ({ ) diff --git a/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx b/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx index 1290c26210..d7fc004236 100644 --- a/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx +++ b/web/app/components/plugins/marketplace/search-box/search-box-wrapper.tsx @@ -1,16 +1,11 @@ 'use client' +import { useTranslation } from '#i18n' import { useMarketplaceContext } from '../context' -import { useMixedTranslation } from '../hooks' import SearchBox from './index' -type SearchBoxWrapperProps = { - locale?: string -} -const SearchBoxWrapper = ({ - locale, -}: SearchBoxWrapperProps) => { - const { t } = useMixedTranslation(locale) +const SearchBoxWrapper = () => { + const { t } = useTranslation() const searchPluginText = useMarketplaceContext(v => v.searchPluginText) const handleSearchPluginTextChange = useMarketplaceContext(v => v.handleSearchPluginTextChange) const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags) @@ -24,7 +19,6 @@ const SearchBoxWrapper = ({ onSearchChange={handleSearchPluginTextChange} tags={filterPluginTags} onTagsChange={handleFilterPluginTagsChange} - locale={locale} placeholder={t('searchPlugins', { ns: 'plugin' })} usedInMarketplace /> diff --git a/web/app/components/plugins/marketplace/search-box/tags-filter.tsx b/web/app/components/plugins/marketplace/search-box/tags-filter.tsx index df4d3eebab..9a8035e2e3 100644 --- a/web/app/components/plugins/marketplace/search-box/tags-filter.tsx +++ b/web/app/components/plugins/marketplace/search-box/tags-filter.tsx @@ -1,5 +1,6 @@ 'use client' +import { useTranslation } from '#i18n' import { useState } from 'react' import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' @@ -9,7 +10,6 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import { useTags } from '@/app/components/plugins/hooks' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import MarketplaceTrigger from './trigger/marketplace' import ToolSelectorTrigger from './trigger/tool-selector' @@ -17,18 +17,16 @@ type TagsFilterProps = { tags: string[] onTagsChange: (tags: string[]) => void usedInMarketplace?: boolean - locale?: string } const TagsFilter = ({ tags, onTagsChange, usedInMarketplace = false, - locale, }: TagsFilterProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() const [open, setOpen] = useState(false) const [searchText, setSearchText] = useState('') - const { tags: options, tagsMap } = useTags(t) + const { tags: options, tagsMap } = useTags() const filteredOptions = options.filter(option => option.label.toLowerCase().includes(searchText.toLowerCase())) const handleCheck = (id: string) => { if (tags.includes(id)) @@ -59,7 +57,6 @@ const TagsFilter = ({ open={open} tags={tags} tagsMap={tagsMap} - locale={locale} onTagsChange={onTagsChange} /> ) diff --git a/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx b/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx index 2ba03bd2f2..e387d52d0e 100644 --- a/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx +++ b/web/app/components/plugins/marketplace/search-box/trigger/marketplace.tsx @@ -1,15 +1,14 @@ import type { Tag } from '../../../hooks' +import { useTranslation } from '#i18n' import { RiArrowDownSLine, RiCloseCircleFill, RiFilter3Line } from '@remixicon/react' import * as React from 'react' import { cn } from '@/utils/classnames' -import { useMixedTranslation } from '../../hooks' type MarketplaceTriggerProps = { selectedTagsLength: number open: boolean tags: string[] tagsMap: Record - locale?: string onTagsChange: (tags: string[]) => void } @@ -18,10 +17,9 @@ const MarketplaceTrigger = ({ open, tags, tagsMap, - locale, onTagsChange, }: MarketplaceTriggerProps) => { - const { t } = useMixedTranslation(locale) + const { t } = useTranslation() return (
{ // Build full key with namespace prefix if provided const fullKey = options?.ns ? `${options.ns}.${key}` : key @@ -22,8 +22,8 @@ const mockTranslation = vi.fn((key: string, options?: { ns?: string }) => { return translations[fullKey] || key }) -vi.mock('../hooks', () => ({ - useMixedTranslation: (_locale?: string) => ({ +vi.mock('#i18n', () => ({ + useTranslation: () => ({ t: mockTranslation, }), })) @@ -145,36 +145,6 @@ describe('SortDropdown', () => { }) }) - // ================================ - // Props Testing - // ================================ - describe('Props', () => { - it('should accept locale prop', () => { - render() - - expect(screen.getByTestId('portal-wrapper')).toBeInTheDocument() - }) - - it('should call useMixedTranslation with provided locale', () => { - render() - - // Translation function should be called for labels - expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortBy', { ns: 'plugin' }) - }) - - it('should render without locale prop (undefined)', () => { - render() - - expect(screen.getByText('Sort by')).toBeInTheDocument() - }) - - it('should render with empty string locale', () => { - render() - - expect(screen.getByText('Sort by')).toBeInTheDocument() - }) - }) - // ================================ // State Management Tests // ================================ @@ -618,13 +588,6 @@ describe('SortDropdown', () => { expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.newlyReleased', { ns: 'plugin' }) expect(mockTranslation).toHaveBeenCalledWith('marketplace.sortOption.firstReleased', { ns: 'plugin' }) }) - - it('should pass locale to useMixedTranslation', () => { - render() - - // Verify component renders with locale - expect(screen.getByTestId('portal-wrapper')).toBeInTheDocument() - }) }) // ================================ diff --git a/web/app/components/plugins/marketplace/sort-dropdown/index.tsx b/web/app/components/plugins/marketplace/sort-dropdown/index.tsx index a1f6631735..984b114d03 100644 --- a/web/app/components/plugins/marketplace/sort-dropdown/index.tsx +++ b/web/app/components/plugins/marketplace/sort-dropdown/index.tsx @@ -1,4 +1,5 @@ 'use client' +import { useTranslation } from '#i18n' import { RiArrowDownSLine, RiCheckLine, @@ -9,16 +10,10 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import { useMixedTranslation } from '@/app/components/plugins/marketplace/hooks' import { useMarketplaceContext } from '../context' -type SortDropdownProps = { - locale?: string -} -const SortDropdown = ({ - locale, -}: SortDropdownProps) => { - const { t } = useMixedTranslation(locale) +const SortDropdown = () => { + const { t } = useTranslation() const options = [ { value: 'install_count', diff --git a/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx b/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx index 602a1e9af2..3d3530c83e 100644 --- a/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx +++ b/web/app/components/plugins/marketplace/sticky-search-and-switch-wrapper.tsx @@ -5,13 +5,11 @@ import PluginTypeSwitch from './plugin-type-switch' import SearchBoxWrapper from './search-box/search-box-wrapper' type StickySearchAndSwitchWrapperProps = { - locale?: string pluginTypeSwitchClassName?: string showSearchParams?: boolean } const StickySearchAndSwitchWrapper = ({ - locale, pluginTypeSwitchClassName, showSearchParams, }: StickySearchAndSwitchWrapperProps) => { @@ -25,9 +23,8 @@ const StickySearchAndSwitchWrapper = ({ pluginTypeSwitchClassName, )} > - +
diff --git a/web/app/components/plugins/plugin-item/index.tsx b/web/app/components/plugins/plugin-item/index.tsx index d287bd9e9a..3f658c63a8 100644 --- a/web/app/components/plugins/plugin-item/index.tsx +++ b/web/app/components/plugins/plugin-item/index.tsx @@ -44,7 +44,7 @@ const PluginItem: FC = ({ }) => { const { t } = useTranslation() const { theme } = useTheme() - const { categoriesMap } = useCategories(t, true) + const { categoriesMap } = useCategories(true) const currentPluginID = usePluginPageContext(v => v.currentPluginID) const setCurrentPluginID = usePluginPageContext(v => v.setCurrentPluginID) const { refreshPluginList } = useRefreshPluginList() diff --git a/web/app/components/plugins/provider-card.tsx b/web/app/components/plugins/provider-card.tsx index d76e222c4a..a3bba8d774 100644 --- a/web/app/components/plugins/provider-card.tsx +++ b/web/app/components/plugins/provider-card.tsx @@ -92,7 +92,7 @@ const ProviderCardComponent: FC = ({ manifest={payload} uniqueIdentifier={payload.latest_package_identifier} onClose={hideInstallFromMarketplace} - onSuccess={hideInstallFromMarketplace} + onSuccess={() => hideInstallFromMarketplace()} /> ) } diff --git a/web/app/components/provider/i18n-server.tsx b/web/app/components/provider/i18n-server.tsx new file mode 100644 index 0000000000..23391cf428 --- /dev/null +++ b/web/app/components/provider/i18n-server.tsx @@ -0,0 +1,21 @@ +import { getLocaleOnServer, getResources } from '@/i18n-config/server' + +import { I18nClientProvider } from './i18n' + +export async function I18nServerProvider({ + children, +}: { + children: React.ReactNode +}) { + const locale = await getLocaleOnServer() + const resource = await getResources(locale) + + return ( + + {children} + + ) +} diff --git a/web/app/components/provider/i18n.tsx b/web/app/components/provider/i18n.tsx new file mode 100644 index 0000000000..6441a09dd3 --- /dev/null +++ b/web/app/components/provider/i18n.tsx @@ -0,0 +1,24 @@ +'use client' + +import type { Resource } from 'i18next' +import type { Locale } from '@/i18n-config' +import { I18nextProvider } from 'react-i18next' +import { createI18nextInstance } from '@/i18n-config/client' + +export function I18nClientProvider({ + locale, + resource, + children, +}: { + locale: Locale + resource: Resource + children: React.ReactNode +}) { + const i18n = createI18nextInstance(locale, resource) + + return ( + + {children} + + ) +} diff --git a/web/app/components/rag-pipeline/components/chunk-card-list/index.spec.tsx b/web/app/components/rag-pipeline/components/chunk-card-list/index.spec.tsx new file mode 100644 index 0000000000..e665cf134e --- /dev/null +++ b/web/app/components/rag-pipeline/components/chunk-card-list/index.spec.tsx @@ -0,0 +1,1164 @@ +import type { GeneralChunks, ParentChildChunk, ParentChildChunks, QAChunk, QAChunks } from './types' +import { render, screen } from '@testing-library/react' +import { ChunkingMode } from '@/models/datasets' +import ChunkCard from './chunk-card' +import { ChunkCardList } from './index' +import QAItem from './q-a-item' +import { QAItemType } from './types' + +// ============================================================================= +// Test Data Factories +// ============================================================================= + +const createGeneralChunks = (overrides: string[] = []): GeneralChunks => { + if (overrides.length > 0) + return overrides + return [ + 'This is the first chunk of text content.', + 'This is the second chunk with different content.', + 'Third chunk here with more text.', + ] +} + +const createParentChildChunk = (overrides: Partial = {}): ParentChildChunk => ({ + child_contents: ['Child content 1', 'Child content 2'], + parent_content: 'This is the parent content that contains the children.', + parent_mode: 'paragraph', + ...overrides, +}) + +const createParentChildChunks = (overrides: Partial = {}): ParentChildChunks => ({ + parent_child_chunks: [ + createParentChildChunk(), + createParentChildChunk({ + child_contents: ['Another child 1', 'Another child 2', 'Another child 3'], + parent_content: 'Another parent content here.', + }), + ], + parent_mode: 'paragraph', + ...overrides, +}) + +const createQAChunk = (overrides: Partial = {}): QAChunk => ({ + question: 'What is the answer to life?', + answer: 'The answer is 42.', + ...overrides, +}) + +const createQAChunks = (overrides: Partial = {}): QAChunks => ({ + qa_chunks: [ + createQAChunk(), + createQAChunk({ + question: 'How does this work?', + answer: 'It works by processing data.', + }), + ], + ...overrides, +}) + +// ============================================================================= +// QAItem Component Tests +// ============================================================================= + +describe('QAItem', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Tests for basic rendering of QAItem component + describe('Rendering', () => { + it('should render question type with Q prefix', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('What is this?')).toBeInTheDocument() + }) + + it('should render answer type with A prefix', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('A')).toBeInTheDocument() + expect(screen.getByText('This is the answer.')).toBeInTheDocument() + }) + }) + + // Tests for different prop variations + describe('Props', () => { + it('should render with empty text', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('Q')).toBeInTheDocument() + }) + + it('should render with long text content', () => { + // Arrange + const longText = 'A'.repeat(1000) + + // Act + render() + + // Assert + expect(screen.getByText(longText)).toBeInTheDocument() + }) + + it('should render with special characters in text', () => { + // Arrange + const specialText = ' & "quotes" \'apostrophe\'' + + // Act + render() + + // Assert + expect(screen.getByText(specialText)).toBeInTheDocument() + }) + }) + + // Tests for memoization behavior + describe('Memoization', () => { + it('should be memoized with React.memo', () => { + // Arrange & Act + const { rerender } = render() + + // Assert - component should render consistently + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('Test')).toBeInTheDocument() + + // Rerender with same props - should not cause issues + rerender() + expect(screen.getByText('Q')).toBeInTheDocument() + }) + }) +}) + +// ============================================================================= +// ChunkCard Component Tests +// ============================================================================= + +describe('ChunkCard', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Tests for basic rendering with different chunk types + describe('Rendering', () => { + it('should render text chunk type correctly', () => { + // Arrange & Act + render( + , + ) + + // Assert + expect(screen.getByText('This is plain text content.')).toBeInTheDocument() + expect(screen.getByText(/Chunk-01/)).toBeInTheDocument() + }) + + it('should render QA chunk type with question and answer', () => { + // Arrange + const qaContent: QAChunk = { + question: 'What is React?', + answer: 'React is a JavaScript library.', + } + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('What is React?')).toBeInTheDocument() + expect(screen.getByText('A')).toBeInTheDocument() + expect(screen.getByText('React is a JavaScript library.')).toBeInTheDocument() + }) + + it('should render parent-child chunk type with child contents', () => { + // Arrange + const childContents = ['Child 1 content', 'Child 2 content'] + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Child 1 content')).toBeInTheDocument() + expect(screen.getByText('Child 2 content')).toBeInTheDocument() + expect(screen.getByText('C-1')).toBeInTheDocument() + expect(screen.getByText('C-2')).toBeInTheDocument() + }) + }) + + // Tests for parent mode variations + describe('Parent Mode Variations', () => { + it('should show Parent-Chunk label prefix for paragraph mode', () => { + // Arrange & Act + render( + , + ) + + // Assert + expect(screen.getByText(/Parent-Chunk-01/)).toBeInTheDocument() + }) + + it('should hide segment index tag for full-doc mode', () => { + // Arrange & Act + render( + , + ) + + // Assert - should not show Chunk or Parent-Chunk label + expect(screen.queryByText(/Chunk/)).not.toBeInTheDocument() + expect(screen.queryByText(/Parent-Chunk/)).not.toBeInTheDocument() + }) + + it('should show Chunk label prefix for text mode', () => { + // Arrange & Act + render( + , + ) + + // Assert + expect(screen.getByText(/Chunk-05/)).toBeInTheDocument() + }) + }) + + // Tests for word count display + describe('Word Count Display', () => { + it('should display formatted word count', () => { + // Arrange & Act + render( + , + ) + + // Assert - formatNumber(1234) returns '1,234' + expect(screen.getByText(/1,234/)).toBeInTheDocument() + }) + + it('should display word count with character translation key', () => { + // Arrange & Act + render( + , + ) + + // Assert - translation key is returned as-is by mock + expect(screen.getByText(/100\s+(?:\S.*)?characters/)).toBeInTheDocument() + }) + + it('should not display word count info for full-doc mode', () => { + // Arrange & Act + render( + , + ) + + // Assert - the header with word count should be hidden + expect(screen.queryByText(/500/)).not.toBeInTheDocument() + }) + }) + + // Tests for position ID variations + describe('Position ID', () => { + it('should handle numeric position ID', () => { + // Arrange & Act + render( + , + ) + + // Assert + expect(screen.getByText(/Chunk-42/)).toBeInTheDocument() + }) + + it('should handle string position ID', () => { + // Arrange & Act + render( + , + ) + + // Assert + expect(screen.getByText(/Chunk-99/)).toBeInTheDocument() + }) + + it('should pad single digit position ID', () => { + // Arrange & Act + render( + , + ) + + // Assert + expect(screen.getByText(/Chunk-03/)).toBeInTheDocument() + }) + }) + + // Tests for memoization dependencies + describe('Memoization', () => { + it('should update isFullDoc memo when parentMode changes', () => { + // Arrange + const { rerender } = render( + , + ) + + // Assert - paragraph mode shows label + expect(screen.getByText(/Parent-Chunk/)).toBeInTheDocument() + + // Act - change to full-doc + rerender( + , + ) + + // Assert - full-doc mode hides label + expect(screen.queryByText(/Parent-Chunk/)).not.toBeInTheDocument() + }) + + it('should update contentElement memo when content changes', () => { + // Arrange + const { rerender } = render( + , + ) + + // Assert + expect(screen.getByText('Initial content')).toBeInTheDocument() + + // Act + rerender( + , + ) + + // Assert + expect(screen.getByText('Updated content')).toBeInTheDocument() + expect(screen.queryByText('Initial content')).not.toBeInTheDocument() + }) + + it('should update contentElement memo when chunkType changes', () => { + // Arrange + const { rerender } = render( + , + ) + + // Assert + expect(screen.getByText('Text content')).toBeInTheDocument() + + // Act - change to QA type + const qaContent: QAChunk = { question: 'Q?', answer: 'A.' } + rerender( + , + ) + + // Assert + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('Q?')).toBeInTheDocument() + }) + }) + + // Tests for edge cases + describe('Edge Cases', () => { + it('should handle empty child contents array', () => { + // Arrange & Act + render( + , + ) + + // Assert - should render without errors + expect(screen.getByText(/Parent-Chunk-01/)).toBeInTheDocument() + }) + + it('should handle QA chunk with empty strings', () => { + // Arrange + const emptyQA: QAChunk = { question: '', answer: '' } + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('A')).toBeInTheDocument() + }) + + it('should handle very long content', () => { + // Arrange + const longContent = 'A'.repeat(10000) + + // Act + render( + , + ) + + // Assert + expect(screen.getByText(longContent)).toBeInTheDocument() + }) + + it('should handle zero word count', () => { + // Arrange & Act + render( + , + ) + + // Assert - formatNumber returns falsy for 0, so it shows 0 + expect(screen.getByText(/0\s+(?:\S.*)?characters/)).toBeInTheDocument() + }) + }) +}) + +// ============================================================================= +// ChunkCardList Component Tests +// ============================================================================= + +describe('ChunkCardList', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Tests for rendering with different chunk types + describe('Rendering', () => { + it('should render text chunks correctly', () => { + // Arrange + const chunks = createGeneralChunks() + + // Act + render( + , + ) + + // Assert + expect(screen.getByText(chunks[0])).toBeInTheDocument() + expect(screen.getByText(chunks[1])).toBeInTheDocument() + expect(screen.getByText(chunks[2])).toBeInTheDocument() + }) + + it('should render parent-child chunks correctly', () => { + // Arrange + const chunks = createParentChildChunks() + + // Act + render( + , + ) + + // Assert - should render child contents from parent-child chunks + expect(screen.getByText('Child content 1')).toBeInTheDocument() + expect(screen.getByText('Child content 2')).toBeInTheDocument() + expect(screen.getByText('Another child 1')).toBeInTheDocument() + }) + + it('should render QA chunks correctly', () => { + // Arrange + const chunks = createQAChunks() + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('What is the answer to life?')).toBeInTheDocument() + expect(screen.getByText('The answer is 42.')).toBeInTheDocument() + expect(screen.getByText('How does this work?')).toBeInTheDocument() + expect(screen.getByText('It works by processing data.')).toBeInTheDocument() + }) + }) + + // Tests for chunkList memoization + describe('Memoization - chunkList', () => { + it('should extract chunks from GeneralChunks for text mode', () => { + // Arrange + const chunks: GeneralChunks = ['Chunk 1', 'Chunk 2'] + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Chunk 1')).toBeInTheDocument() + expect(screen.getByText('Chunk 2')).toBeInTheDocument() + }) + + it('should extract parent_child_chunks from ParentChildChunks for parentChild mode', () => { + // Arrange + const chunks = createParentChildChunks({ + parent_child_chunks: [ + createParentChildChunk({ child_contents: ['Specific child'] }), + ], + }) + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Specific child')).toBeInTheDocument() + }) + + it('should extract qa_chunks from QAChunks for qa mode', () => { + // Arrange + const chunks: QAChunks = { + qa_chunks: [ + { question: 'Specific Q', answer: 'Specific A' }, + ], + } + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Specific Q')).toBeInTheDocument() + expect(screen.getByText('Specific A')).toBeInTheDocument() + }) + + it('should update chunkList when chunkInfo changes', () => { + // Arrange + const initialChunks = createGeneralChunks(['Initial chunk']) + + const { rerender } = render( + , + ) + + // Assert initial state + expect(screen.getByText('Initial chunk')).toBeInTheDocument() + + // Act - update chunks + const updatedChunks = createGeneralChunks(['Updated chunk']) + rerender( + , + ) + + // Assert updated state + expect(screen.getByText('Updated chunk')).toBeInTheDocument() + expect(screen.queryByText('Initial chunk')).not.toBeInTheDocument() + }) + }) + + // Tests for getWordCount function + describe('Word Count Calculation', () => { + it('should calculate word count for text chunks using string length', () => { + // Arrange - "Hello" has 5 characters + const chunks = createGeneralChunks(['Hello']) + + // Act + render( + , + ) + + // Assert - word count should be 5 (string length) + expect(screen.getByText(/5\s+(?:\S.*)?characters/)).toBeInTheDocument() + }) + + it('should calculate word count for parent-child chunks using parent_content length', () => { + // Arrange - parent_content length determines word count + const chunks = createParentChildChunks({ + parent_child_chunks: [ + createParentChildChunk({ + parent_content: 'Parent', // 6 characters + child_contents: ['Child'], + }), + ], + }) + + // Act + render( + , + ) + + // Assert - word count should be 6 (parent_content length) + expect(screen.getByText(/6\s+(?:\S.*)?characters/)).toBeInTheDocument() + }) + + it('should calculate word count for QA chunks using question + answer length', () => { + // Arrange - "Hi" (2) + "Bye" (3) = 5 + const chunks: QAChunks = { + qa_chunks: [ + { question: 'Hi', answer: 'Bye' }, + ], + } + + // Act + render( + , + ) + + // Assert - word count should be 5 (question.length + answer.length) + expect(screen.getByText(/5\s+(?:\S.*)?characters/)).toBeInTheDocument() + }) + }) + + // Tests for position ID assignment + describe('Position ID', () => { + it('should assign 1-based position IDs to chunks', () => { + // Arrange + const chunks = createGeneralChunks(['First', 'Second', 'Third']) + + // Act + render( + , + ) + + // Assert - position IDs should be 1, 2, 3 + expect(screen.getByText(/Chunk-01/)).toBeInTheDocument() + expect(screen.getByText(/Chunk-02/)).toBeInTheDocument() + expect(screen.getByText(/Chunk-03/)).toBeInTheDocument() + }) + }) + + // Tests for className prop + describe('Custom className', () => { + it('should apply custom className to container', () => { + // Arrange + const chunks = createGeneralChunks(['Test']) + + // Act + const { container } = render( + , + ) + + // Assert + expect(container.firstChild).toHaveClass('custom-class') + }) + + it('should merge custom className with default classes', () => { + // Arrange + const chunks = createGeneralChunks(['Test']) + + // Act + const { container } = render( + , + ) + + // Assert - should have both default and custom classes + expect(container.firstChild).toHaveClass('flex') + expect(container.firstChild).toHaveClass('w-full') + expect(container.firstChild).toHaveClass('flex-col') + expect(container.firstChild).toHaveClass('my-custom-class') + }) + + it('should render without className prop', () => { + // Arrange + const chunks = createGeneralChunks(['Test']) + + // Act + const { container } = render( + , + ) + + // Assert - should have default classes + expect(container.firstChild).toHaveClass('flex') + expect(container.firstChild).toHaveClass('w-full') + }) + }) + + // Tests for parentMode prop + describe('Parent Mode', () => { + it('should pass parentMode to ChunkCard for parent-child type', () => { + // Arrange + const chunks = createParentChildChunks() + + // Act + render( + , + ) + + // Assert - paragraph mode shows Parent-Chunk label + expect(screen.getAllByText(/Parent-Chunk/).length).toBeGreaterThan(0) + }) + + it('should handle full-doc parentMode', () => { + // Arrange + const chunks = createParentChildChunks() + + // Act + render( + , + ) + + // Assert - full-doc mode hides chunk labels + expect(screen.queryByText(/Parent-Chunk/)).not.toBeInTheDocument() + expect(screen.queryByText(/Chunk-/)).not.toBeInTheDocument() + }) + + it('should not use parentMode for text type', () => { + // Arrange + const chunks = createGeneralChunks(['Text']) + + // Act + render( + , + ) + + // Assert - should show Chunk label, not affected by parentMode + expect(screen.getByText(/Chunk-01/)).toBeInTheDocument() + }) + }) + + // Tests for edge cases + describe('Edge Cases', () => { + it('should handle empty GeneralChunks array', () => { + // Arrange + const chunks: GeneralChunks = [] + + // Act + const { container } = render( + , + ) + + // Assert - should render empty container + expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild?.childNodes.length).toBe(0) + }) + + it('should handle empty ParentChildChunks', () => { + // Arrange + const chunks: ParentChildChunks = { + parent_child_chunks: [], + parent_mode: 'paragraph', + } + + // Act + const { container } = render( + , + ) + + // Assert + expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild?.childNodes.length).toBe(0) + }) + + it('should handle empty QAChunks', () => { + // Arrange + const chunks: QAChunks = { + qa_chunks: [], + } + + // Act + const { container } = render( + , + ) + + // Assert + expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild?.childNodes.length).toBe(0) + }) + + it('should handle single item in chunks', () => { + // Arrange + const chunks = createGeneralChunks(['Single chunk']) + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Single chunk')).toBeInTheDocument() + expect(screen.getByText(/Chunk-01/)).toBeInTheDocument() + }) + + it('should handle large number of chunks', () => { + // Arrange + const chunks = Array.from({ length: 100 }, (_, i) => `Chunk number ${i + 1}`) + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('Chunk number 1')).toBeInTheDocument() + expect(screen.getByText('Chunk number 100')).toBeInTheDocument() + expect(screen.getByText(/Chunk-100/)).toBeInTheDocument() + }) + }) + + // Tests for key uniqueness + describe('Key Generation', () => { + it('should generate unique keys for chunks', () => { + // Arrange - chunks with same content + const chunks = createGeneralChunks(['Same content', 'Same content', 'Same content']) + + // Act + const { container } = render( + , + ) + + // Assert - all three should render (keys are based on chunkType-index) + const chunkCards = container.querySelectorAll('.bg-components-panel-bg') + expect(chunkCards.length).toBe(3) + }) + }) +}) + +// ============================================================================= +// Integration Tests +// ============================================================================= + +describe('ChunkCardList Integration', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Tests for complete workflow scenarios + describe('Complete Workflows', () => { + it('should render complete text chunking workflow', () => { + // Arrange + const textChunks = createGeneralChunks([ + 'First paragraph of the document.', + 'Second paragraph with more information.', + 'Final paragraph concluding the content.', + ]) + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('First paragraph of the document.')).toBeInTheDocument() + expect(screen.getByText(/Chunk-01/)).toBeInTheDocument() + // "First paragraph of the document." = 32 characters + expect(screen.getByText(/32\s+(?:\S.*)?characters/)).toBeInTheDocument() + + expect(screen.getByText('Second paragraph with more information.')).toBeInTheDocument() + expect(screen.getByText(/Chunk-02/)).toBeInTheDocument() + + expect(screen.getByText('Final paragraph concluding the content.')).toBeInTheDocument() + expect(screen.getByText(/Chunk-03/)).toBeInTheDocument() + }) + + it('should render complete parent-child chunking workflow', () => { + // Arrange + const parentChildChunks = createParentChildChunks({ + parent_child_chunks: [ + { + parent_content: 'Main section about React components and their lifecycle.', + child_contents: [ + 'React components are building blocks.', + 'Lifecycle methods control component behavior.', + ], + parent_mode: 'paragraph', + }, + ], + }) + + // Act + render( + , + ) + + // Assert + expect(screen.getByText('React components are building blocks.')).toBeInTheDocument() + expect(screen.getByText('Lifecycle methods control component behavior.')).toBeInTheDocument() + expect(screen.getByText('C-1')).toBeInTheDocument() + expect(screen.getByText('C-2')).toBeInTheDocument() + expect(screen.getByText(/Parent-Chunk-01/)).toBeInTheDocument() + }) + + it('should render complete QA chunking workflow', () => { + // Arrange + const qaChunks = createQAChunks({ + qa_chunks: [ + { + question: 'What is Dify?', + answer: 'Dify is an open-source LLM application development platform.', + }, + { + question: 'How do I get started?', + answer: 'You can start by installing the platform using Docker.', + }, + ], + }) + + // Act + render( + , + ) + + // Assert + const qLabels = screen.getAllByText('Q') + const aLabels = screen.getAllByText('A') + expect(qLabels.length).toBe(2) + expect(aLabels.length).toBe(2) + + expect(screen.getByText('What is Dify?')).toBeInTheDocument() + expect(screen.getByText('Dify is an open-source LLM application development platform.')).toBeInTheDocument() + expect(screen.getByText('How do I get started?')).toBeInTheDocument() + expect(screen.getByText('You can start by installing the platform using Docker.')).toBeInTheDocument() + }) + }) + + // Tests for type switching scenarios + describe('Type Switching', () => { + it('should handle switching from text to QA type', () => { + // Arrange + const textChunks = createGeneralChunks(['Text content']) + const qaChunks = createQAChunks() + + const { rerender } = render( + , + ) + + // Assert initial text state + expect(screen.getByText('Text content')).toBeInTheDocument() + + // Act - switch to QA + rerender( + , + ) + + // Assert QA state + expect(screen.queryByText('Text content')).not.toBeInTheDocument() + expect(screen.getByText('What is the answer to life?')).toBeInTheDocument() + }) + + it('should handle switching from text to parent-child type', () => { + // Arrange + const textChunks = createGeneralChunks(['Simple text']) + const parentChildChunks = createParentChildChunks() + + const { rerender } = render( + , + ) + + // Assert initial state + expect(screen.getByText('Simple text')).toBeInTheDocument() + expect(screen.getByText(/Chunk-01/)).toBeInTheDocument() + + // Act - switch to parent-child + rerender( + , + ) + + // Assert parent-child state + expect(screen.queryByText('Simple text')).not.toBeInTheDocument() + // Multiple Parent-Chunk elements exist, so use getAllByText + expect(screen.getAllByText(/Parent-Chunk/).length).toBeGreaterThan(0) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/index.spec.tsx b/web/app/components/rag-pipeline/components/index.spec.tsx new file mode 100644 index 0000000000..3f6b0dccc2 --- /dev/null +++ b/web/app/components/rag-pipeline/components/index.spec.tsx @@ -0,0 +1,1390 @@ +import type { PropsWithChildren } from 'react' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { createMockProviderContextValue } from '@/__mocks__/provider-context' + +// ============================================================================ +// Import Components After Mocks Setup +// ============================================================================ + +import Conversion from './conversion' +import RagPipelinePanel from './panel' +import PublishAsKnowledgePipelineModal from './publish-as-knowledge-pipeline-modal' +import PublishToast from './publish-toast' +import RagPipelineChildren from './rag-pipeline-children' +import PipelineScreenShot from './screenshot' + +// ============================================================================ +// Mock External Dependencies - All vi.mock calls must come before any imports +// ============================================================================ + +// Mock next/navigation +const mockPush = vi.fn() +vi.mock('next/navigation', () => ({ + useParams: () => ({ datasetId: 'test-dataset-id' }), + useRouter: () => ({ push: mockPush }), +})) + +// Mock next/image +vi.mock('next/image', () => ({ + default: ({ src, alt, width, height }: { src: string, alt: string, width: number, height: number }) => ( + // eslint-disable-next-line next/no-img-element + {alt} + ), +})) + +// Mock next/dynamic +vi.mock('next/dynamic', () => ({ + default: (importFn: () => Promise<{ default: React.ComponentType }>, options?: { ssr?: boolean }) => { + const DynamicComponent = ({ children, ...props }: PropsWithChildren) => { + return
{children}
+ } + DynamicComponent.displayName = 'DynamicComponent' + return DynamicComponent + }, +})) + +// Mock workflow store - using controllable state +let mockShowImportDSLModal = false +const mockSetShowImportDSLModal = vi.fn((value: boolean) => { + mockShowImportDSLModal = value +}) +vi.mock('@/app/components/workflow/store', () => { + const mockSetShowInputFieldPanel = vi.fn() + const mockSetShowEnvPanel = vi.fn() + const mockSetShowDebugAndPreviewPanel = vi.fn() + const mockSetIsPreparingDataSource = vi.fn() + const mockSetPublishedAt = vi.fn() + const mockSetRagPipelineVariables = vi.fn() + const mockSetEnvironmentVariables = vi.fn() + + return { + useStore: (selector: (state: Record) => unknown) => { + const storeState = { + pipelineId: 'test-pipeline-id', + showDebugAndPreviewPanel: false, + showGlobalVariablePanel: false, + showInputFieldPanel: false, + showInputFieldPreviewPanel: false, + inputFieldEditPanelProps: null as null | object, + historyWorkflowData: null as null | object, + publishedAt: 0, + draftUpdatedAt: Date.now(), + knowledgeName: 'Test Knowledge', + knowledgeIcon: { + icon_type: 'emoji' as const, + icon: '📚', + icon_background: '#FFFFFF', + icon_url: '', + }, + showImportDSLModal: mockShowImportDSLModal, + setShowInputFieldPanel: mockSetShowInputFieldPanel, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setPublishedAt: mockSetPublishedAt, + setRagPipelineVariables: mockSetRagPipelineVariables, + setEnvironmentVariables: mockSetEnvironmentVariables, + setShowImportDSLModal: mockSetShowImportDSLModal, + } + return selector(storeState) + }, + useWorkflowStore: () => ({ + getState: () => ({ + pipelineId: 'test-pipeline-id', + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + setPublishedAt: mockSetPublishedAt, + setRagPipelineVariables: mockSetRagPipelineVariables, + setEnvironmentVariables: mockSetEnvironmentVariables, + }), + }), + } +}) + +// Mock workflow hooks - extract mock functions for assertions using vi.hoisted +const { + mockHandlePaneContextmenuCancel, + mockExportCheck, + mockHandleExportDSL, +} = vi.hoisted(() => ({ + mockHandlePaneContextmenuCancel: vi.fn(), + mockExportCheck: vi.fn(), + mockHandleExportDSL: vi.fn(), +})) +vi.mock('@/app/components/workflow/hooks', () => { + return { + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: vi.fn(), + syncWorkflowDraftWhenPageClose: vi.fn(), + handleSyncWorkflowDraft: vi.fn(), + }), + usePanelInteractions: () => ({ + handlePaneContextmenuCancel: mockHandlePaneContextmenuCancel, + }), + useDSL: () => ({ + exportCheck: mockExportCheck, + handleExportDSL: mockHandleExportDSL, + }), + useChecklistBeforePublish: () => ({ + handleCheckBeforePublish: vi.fn().mockResolvedValue(true), + }), + useWorkflowRun: () => ({ + handleStopRun: vi.fn(), + }), + useWorkflowStartRun: () => ({ + handleWorkflowStartRunInWorkflow: vi.fn(), + }), + } +}) + +// Mock rag-pipeline hooks +vi.mock('../hooks', () => ({ + useAvailableNodesMetaData: () => ({}), + useDSL: () => ({ + exportCheck: mockExportCheck, + handleExportDSL: mockHandleExportDSL, + }), + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: vi.fn(), + syncWorkflowDraftWhenPageClose: vi.fn(), + }), + usePipelineRefreshDraft: () => ({ + handleRefreshWorkflowDraft: vi.fn(), + }), + usePipelineRun: () => ({ + handleBackupDraft: vi.fn(), + handleLoadBackupDraft: vi.fn(), + handleRestoreFromPublishedWorkflow: vi.fn(), + handleRun: vi.fn(), + handleStopRun: vi.fn(), + }), + usePipelineStartRun: () => ({ + handleStartWorkflowRun: vi.fn(), + handleWorkflowStartRunInWorkflow: vi.fn(), + }), + useGetRunAndTraceUrl: () => ({ + getWorkflowRunAndTraceUrl: vi.fn(), + }), +})) + +// Mock rag-pipeline search hook +vi.mock('../hooks/use-rag-pipeline-search', () => ({ + useRagPipelineSearch: vi.fn(), +})) + +// Mock configs-map hook +vi.mock('../hooks/use-configs-map', () => ({ + useConfigsMap: () => ({}), +})) + +// Mock inspect-vars-crud hook +vi.mock('../hooks/use-inspect-vars-crud', () => ({ + useInspectVarsCrud: () => ({ + hasNodeInspectVars: vi.fn(), + hasSetInspectVar: vi.fn(), + fetchInspectVarValue: vi.fn(), + editInspectVarValue: vi.fn(), + renameInspectVarName: vi.fn(), + appendNodeInspectVars: vi.fn(), + deleteInspectVar: vi.fn(), + deleteNodeInspectorVars: vi.fn(), + deleteAllInspectorVars: vi.fn(), + isInspectVarEdited: vi.fn(), + resetToLastRunVar: vi.fn(), + invalidateSysVarValues: vi.fn(), + resetConversationVar: vi.fn(), + invalidateConversationVarValues: vi.fn(), + }), +})) + +// Mock workflow hooks for fetch-workflow-inspect-vars +vi.mock('@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars', () => ({ + useSetWorkflowVarsWithValue: () => ({ + fetchInspectVars: vi.fn(), + }), +})) + +// Mock service hooks - with controllable convert function +let mockConvertFn = vi.fn() +let mockIsPending = false +vi.mock('@/service/use-pipeline', () => ({ + useConvertDatasetToPipeline: () => ({ + mutateAsync: mockConvertFn, + isPending: mockIsPending, + }), + useImportPipelineDSL: () => ({ + mutateAsync: vi.fn(), + }), + useImportPipelineDSLConfirm: () => ({ + mutateAsync: vi.fn(), + }), + publishedPipelineInfoQueryKeyPrefix: ['pipeline-info'], + useInvalidCustomizedTemplateList: () => vi.fn(), + usePublishAsCustomizedPipeline: () => ({ + mutateAsync: vi.fn(), + }), +})) + +vi.mock('@/service/use-base', () => ({ + useInvalid: () => vi.fn(), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + datasetDetailQueryKeyPrefix: ['dataset-detail'], + useInvalidDatasetList: () => vi.fn(), +})) + +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: vi.fn().mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + rag_pipeline_variables: [], + }), +})) + +// Mock event emitter context - with controllable subscription +let mockEventSubscriptionCallback: ((v: { type: string, payload?: { data?: EnvironmentVariable[] } }) => void) | null = null +const mockUseSubscription = vi.fn((callback: (v: { type: string, payload?: { data?: EnvironmentVariable[] } }) => void) => { + mockEventSubscriptionCallback = callback +}) +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + useSubscription: mockUseSubscription, + emit: vi.fn(), + }, + }), +})) + +// Mock toast +vi.mock('@/app/components/base/toast', () => ({ + default: { + notify: vi.fn(), + }, + useToastContext: () => ({ + notify: vi.fn(), + }), + ToastContext: { + Provider: ({ children }: PropsWithChildren) => children, + }, +})) + +// Mock useTheme hook +vi.mock('@/hooks/use-theme', () => ({ + default: () => ({ + theme: 'light', + }), +})) + +// Mock basePath +vi.mock('@/utils/var', () => ({ + basePath: '/public', +})) + +// Mock provider context +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => createMockProviderContextValue(), +})) + +// Mock WorkflowWithInnerContext +vi.mock('@/app/components/workflow', () => ({ + WorkflowWithInnerContext: ({ children }: PropsWithChildren) => ( +
{children}
+ ), +})) + +// Mock workflow panel +vi.mock('@/app/components/workflow/panel', () => ({ + default: ({ components }: { components?: { left?: React.ReactNode, right?: React.ReactNode } }) => ( +
+
{components?.left}
+
{components?.right}
+
+ ), +})) + +// Mock PluginDependency +vi.mock('../../workflow/plugin-dependency', () => ({ + default: () =>
, +})) + +// Mock plugin-dependency hooks +vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({ + usePluginDependencies: () => ({ + handleCheckPluginDependencies: vi.fn().mockResolvedValue(undefined), + }), +})) + +// Mock DSLExportConfirmModal +vi.mock('@/app/components/workflow/dsl-export-confirm-modal', () => ({ + default: ({ envList, onConfirm, onClose }: { envList: EnvironmentVariable[], onConfirm: () => void, onClose: () => void }) => ( +
+ {envList.length} + + +
+ ), +})) + +// Mock workflow constants +vi.mock('@/app/components/workflow/constants', () => ({ + DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', + WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE', +})) + +// Mock workflow utils +vi.mock('@/app/components/workflow/utils', () => ({ + initialNodes: vi.fn(nodes => nodes), + initialEdges: vi.fn(edges => edges), + getKeyboardKeyCodeBySystem: (key: string) => key, + getKeyboardKeyNameBySystem: (key: string) => key, +})) + +// Mock Confirm component +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ title, content, isShow, onConfirm, onCancel, isLoading, isDisabled }: { + title: string + content: string + isShow: boolean + onConfirm: () => void + onCancel: () => void + isLoading?: boolean + isDisabled?: boolean + }) => isShow + ? ( +
+
{title}
+
{content}
+ + +
+ ) + : null, +})) + +// Mock Modal component +vi.mock('@/app/components/base/modal', () => ({ + default: ({ children, isShow, onClose, className }: PropsWithChildren<{ + isShow: boolean + onClose: () => void + className?: string + }>) => isShow + ? ( +
e.target === e.currentTarget && onClose()}> + {children} +
+ ) + : null, +})) + +// Mock Input component +vi.mock('@/app/components/base/input', () => ({ + default: ({ value, onChange, placeholder }: { + value: string + onChange: (e: React.ChangeEvent) => void + placeholder?: string + }) => ( + + ), +})) + +// Mock Textarea component +vi.mock('@/app/components/base/textarea', () => ({ + default: ({ value, onChange, placeholder, className }: { + value: string + onChange: (e: React.ChangeEvent) => void + placeholder?: string + className?: string + }) => ( +