diff --git a/.claude/settings.json b/.claude/settings.json index 7d42234cae..509dbe8447 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -3,6 +3,7 @@ "feature-dev@claude-plugins-official": true, "context7@claude-plugins-official": true, "typescript-lsp@claude-plugins-official": true, - "pyright-lsp@claude-plugins-official": true + "pyright-lsp@claude-plugins-official": true, + "ralph-loop@claude-plugins-official": true } } diff --git a/.claude/skills/frontend-code-review/SKILL.md b/.claude/skills/frontend-code-review/SKILL.md new file mode 100644 index 0000000000..6cc23ca171 --- /dev/null +++ b/.claude/skills/frontend-code-review/SKILL.md @@ -0,0 +1,73 @@ +--- +name: frontend-code-review +description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules." +--- + +# Frontend Code Review + +## Intent +Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes: + +1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission. +2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings. + +Stick to the checklist below for every applicable file and mode. + +## Checklist +See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow. + +Flag each rule violation with urgency metadata so future reviewers can prioritize fixes. + +## Review Process +1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling. +2. For each rule in the review point, note where the code deviates and capture a representative snippet. +3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic). + +## Required output +When invoked, the response must exactly follow one of the two templates: + +### Template A (any findings) +``` +# Code review +Found urgent issues need to be fixed: + +## 1 +FilePath: line + + + +### Suggested fix + + +--- +... (repeat for each urgent issue) ... + +Found suggestions for improvement: + +## 1 +FilePath: line + + + +### Suggested fix + + +--- + +... (repeat for each suggestion) ... +``` + +If there are no urgent issues, omit that section. If there are no suggestions, omit that section. + +If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues. + +Don't compress the blank lines between sections; keep them as-is for readability. + +If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?" + +### Template B (no issues) +``` +## Code review +No issues found. +``` + diff --git a/.claude/skills/frontend-code-review/references/business-logic.md b/.claude/skills/frontend-code-review/references/business-logic.md new file mode 100644 index 0000000000..4584f99dfc --- /dev/null +++ b/.claude/skills/frontend-code-review/references/business-logic.md @@ -0,0 +1,15 @@ +# Rule Catalog — Business Logic + +## Can't use workflowStore in Node components + +IsUrgent: True + +### Description + +File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx` + +Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason. + +### Suggested Fix + +Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`. diff --git a/.claude/skills/frontend-code-review/references/code-quality.md b/.claude/skills/frontend-code-review/references/code-quality.md new file mode 100644 index 0000000000..afdd40deb3 --- /dev/null +++ b/.claude/skills/frontend-code-review/references/code-quality.md @@ -0,0 +1,44 @@ +# Rule Catalog — Code Quality + +## Conditional class names use utility function + +IsUrgent: True +Category: Code Quality + +### Description + +Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain. + +### Suggested Fix + +```ts +import { cn } from '@/utils/classnames' +const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500') +``` + +## Tailwind-first styling + +IsUrgent: True +Category: Code Quality + +### Description + +Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead. + +Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate. + +## Classname ordering for easy overrides + +### Description + +When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles. + +Example: + +```tsx +import { cn } from '@/utils/classnames' + +const Button = ({ className }) => { + return
+} +``` diff --git a/.claude/skills/frontend-code-review/references/performance.md b/.claude/skills/frontend-code-review/references/performance.md new file mode 100644 index 0000000000..2d60072f5c --- /dev/null +++ b/.claude/skills/frontend-code-review/references/performance.md @@ -0,0 +1,45 @@ +# Rule Catalog — Performance + +## React Flow data usage + +IsUrgent: True +Category: Performance + +### Description + +When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks. + +## Complex prop memoization + +IsUrgent: True +Category: Performance + +### Description + +Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders. + +Update this file when adding, editing, or removing Performance rules so the catalog remains accurate. + +Wrong: + +```tsx + +``` + +Right: + +```tsx +const config = useMemo(() => ({ + provider: ..., + detail: ... +}), [provider, detail]); + + +``` diff --git a/.claude/skills/frontend-testing/assets/component-test.template.tsx b/.claude/skills/frontend-testing/assets/component-test.template.tsx index c39baff916..6b7803bd4b 100644 --- a/.claude/skills/frontend-testing/assets/component-test.template.tsx +++ b/.claude/skills/frontend-testing/assets/component-test.template.tsx @@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event' // i18n (automatically mocked) // WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup -// No explicit mock needed - it returns translation keys as-is +// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage +// No explicit mock needed for most tests +// // Override only if custom translations are required: -// vi.mock('react-i18next', () => ({ -// useTranslation: () => ({ -// t: (key: string) => { -// const customTranslations: Record = { -// 'my.custom.key': 'Custom Translation', -// } -// return customTranslations[key] || key -// }, -// }), +// import { createReactI18nextMock } from '@/test/i18n-mock' +// vi.mock('react-i18next', () => createReactI18nextMock({ +// 'my.custom.key': 'Custom Translation', +// 'button.save': 'Save', // })) // Router (if component uses useRouter, usePathname, useSearchParams) diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.claude/skills/frontend-testing/references/mocking.md index 23889c8d3d..c70bcf0ae5 100644 --- a/.claude/skills/frontend-testing/references/mocking.md +++ b/.claude/skills/frontend-testing/references/mocking.md @@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global ### 1. i18n (Auto-loaded via Global Mock) A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup. -**No explicit mock needed** for most tests - it returns translation keys as-is. -For tests requiring custom translations, override the mock: +The global mock provides: + +- `useTranslation` - returns translation keys with namespace prefix +- `Trans` component - renders i18nKey and components +- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`) +- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'` + +**Default behavior**: Most tests should use the global mock (no local override needed). + +**For custom translations**: Use the helper function from `@/test/i18n-mock`: ```typescript -vi.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => { - const translations: Record = { - 'my.custom.key': 'Custom translation', - } - return translations[key] || key - }, - }), +import { createReactI18nextMock } from '@/test/i18n-mock' + +vi.mock('react-i18next', () => createReactI18nextMock({ + 'my.custom.key': 'Custom translation', + 'button.save': 'Save', })) ``` +**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this. + ### 2. Next.js Router ```typescript diff --git a/.claude/skills/skill-creator/SKILL.md b/.claude/skills/skill-creator/SKILL.md new file mode 100644 index 0000000000..b49da5ac68 --- /dev/null +++ b/.claude/skills/skill-creator/SKILL.md @@ -0,0 +1,355 @@ +--- +name: skill-creator +description: Guide for creating effective skills. This skill should be used when users want to create a new skill (or update an existing skill) that extends Claude's capabilities with specialized knowledge, workflows, or tool integrations. +--- + +# Skill Creator + +This skill provides guidance for creating effective skills. + +## About Skills + +Skills are modular, self-contained packages that extend Claude's capabilities by providing +specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific +domains or tasks—they transform Claude from a general-purpose agent into a specialized agent +equipped with procedural knowledge that no model can fully possess. + +### What Skills Provide + +1. Specialized workflows - Multi-step procedures for specific domains +2. Tool integrations - Instructions for working with specific file formats or APIs +3. Domain expertise - Company-specific knowledge, schemas, business logic +4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks + +## Core Principles + +### Concise is Key + +The context window is a public good. Skills share the context window with everything else Claude needs: system prompt, conversation history, other Skills' metadata, and the actual user request. + +**Default assumption: Claude is already very smart.** Only add context Claude doesn't already have. Challenge each piece of information: "Does Claude really need this explanation?" and "Does this paragraph justify its token cost?" + +Prefer concise examples over verbose explanations. + +### Set Appropriate Degrees of Freedom + +Match the level of specificity to the task's fragility and variability: + +**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach. + +**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior. + +**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed. + +Think of Claude as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom). + +### Anatomy of a Skill + +Every skill consists of a required SKILL.md file and optional bundled resources: + +``` +skill-name/ +├── SKILL.md (required) +│ ├── YAML frontmatter metadata (required) +│ │ ├── name: (required) +│ │ └── description: (required) +│ └── Markdown instructions (required) +└── Bundled Resources (optional) + ├── scripts/ - Executable code (Python/Bash/etc.) + ├── references/ - Documentation intended to be loaded into context as needed + └── assets/ - Files used in output (templates, icons, fonts, etc.) +``` + +#### SKILL.md (required) + +Every SKILL.md consists of: + +- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that Claude reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used. +- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all). + +#### Bundled Resources (optional) + +##### Scripts (`scripts/`) + +Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten. + +- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed +- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks +- **Benefits**: Token efficient, deterministic, may be executed without loading into context +- **Note**: Scripts may still need to be read by Claude for patching or environment-specific adjustments + +##### References (`references/`) + +Documentation and reference material intended to be loaded as needed into context to inform Claude's process and thinking. + +- **When to include**: For documentation that Claude should reference while working +- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications +- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides +- **Benefits**: Keeps SKILL.md lean, loaded only when Claude determines it's needed +- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md +- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files. + +##### Assets (`assets/`) + +Files not intended to be loaded into context, but rather used within the output Claude produces. + +- **When to include**: When the skill needs files that will be used in the final output +- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography +- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified +- **Benefits**: Separates output resources from documentation, enables Claude to use files without loading them into context + +#### What to Not Include in a Skill + +A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including: + +- README.md +- INSTALLATION_GUIDE.md +- QUICK_REFERENCE.md +- CHANGELOG.md +- etc. + +The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxilary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion. + +### Progressive Disclosure Design Principle + +Skills use a three-level loading system to manage context efficiently: + +1. **Metadata (name + description)** - Always in context (~100 words) +2. **SKILL.md body** - When skill triggers (<5k words) +3. **Bundled resources** - As needed by Claude (Unlimited because scripts can be executed without reading into context window) + +#### Progressive Disclosure Patterns + +Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them. + +**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files. + +**Pattern 1: High-level guide with references** + +```markdown +# PDF Processing + +## Quick start + +Extract text with pdfplumber: +[code example] + +## Advanced features + +- **Form filling**: See [FORMS.md](FORMS.md) for complete guide +- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods +- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns +``` + +Claude loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed. + +**Pattern 2: Domain-specific organization** + +For Skills with multiple domains, organize content by domain to avoid loading irrelevant context: + +``` +bigquery-skill/ +├── SKILL.md (overview and navigation) +└── reference/ + ├── finance.md (revenue, billing metrics) + ├── sales.md (opportunities, pipeline) + ├── product.md (API usage, features) + └── marketing.md (campaigns, attribution) +``` + +When a user asks about sales metrics, Claude only reads sales.md. + +Similarly, for skills supporting multiple frameworks or variants, organize by variant: + +``` +cloud-deploy/ +├── SKILL.md (workflow + provider selection) +└── references/ + ├── aws.md (AWS deployment patterns) + ├── gcp.md (GCP deployment patterns) + └── azure.md (Azure deployment patterns) +``` + +When the user chooses AWS, Claude only reads aws.md. + +**Pattern 3: Conditional details** + +Show basic content, link to advanced content: + +```markdown +# DOCX Processing + +## Creating documents + +Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md). + +## Editing documents + +For simple edits, modify the XML directly. + +**For tracked changes**: See [REDLINING.md](REDLINING.md) +**For OOXML details**: See [OOXML.md](OOXML.md) +``` + +Claude reads REDLINING.md or OOXML.md only when the user needs those features. + +**Important guidelines:** + +- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md. +- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so Claude can see the full scope when previewing. + +## Skill Creation Process + +Skill creation involves these steps: + +1. Understand the skill with concrete examples +2. Plan reusable skill contents (scripts, references, assets) +3. Initialize the skill (run init_skill.py) +4. Edit the skill (implement resources and write SKILL.md) +5. Package the skill (run package_skill.py) +6. Iterate based on real usage + +Follow these steps in order, skipping only if there is a clear reason why they are not applicable. + +### Step 1: Understanding the Skill with Concrete Examples + +Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill. + +To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback. + +For example, when building an image-editor skill, relevant questions include: + +- "What functionality should the image-editor skill support? Editing, rotating, anything else?" +- "Can you give some examples of how this skill would be used?" +- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?" +- "What would a user say that should trigger this skill?" + +To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness. + +Conclude this step when there is a clear sense of the functionality the skill should support. + +### Step 2: Planning the Reusable Skill Contents + +To turn concrete examples into an effective skill, analyze each example by: + +1. Considering how to execute on the example from scratch +2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly + +Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows: + +1. Rotating a PDF requires re-writing the same code each time +2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill + +Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows: + +1. Writing a frontend webapp requires the same boilerplate HTML/React each time +2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill + +Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows: + +1. Querying BigQuery requires re-discovering the table schemas and relationships each time +2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill + +To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets. + +### Step 3: Initializing the Skill + +At this point, it is time to actually create the skill. + +Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step. + +When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable. + +Usage: + +```bash +scripts/init_skill.py --path +``` + +The script: + +- Creates the skill directory at the specified path +- Generates a SKILL.md template with proper frontmatter and TODO placeholders +- Creates example resource directories: `scripts/`, `references/`, and `assets/` +- Adds example files in each directory that can be customized or deleted + +After initialization, customize or remove the generated SKILL.md and example files as needed. + +### Step 4: Edit the Skill + +When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of Claude to use. Include information that would be beneficial and non-obvious to Claude. Consider what procedural knowledge, domain-specific details, or reusable assets would help another Claude instance execute these tasks more effectively. + +#### Learn Proven Design Patterns + +Consult these helpful guides based on your skill's needs: + +- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic +- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns + +These files contain established best practices for effective skill design. + +#### Start with Reusable Skill Contents + +To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`. + +Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion. + +Any example files and directories not needed for the skill should be deleted. The initialization script creates example files in `scripts/`, `references/`, and `assets/` to demonstrate structure, but most skills won't need all of them. + +#### Update SKILL.md + +**Writing Guidelines:** Always use imperative/infinitive form. + +##### Frontmatter + +Write the YAML frontmatter with `name` and `description`: + +- `name`: The skill name +- `description`: This is the primary triggering mechanism for your skill, and helps Claude understand when to use the skill. + - Include both what the Skill does and specific triggers/contexts for when to use it. + - Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Claude. + - Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Claude needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks" + +Do not include any other fields in YAML frontmatter. + +##### Body + +Write instructions for using the skill and its bundled resources. + +### Step 5: Packaging a Skill + +Once development of the skill is complete, it must be packaged into a distributable .skill file that gets shared with the user. The packaging process automatically validates the skill first to ensure it meets all requirements: + +```bash +scripts/package_skill.py +``` + +Optional output directory specification: + +```bash +scripts/package_skill.py ./dist +``` + +The packaging script will: + +1. **Validate** the skill automatically, checking: + + - YAML frontmatter format and required fields + - Skill naming conventions and directory structure + - Description completeness and quality + - File organization and resource references + +2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension. + +If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again. + +### Step 6: Iterate + +After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed. + +**Iteration workflow:** + +1. Use the skill on real tasks +2. Notice struggles or inefficiencies +3. Identify how SKILL.md or bundled resources should be updated +4. Implement changes and test again diff --git a/.claude/skills/skill-creator/references/output-patterns.md b/.claude/skills/skill-creator/references/output-patterns.md new file mode 100644 index 0000000000..022e85fe5e --- /dev/null +++ b/.claude/skills/skill-creator/references/output-patterns.md @@ -0,0 +1,86 @@ +# Output Patterns + +Use these patterns when skills need to produce consistent, high-quality output. + +## Template Pattern + +Provide templates for output format. Match the level of strictness to your needs. + +**For strict requirements (like API responses or data formats):** + +```markdown +## Report structure + +ALWAYS use this exact template structure: + +# [Analysis Title] + +## Executive summary +[One-paragraph overview of key findings] + +## Key findings +- Finding 1 with supporting data +- Finding 2 with supporting data +- Finding 3 with supporting data + +## Recommendations +1. Specific actionable recommendation +2. Specific actionable recommendation +``` + +**For flexible guidance (when adaptation is useful):** + +```markdown +## Report structure + +Here is a sensible default format, but use your best judgment: + +# [Analysis Title] + +## Executive summary +[Overview] + +## Key findings +[Adapt sections based on what you discover] + +## Recommendations +[Tailor to the specific context] + +Adjust sections as needed for the specific analysis type. +``` + +## Examples Pattern + +For skills where output quality depends on seeing examples, provide input/output pairs: + +```markdown +## Commit message format + +Generate commit messages following these examples: + +**Example 1:** +Input: Added user authentication with JWT tokens +Output: +``` + +feat(auth): implement JWT-based authentication + +Add login endpoint and token validation middleware + +``` + +**Example 2:** +Input: Fixed bug where dates displayed incorrectly in reports +Output: +``` + +fix(reports): correct date formatting in timezone conversion + +Use UTC timestamps consistently across report generation + +``` + +Follow this style: type(scope): brief description, then detailed explanation. +``` + +Examples help Claude understand the desired style and level of detail more clearly than descriptions alone. diff --git a/.claude/skills/skill-creator/references/workflows.md b/.claude/skills/skill-creator/references/workflows.md new file mode 100644 index 0000000000..54b0174078 --- /dev/null +++ b/.claude/skills/skill-creator/references/workflows.md @@ -0,0 +1,28 @@ +# Workflow Patterns + +## Sequential Workflows + +For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md: + +```markdown +Filling a PDF form involves these steps: + +1. Analyze the form (run analyze_form.py) +2. Create field mapping (edit fields.json) +3. Validate mapping (run validate_fields.py) +4. Fill the form (run fill_form.py) +5. Verify output (run verify_output.py) +``` + +## Conditional Workflows + +For tasks with branching logic, guide Claude through decision points: + +```markdown +1. Determine the modification type: + **Creating new content?** → Follow "Creation workflow" below + **Editing existing content?** → Follow "Editing workflow" below + +2. Creation workflow: [steps] +3. Editing workflow: [steps] +``` diff --git a/.claude/skills/skill-creator/scripts/init_skill.py b/.claude/skills/skill-creator/scripts/init_skill.py new file mode 100755 index 0000000000..249fffcbbd --- /dev/null +++ b/.claude/skills/skill-creator/scripts/init_skill.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +""" +Skill Initializer - Creates a new skill from template + +Usage: + init_skill.py --path + +Examples: + init_skill.py my-new-skill --path skills/public + init_skill.py my-api-helper --path skills/private + init_skill.py custom-skill --path /custom/location +""" + +import sys +from pathlib import Path + + +SKILL_TEMPLATE = """--- +name: {skill_name} +description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.] +--- + +# {skill_title} + +## Overview + +[TODO: 1-2 sentences explaining what this skill enables] + +## Structuring This Skill + +[TODO: Choose the structure that best fits this skill's purpose. Common patterns: + +**1. Workflow-Based** (best for sequential processes) +- Works well when there are clear step-by-step procedures +- Example: DOCX skill with "Workflow Decision Tree" → "Reading" → "Creating" → "Editing" +- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2... + +**2. Task-Based** (best for tool collections) +- Works well when the skill offers different operations/capabilities +- Example: PDF skill with "Quick Start" → "Merge PDFs" → "Split PDFs" → "Extract Text" +- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2... + +**3. Reference/Guidelines** (best for standards or specifications) +- Works well for brand guidelines, coding standards, or requirements +- Example: Brand styling with "Brand Guidelines" → "Colors" → "Typography" → "Features" +- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage... + +**4. Capabilities-Based** (best for integrated systems) +- Works well when the skill provides multiple interrelated features +- Example: Product Management with "Core Capabilities" → numbered capability list +- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature... + +Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations). + +Delete this entire "Structuring This Skill" section when done - it's just guidance.] + +## [TODO: Replace with the first main section based on chosen structure] + +[TODO: Add content here. See examples in existing skills: +- Code samples for technical skills +- Decision trees for complex workflows +- Concrete examples with realistic user requests +- References to scripts/templates/references as needed] + +## Resources + +This skill includes example resource directories that demonstrate how to organize different types of bundled resources: + +### scripts/ +Executable code (Python/Bash/etc.) that can be run directly to perform specific operations. + +**Examples from other skills:** +- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation +- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing + +**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations. + +**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments. + +### references/ +Documentation and reference material intended to be loaded into context to inform Claude's process and thinking. + +**Examples from other skills:** +- Product management: `communication.md`, `context_building.md` - detailed workflow guides +- BigQuery: API reference documentation and query examples +- Finance: Schema documentation, company policies + +**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working. + +### assets/ +Files not intended to be loaded into context, but rather used within the output Claude produces. + +**Examples from other skills:** +- Brand styling: PowerPoint template files (.pptx), logo files +- Frontend builder: HTML/React boilerplate project directories +- Typography: Font files (.ttf, .woff2) + +**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output. + +--- + +**Any unneeded directories can be deleted.** Not every skill requires all three types of resources. +""" + +EXAMPLE_SCRIPT = '''#!/usr/bin/env python3 +""" +Example helper script for {skill_name} + +This is a placeholder script that can be executed directly. +Replace with actual implementation or delete if not needed. + +Example real scripts from other skills: +- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields +- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images +""" + +def main(): + print("This is an example script for {skill_name}") + # TODO: Add actual script logic here + # This could be data processing, file conversion, API calls, etc. + +if __name__ == "__main__": + main() +''' + +EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title} + +This is a placeholder for detailed reference documentation. +Replace with actual reference content or delete if not needed. + +Example real reference docs from other skills: +- product-management/references/communication.md - Comprehensive guide for status updates +- product-management/references/context_building.md - Deep-dive on gathering context +- bigquery/references/ - API references and query examples + +## When Reference Docs Are Useful + +Reference docs are ideal for: +- Comprehensive API documentation +- Detailed workflow guides +- Complex multi-step processes +- Information too lengthy for main SKILL.md +- Content that's only needed for specific use cases + +## Structure Suggestions + +### API Reference Example +- Overview +- Authentication +- Endpoints with examples +- Error codes +- Rate limits + +### Workflow Guide Example +- Prerequisites +- Step-by-step instructions +- Common patterns +- Troubleshooting +- Best practices +""" + +EXAMPLE_ASSET = """# Example Asset File + +This placeholder represents where asset files would be stored. +Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed. + +Asset files are NOT intended to be loaded into context, but rather used within +the output Claude produces. + +Example asset files from other skills: +- Brand guidelines: logo.png, slides_template.pptx +- Frontend builder: hello-world/ directory with HTML/React boilerplate +- Typography: custom-font.ttf, font-family.woff2 +- Data: sample_data.csv, test_dataset.json + +## Common Asset Types + +- Templates: .pptx, .docx, boilerplate directories +- Images: .png, .jpg, .svg, .gif +- Fonts: .ttf, .otf, .woff, .woff2 +- Boilerplate code: Project directories, starter files +- Icons: .ico, .svg +- Data files: .csv, .json, .xml, .yaml + +Note: This is a text placeholder. Actual assets can be any file type. +""" + + +def title_case_skill_name(skill_name): + """Convert hyphenated skill name to Title Case for display.""" + return " ".join(word.capitalize() for word in skill_name.split("-")) + + +def init_skill(skill_name, path): + """ + Initialize a new skill directory with template SKILL.md. + + Args: + skill_name: Name of the skill + path: Path where the skill directory should be created + + Returns: + Path to created skill directory, or None if error + """ + # Determine skill directory path + skill_dir = Path(path).resolve() / skill_name + + # Check if directory already exists + if skill_dir.exists(): + print(f"❌ Error: Skill directory already exists: {skill_dir}") + return None + + # Create skill directory + try: + skill_dir.mkdir(parents=True, exist_ok=False) + print(f"✅ Created skill directory: {skill_dir}") + except Exception as e: + print(f"❌ Error creating directory: {e}") + return None + + # Create SKILL.md from template + skill_title = title_case_skill_name(skill_name) + skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title) + + skill_md_path = skill_dir / "SKILL.md" + try: + skill_md_path.write_text(skill_content) + print("✅ Created SKILL.md") + except Exception as e: + print(f"❌ Error creating SKILL.md: {e}") + return None + + # Create resource directories with example files + try: + # Create scripts/ directory with example script + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir(exist_ok=True) + example_script = scripts_dir / "example.py" + example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name)) + example_script.chmod(0o755) + print("✅ Created scripts/example.py") + + # Create references/ directory with example reference doc + references_dir = skill_dir / "references" + references_dir.mkdir(exist_ok=True) + example_reference = references_dir / "api_reference.md" + example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title)) + print("✅ Created references/api_reference.md") + + # Create assets/ directory with example asset placeholder + assets_dir = skill_dir / "assets" + assets_dir.mkdir(exist_ok=True) + example_asset = assets_dir / "example_asset.txt" + example_asset.write_text(EXAMPLE_ASSET) + print("✅ Created assets/example_asset.txt") + except Exception as e: + print(f"❌ Error creating resource directories: {e}") + return None + + # Print next steps + print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}") + print("\nNext steps:") + print("1. Edit SKILL.md to complete the TODO items and update the description") + print("2. Customize or delete the example files in scripts/, references/, and assets/") + print("3. Run the validator when ready to check the skill structure") + + return skill_dir + + +def main(): + if len(sys.argv) < 4 or sys.argv[2] != "--path": + print("Usage: init_skill.py --path ") + print("\nSkill name requirements:") + print(" - Hyphen-case identifier (e.g., 'data-analyzer')") + print(" - Lowercase letters, digits, and hyphens only") + print(" - Max 40 characters") + print(" - Must match directory name exactly") + print("\nExamples:") + print(" init_skill.py my-new-skill --path skills/public") + print(" init_skill.py my-api-helper --path skills/private") + print(" init_skill.py custom-skill --path /custom/location") + sys.exit(1) + + skill_name = sys.argv[1] + path = sys.argv[3] + + print(f"🚀 Initializing skill: {skill_name}") + print(f" Location: {path}") + print() + + result = init_skill(skill_name, path) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.claude/skills/skill-creator/scripts/package_skill.py b/.claude/skills/skill-creator/scripts/package_skill.py new file mode 100755 index 0000000000..736b928be0 --- /dev/null +++ b/.claude/skills/skill-creator/scripts/package_skill.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +Skill Packager - Creates a distributable .skill file of a skill folder + +Usage: + python utils/package_skill.py [output-directory] + +Example: + python utils/package_skill.py skills/public/my-skill + python utils/package_skill.py skills/public/my-skill ./dist +""" + +import sys +import zipfile +from pathlib import Path +from quick_validate import validate_skill + + +def package_skill(skill_path, output_dir=None): + """ + Package a skill folder into a .skill file. + + Args: + skill_path: Path to the skill folder + output_dir: Optional output directory for the .skill file (defaults to current directory) + + Returns: + Path to the created .skill file, or None if error + """ + skill_path = Path(skill_path).resolve() + + # Validate skill folder exists + if not skill_path.exists(): + print(f"❌ Error: Skill folder not found: {skill_path}") + return None + + if not skill_path.is_dir(): + print(f"❌ Error: Path is not a directory: {skill_path}") + return None + + # Validate SKILL.md exists + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + print(f"❌ Error: SKILL.md not found in {skill_path}") + return None + + # Run validation before packaging + print("🔍 Validating skill...") + valid, message = validate_skill(skill_path) + if not valid: + print(f"❌ Validation failed: {message}") + print(" Please fix the validation errors before packaging.") + return None + print(f"✅ {message}\n") + + # Determine output location + skill_name = skill_path.name + if output_dir: + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + else: + output_path = Path.cwd() + + skill_filename = output_path / f"{skill_name}.skill" + + # Create the .skill file (zip format) + try: + with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf: + # Walk through the skill directory + for file_path in skill_path.rglob("*"): + if file_path.is_file(): + # Calculate the relative path within the zip + arcname = file_path.relative_to(skill_path.parent) + zipf.write(file_path, arcname) + print(f" Added: {arcname}") + + print(f"\n✅ Successfully packaged skill to: {skill_filename}") + return skill_filename + + except Exception as e: + print(f"❌ Error creating .skill file: {e}") + return None + + +def main(): + if len(sys.argv) < 2: + print("Usage: python utils/package_skill.py [output-directory]") + print("\nExample:") + print(" python utils/package_skill.py skills/public/my-skill") + print(" python utils/package_skill.py skills/public/my-skill ./dist") + sys.exit(1) + + skill_path = sys.argv[1] + output_dir = sys.argv[2] if len(sys.argv) > 2 else None + + print(f"📦 Packaging skill: {skill_path}") + if output_dir: + print(f" Output directory: {output_dir}") + print() + + result = package_skill(skill_path, output_dir) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/.claude/skills/skill-creator/scripts/quick_validate.py b/.claude/skills/skill-creator/scripts/quick_validate.py new file mode 100755 index 0000000000..66eb0a71bf --- /dev/null +++ b/.claude/skills/skill-creator/scripts/quick_validate.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Quick validation script for skills - minimal version +""" + +import sys +import os +import re +import yaml +from pathlib import Path + + +def validate_skill(skill_path): + """Basic validation of a skill""" + skill_path = Path(skill_path) + + # Check SKILL.md exists + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + return False, "SKILL.md not found" + + # Read and validate frontmatter + content = skill_md.read_text() + if not content.startswith("---"): + return False, "No YAML frontmatter found" + + # Extract frontmatter + match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) + if not match: + return False, "Invalid frontmatter format" + + frontmatter_text = match.group(1) + + # Parse YAML frontmatter + try: + frontmatter = yaml.safe_load(frontmatter_text) + if not isinstance(frontmatter, dict): + return False, "Frontmatter must be a YAML dictionary" + except yaml.YAMLError as e: + return False, f"Invalid YAML in frontmatter: {e}" + + # Define allowed properties + ALLOWED_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata"} + + # Check for unexpected properties (excluding nested keys under metadata) + unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES + if unexpected_keys: + return False, ( + f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. " + f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}" + ) + + # Check required fields + if "name" not in frontmatter: + return False, "Missing 'name' in frontmatter" + if "description" not in frontmatter: + return False, "Missing 'description' in frontmatter" + + # Extract name for validation + name = frontmatter.get("name", "") + if not isinstance(name, str): + return False, f"Name must be a string, got {type(name).__name__}" + name = name.strip() + if name: + # Check naming convention (hyphen-case: lowercase with hyphens) + if not re.match(r"^[a-z0-9-]+$", name): + return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)" + if name.startswith("-") or name.endswith("-") or "--" in name: + return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens" + # Check name length (max 64 characters per spec) + if len(name) > 64: + return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters." + + # Extract and validate description + description = frontmatter.get("description", "") + if not isinstance(description, str): + return False, f"Description must be a string, got {type(description).__name__}" + description = description.strip() + if description: + # Check for angle brackets + if "<" in description or ">" in description: + return False, "Description cannot contain angle brackets (< or >)" + # Check description length (max 1024 characters per spec) + if len(description) > 1024: + return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters." + + return True, "Skill is valid!" + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python quick_validate.py ") + sys.exit(1) + + valid, message = validate_skill(sys.argv[1]) + print(message) + sys.exit(0 if valid else 1) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index aa5a50918a..50dbde2aee 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -20,4 +20,4 @@ - [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) - [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. - [x] I've updated the documentation accordingly. -- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods +- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index d463349686..462ece303e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -110,6 +110,16 @@ jobs: working-directory: ./web run: pnpm run type-check:tsgo + - name: Web dead code check + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm run knip + + - name: Web build check + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm run build + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index a51350f630..16d36361fd 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -5,6 +5,7 @@ on: branches: [main] paths: - 'web/i18n/en-US/*.json' + workflow_dispatch: permissions: contents: write @@ -18,7 +19,8 @@ jobs: run: working-directory: web steps: - - uses: actions/checkout@v6 + # Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272 + - uses: actions/checkout@v4 with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} @@ -26,21 +28,28 @@ jobs: - name: Check for file changes in i18n/en-US id: check_files run: | - git fetch origin "${{ github.event.before }}" || true - git fetch origin "${{ github.sha }}" || true - changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json') - echo "Changed files: $changed_files" - if [ -n "$changed_files" ]; then + # Skip check for manual trigger, translate all files + if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then echo "FILES_CHANGED=true" >> $GITHUB_ENV - file_args="" - for file in $changed_files; do - filename=$(basename "$file" .json) - file_args="$file_args --file $filename" - done - echo "FILE_ARGS=$file_args" >> $GITHUB_ENV - echo "File arguments: $file_args" + echo "FILE_ARGS=" >> $GITHUB_ENV + echo "Manual trigger: translating all files" else - echo "FILES_CHANGED=false" >> $GITHUB_ENV + git fetch origin "${{ github.event.before }}" || true + git fetch origin "${{ github.sha }}" || true + changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json') + echo "Changed files: $changed_files" + if [ -n "$changed_files" ]; then + echo "FILES_CHANGED=true" >> $GITHUB_ENV + file_args="" + for file in $changed_files; do + filename=$(basename "$file" .json) + file_args="$file_args --file $filename" + done + echo "FILE_ARGS=$file_args" >> $GITHUB_ENV + echo "File arguments: $file_args" + else + echo "FILES_CHANGED=false" >> $GITHUB_ENV + fi fi - name: Install pnpm diff --git a/.gitignore b/.gitignore index 17a2bd5b7b..7bd919f095 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,4 @@ scripts/stress-test/reports/ # settings *.local.json +*.local.md diff --git a/Makefile b/Makefile index 07afd8187e..60c32948b9 100644 --- a/Makefile +++ b/Makefile @@ -60,9 +60,10 @@ check: @echo "✅ Code check complete" lint: - @echo "🔧 Running ruff format, check with fixes, and import linter..." + @echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..." @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api' @uv run --directory api --dev lint-imports + @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example @echo "✅ Linting complete" type-check: @@ -122,7 +123,7 @@ help: @echo "Backend Code Quality:" @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" - @echo " make lint - Format and fix code with ruff" + @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make type-check - Run type checking with basedpyright" @echo " make test - Run backend unit tests" @echo "" diff --git a/api/.env.example b/api/.env.example index 99cd2ba558..44d770ed70 100644 --- a/api/.env.example +++ b/api/.env.example @@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key S3_REGION=your-region +# Workflow run and Conversation archive storage (S3-compatible) +ARCHIVE_STORAGE_ENABLED=false +ARCHIVE_STORAGE_ENDPOINT= +ARCHIVE_STORAGE_ARCHIVE_BUCKET= +ARCHIVE_STORAGE_EXPORT_BUCKET= +ARCHIVE_STORAGE_ACCESS_KEY= +ARCHIVE_STORAGE_SECRET_KEY= +ARCHIVE_STORAGE_REGION=auto + # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key @@ -493,6 +502,8 @@ LOG_FILE_BACKUP_COUNT=5 LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S # Log Timezone LOG_TZ=UTC +# Log output format: text or json +LOG_OUTPUT_FORMAT=text # Log format LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s @@ -564,6 +575,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # Celery beat configuration CELERY_BEAT_SCHEDULER_TIME=1 diff --git a/api/.importlinter b/api/.importlinter index 24ece72b30..2dec958788 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -3,9 +3,11 @@ root_packages = core configs controllers + extensions models tasks services +include_external_packages = True [importlinter:contract:workflow] name = Workflow @@ -33,6 +35,28 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels +[importlinter:contract:workflow-infrastructure-dependencies] +name = Workflow Infrastructure Dependencies +type = forbidden +source_modules = + core.workflow +forbidden_modules = + extensions.ext_database + extensions.ext_redis +allow_indirect_imports = True +ignore_imports = + core.workflow.nodes.agent.agent_node -> extensions.ext_database + core.workflow.nodes.datasource.datasource_node -> extensions.ext_database + core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database + core.workflow.nodes.llm.file_saver -> extensions.ext_database + core.workflow.nodes.llm.llm_utils -> extensions.ext_database + core.workflow.nodes.llm.node -> extensions.ext_database + core.workflow.nodes.tool.tool_node -> extensions.ext_database + core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis + core.workflow.graph_engine.manager -> extensions.ext_redis + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/.ruff.toml b/api/.ruff.toml index 7206f7fa0f..8db0cbcb21 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,4 +1,8 @@ -exclude = ["migrations/*"] +exclude = [ + "migrations/*", + ".git", + ".git/**", +] line-length = 120 [format] diff --git a/api/Dockerfile b/api/Dockerfile index 02df91bfc1..a08d4e3aab 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -50,16 +50,33 @@ WORKDIR /app/api # Create non-root user ARG dify_uid=1001 +ARG NODE_MAJOR=22 +ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1 +ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4 RUN groupadd -r -g ${dify_uid} dify && \ useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \ chown -R dify:dify /app RUN \ apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + gnupg \ + && mkdir -p /etc/apt/keyrings \ + && curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \ + && gpg --show-keys --with-colons /tmp/nodesource.gpg \ + | awk -F: '/^fpr:/ {print $10}' \ + | grep -Fx "${NODESOURCE_KEY_FPR}" \ + && gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \ + && rm -f /tmp/nodesource.gpg \ + && echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \ + > /etc/apt/sources.list.d/nodesource.list \ + && apt-get update \ # Install dependencies && apt-get install -y --no-install-recommends \ # basic environment - curl nodejs \ + nodejs=${NODE_PACKAGE_VERSION} \ # for gmpy2 \ libgmp-dev libmpfr-dev libmpc-dev \ # For Security @@ -79,7 +96,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data -RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \ +RUN mkdir -p /usr/local/share/nltk_data \ + && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \ && chmod -R 755 /usr/local/share/nltk_data ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache diff --git a/api/app_factory.py b/api/app_factory.py index bcad88e9e0..f827842d68 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -2,9 +2,11 @@ import logging import time from opentelemetry.trace import get_current_span +from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID from configs import dify_config from contexts.wrapper import RecyclableContextVar +from core.logging.context import init_request_context from dify_app import DifyApp logger = logging.getLogger(__name__) @@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp: # add before request hook @dify_app.before_request def before_request(): - # add an unique identifier to each request + # Initialize logging context for this request + init_request_context() RecyclableContextVar.increment_thread_recycles() - # add after request hook for injecting X-Trace-Id header from OpenTelemetry span context + # add after request hook for injecting trace headers from OpenTelemetry span context + # Only adds headers when OTEL is enabled and has valid context @dify_app.after_request - def add_trace_id_header(response): + def add_trace_headers(response): try: span = get_current_span() ctx = span.get_span_context() if span else None - if ctx and ctx.is_valid: - trace_id_hex = format(ctx.trace_id, "032x") - # Avoid duplicates if some middleware added it - if "X-Trace-Id" not in response.headers: - response.headers["X-Trace-Id"] = trace_id_hex + + if not ctx or not ctx.is_valid: + return response + + # Inject trace headers from OTEL context + if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers: + response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x") + if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers: + response.headers["X-Span-Id"] = format(ctx.span_id, "016x") + except Exception: # Never break the response due to tracing header injection - logger.warning("Failed to add trace ID to response header", exc_info=True) + logger.warning("Failed to add trace headers to response", exc_info=True) return response # Capture the decorator's return value to avoid pyright reportUnusedFunction _ = before_request - _ = add_trace_id_header + _ = add_trace_headers return dify_app diff --git a/api/commands.py b/api/commands.py index a8d89ac200..7ebf5b4874 100644 --- a/api/commands.py +++ b/api/commands.py @@ -235,7 +235,7 @@ def migrate_annotation_vector_database(): if annotations: for annotation in annotations: document = Document( - page_content=annotation.question, + page_content=annotation.question_text, metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) @@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) +@click.command("file-usage", help="Query file usages and show where files are referenced.") +@click.option("--file-id", type=str, default=None, help="Filter by file UUID.") +@click.option("--key", type=str, default=None, help="Filter by storage key.") +@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').") +@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).") +@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).") +@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.") +def file_usage( + file_id: str | None, + key: str | None, + src: str | None, + limit: int, + offset: int, + output_json: bool, +): + """ + Query file usages and show where files are referenced in the database. + + This command reuses the same reference checking logic as clear-orphaned-file-records + and displays detailed information about where each file is referenced. + """ + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"}, + {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"}, + {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"}, + {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"}, + {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"}, + {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"}, + {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "message", "pk_column": "id"}, + ] + + # Stream file usages with pagination to avoid holding all results in memory + paginated_usages = [] + total_count = 0 + + # First, build a mapping of file_id -> storage_key from the base tables + file_key_map = {} + for files_table in files_tables: + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}" + + # If filtering by key or file_id, verify it exists + if file_id and file_id not in file_key_map: + if output_json: + click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"})) + else: + click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red")) + return + + if key: + valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"} + matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes] + if not matching_file_ids: + if output_json: + click.echo(json.dumps({"error": f"Key {key} not found in base tables"})) + else: + click.echo(click.style(f"Key {key} not found in base tables.", fg="red")) + return + + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + + # For each reference table/column, find matching file IDs and record the references + for ids_table in ids_tables: + src_filter = f"{ids_table['table']}.{ids_table['column']}" + + # Skip if src filter doesn't match (use fnmatch for wildcard patterns) + if src: + if "%" in src or "_" in src: + import fnmatch + + # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?) + pattern = src.replace("%", "*").replace("_", "?") + if not fnmatch.fnmatch(src_filter, pattern): + continue + else: + if src_filter != src: + continue + + if ids_table["type"] == "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + elif ids_table["type"] in ("text", "json"): + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + # Output results + if output_json: + result = { + "total": total_count, + "offset": offset, + "limit": limit, + "usages": paginated_usages, + } + click.echo(json.dumps(result, indent=2)) + else: + click.echo( + click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white") + ) + click.echo("") + + if not paginated_usages: + click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow")) + return + + # Print table header + click.echo( + click.style( + f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}", + fg="cyan", + ) + ) + click.echo(click.style("-" * 190, fg="white")) + + # Print each usage + for usage in paginated_usages: + click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}") + + # Show pagination info + if offset + limit < total_count: + click.echo("") + click.echo( + click.style( + f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white" + ) + ) + click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white")) + + @click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") @click.option("--provider", prompt=True, help="Provider name") @click.option("--client-params", prompt=True, help="Client Params") diff --git a/api/configs/extra/__init__.py b/api/configs/extra/__init__.py index 4543b5389d..de97adfc0e 100644 --- a/api/configs/extra/__init__.py +++ b/api/configs/extra/__init__.py @@ -1,9 +1,11 @@ +from configs.extra.archive_config import ArchiveStorageConfig from configs.extra.notion_config import NotionConfig from configs.extra.sentry_config import SentryConfig class ExtraServiceConfig( # place the configs in alphabet order + ArchiveStorageConfig, NotionConfig, SentryConfig, ): diff --git a/api/configs/extra/archive_config.py b/api/configs/extra/archive_config.py new file mode 100644 index 0000000000..a85628fa61 --- /dev/null +++ b/api/configs/extra/archive_config.py @@ -0,0 +1,43 @@ +from pydantic import Field +from pydantic_settings import BaseSettings + + +class ArchiveStorageConfig(BaseSettings): + """ + Configuration settings for workflow run logs archiving storage. + """ + + ARCHIVE_STORAGE_ENABLED: bool = Field( + description="Enable workflow run logs archiving to S3-compatible storage", + default=False, + ) + + ARCHIVE_STORAGE_ENDPOINT: str | None = Field( + description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')", + default=None, + ) + + ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field( + description="Name of the bucket to store archived workflow logs", + default=None, + ) + + ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field( + description="Name of the bucket to store exported workflow runs", + default=None, + ) + + ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field( + description="Access key ID for authenticating with storage", + default=None, + ) + + ARCHIVE_STORAGE_SECRET_KEY: str | None = Field( + description="Secret access key for authenticating with storage", + default=None, + ) + + ARCHIVE_STORAGE_REGION: str = Field( + description="Region for storage (use 'auto' if the provider supports it)", + default="auto", + ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 43dddbd011..6a04171d2d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings): default="INFO", ) + LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field( + description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.", + default="text", + ) + LOG_FILE: str | None = Field( description="File path for log output.", default=None, diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 05cee51cc9..eb9b0ac2ab 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings): description="Authentication token for Milvus, if token-based authentication is enabled", default=None, ) - MILVUS_USER: str | None = Field( description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index df9de825de..c16a23fac8 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,62 +1,59 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from libs.helper import AppIconUrlField +from typing import Any, TypeAlias -parameters__system_parameters = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - "workflow_file_upload_limit": fields.Integer, -} +from pydantic import BaseModel, ConfigDict, computed_field + +from core.file import helpers as file_helpers +from models.model import IconType + +JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] +JSONObject: TypeAlias = dict[str, Any] -def build_system_parameters_model(api_or_ns: Api | Namespace): - """Build the system parameters model for the API or Namespace.""" - return api_or_ns.model("SystemParameters", parameters__system_parameters) +class SystemParameters(BaseModel): + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + file_size_limit: int + workflow_file_upload_limit: int -parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(parameters__system_parameters), -} +class Parameters(BaseModel): + opening_statement: str | None = None + suggested_questions: list[str] + suggested_questions_after_answer: JSONObject + speech_to_text: JSONObject + text_to_speech: JSONObject + retriever_resource: JSONObject + annotation_reply: JSONObject + more_like_this: JSONObject + user_input_form: list[JSONObject] + sensitive_word_avoidance: JSONObject + file_upload: JSONObject + system_parameters: SystemParameters -def build_parameters_model(api_or_ns: Api | Namespace): - """Build the parameters model for the API or Namespace.""" - copied_fields = parameters_fields.copy() - copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns)) - return api_or_ns.model("Parameters", copied_fields) +class Site(BaseModel): + model_config = ConfigDict(from_attributes=True) + title: str + chat_color_theme: str | None = None + chat_color_theme_inverted: bool + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + default_language: str + show_workflow_steps: bool + use_icon_as_answer_icon: bool -site_fields = { - "title": fields.String, - "chat_color_theme": fields.String, - "chat_color_theme_inverted": fields.Boolean, - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "description": fields.String, - "copyright": fields.String, - "privacy_policy": fields.String, - "custom_disclaimer": fields.String, - "default_language": fields.String, - "show_workflow_steps": fields.Boolean, - "use_icon_as_answer_icon": fields.Boolean, -} - - -def build_site_model(api_or_ns: Api | Namespace): - """Build the site model for the API or Namespace.""" - return api_or_ns.model("Site", site_fields) + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + if self.icon and self.icon_type == IconType.IMAGE: + return file_helpers.get_signed_file_url(self.icon) + return None diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 62e997dae2..d66bb7063f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,13 +1,16 @@ +import re import uuid -from typing import Literal +from datetime import datetime +from typing import Any, Literal, TypeAlias from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -18,27 +21,19 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, ) +from core.file import helpers as file_helpers from core.ops.ops_trace_manager import OpsTraceManager from core.workflow.enums import NodeType from extensions.ext_database import db -from fields.app_fields import ( - deleted_tool_fields, - model_config_fields, - model_config_partial_fields, - site_fields, - tag_fields, -) -from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict -from libs.helper import AppIconUrlField, TimestampField from libs.login import current_account_with_tenant, login_required from models import App, Workflow +from models.model import IconType from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class AppListQuery(BaseModel): @@ -73,6 +68,48 @@ class AppListQuery(BaseModel): raise ValueError("Invalid UUID format in tag_ids.") from exc +# XSS prevention: patterns that could lead to XSS attacks +# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc. +_XSS_PATTERNS = [ + r"]*>.*?", # Script tags + r"]*?(?:/>|>.*?)", # Iframe tags (including self-closing) + r"javascript:", # JavaScript protocol + r"]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace) + r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc. + r"]*(?:\s*/>|>.*?)", # Object tags (opening tag) + r"]*>", # Embed tags (self-closing) + r"]*>", # Link tags with javascript +] + + +def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None: + """ + Validate that a string value doesn't contain potential XSS payloads. + + Args: + value: The string value to validate + field_name: Name of the field for error messages + + Returns: + The original value if safe + + Raises: + ValueError: If the value contains XSS patterns + """ + if value is None: + return None + + value_lower = value.lower() + for pattern in _XSS_PATTERNS: + if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE): + raise ValueError( + f"{field_name} contains invalid characters or patterns. " + "HTML tags, JavaScript, and other potentially dangerous content are not allowed." + ) + + return value + + class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) @@ -81,6 +118,11 @@ class CreateAppPayload(BaseModel): icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") + @field_validator("name", "description", mode="before") + @classmethod + def validate_xss_safe(cls, value: str | None, info) -> str | None: + return _validate_xss_safe(value, info.field_name) + class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") @@ -91,6 +133,11 @@ class UpdateAppPayload(BaseModel): use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") max_active_requests: int | None = Field(default=None, description="Maximum active requests") + @field_validator("name", "description", mode="before") + @classmethod + def validate_xss_safe(cls, value: str | None, info) -> str | None: + return _validate_xss_safe(value, info.field_name) + class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") @@ -99,6 +146,11 @@ class CopyAppPayload(BaseModel): icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") + @field_validator("name", "description", mode="before") + @classmethod + def validate_xss_safe(cls, value: str | None, info) -> str | None: + return _validate_xss_safe(value, info.field_name) + class AppExportQuery(BaseModel): include_secret: bool = Field(default=False, description="Include secrets in export") @@ -134,124 +186,292 @@ class AppTracePayload(BaseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +JSONValue: TypeAlias = Any -reg(AppListQuery) -reg(CreateAppPayload) -reg(UpdateAppPayload) -reg(CopyAppPayload) -reg(AppExportQuery) -reg(AppNamePayload) -reg(AppIconPayload) -reg(AppSiteStatusPayload) -reg(AppApiStatusPayload) -reg(AppTracePayload) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base models first -tag_model = console_ns.model("Tag", tag_fields) -workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -model_config_model = console_ns.model("ModelConfig", model_config_fields) -model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE.value: + return None + return file_helpers.get_signed_file_url(icon) -deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields) -site_model = console_ns.model("Site", site_fields) +class Tag(ResponseModel): + id: str + name: str + type: str -app_partial_model = console_ns.model( - "AppPartial", - { - "id": fields.String, - "name": fields.String, - "max_active_requests": fields.Raw(), - "description": fields.String(attribute="desc_or_prompt"), - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "tags": fields.List(fields.Nested(tag_model)), - "access_mode": fields.String, - "create_user_name": fields.String, - "author_name": fields.String, - "has_draft_trigger": fields.Boolean, - }, -) -app_detail_model = console_ns.model( - "AppDetail", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon": fields.String, - "icon_background": fields.String, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "tracing": fields.Raw, - "use_icon_as_answer_icon": fields.Boolean, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - }, -) +class WorkflowPartial(ResponseModel): + id: str + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None -app_detail_with_site_model = console_ns.model( - "AppDetailWithSite", - { - "id": fields.String, - "name": fields.String, - "description": fields.String, - "mode": fields.String(attribute="mode_compatible_with_agent"), - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "enable_site": fields.Boolean, - "enable_api": fields.Boolean, - "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True), - "workflow": fields.Nested(workflow_partial_model, allow_null=True), - "api_base_url": fields.String, - "use_icon_as_answer_icon": fields.Boolean, - "max_active_requests": fields.Integer, - "created_by": fields.String, - "created_at": TimestampField, - "updated_by": fields.String, - "updated_at": TimestampField, - "deleted_tools": fields.List(fields.Nested(deleted_tool_model)), - "access_mode": fields.String, - "tags": fields.List(fields.Nested(tag_model)), - "site": fields.Nested(site_model), - }, -) + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -app_pagination_model = console_ns.model( - "AppPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(app_partial_model), attribute="items"), - }, + +class ModelConfigPartial(ResponseModel): + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + pre_prompt: str | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ModelConfig(ResponseModel): + opening_statement: str | None = None + suggested_questions: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions") + ) + suggested_questions_after_answer: JSONValue | None = Field( + default=None, + validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"), + ) + speech_to_text: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text") + ) + text_to_speech: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech") + ) + retriever_resource: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource") + ) + annotation_reply: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply") + ) + more_like_this: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this") + ) + sensitive_word_avoidance: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance") + ) + external_data_tools: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools") + ) + model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model")) + user_input_form: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form") + ) + dataset_query_variable: str | None = None + pre_prompt: str | None = None + agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode")) + prompt_type: str | None = None + chat_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config") + ) + completion_prompt_config: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config") + ) + dataset_configs: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs") + ) + file_upload: JSONValue | None = Field( + default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload") + ) + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class Site(ResponseModel): + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str | None = None + icon_type: str | IconType | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str | None = None + chat_color_theme: str | None = None + chat_color_theme_inverted: bool | None = None + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str | None = None + prompt_public: bool | None = None + app_base_url: str | None = None + show_workflow_steps: bool | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("icon_type", mode="before") + @classmethod + def _normalize_icon_type(cls, value: str | IconType | None) -> str | None: + if isinstance(value, IconType): + return value.value + return value + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DeletedTool(ResponseModel): + type: str + tool_name: str + provider_id: str + + +class AppPartial(ResponseModel): + id: str + name: str + max_active_requests: int | None = None + description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description")) + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + model_config_: ModelConfigPartial | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + tags: list[Tag] = Field(default_factory=list) + access_mode: str | None = None + create_user_name: str | None = None + author_name: str | None = None + has_draft_trigger: bool | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetail(ResponseModel): + id: str + name: str + description: str | None = None + mode: str = Field(validation_alias="mode_compatible_with_agent") + icon: str | None = None + icon_background: str | None = None + enable_site: bool + enable_api: bool + model_config_: ModelConfig | None = Field( + default=None, + validation_alias=AliasChoices("app_model_config", "model_config"), + alias="model_config", + ) + workflow: WorkflowPartial | None = None + tracing: JSONValue | None = None + use_icon_as_answer_icon: bool | None = None + created_by: str | None = None + created_at: int | None = None + updated_by: str | None = None + updated_at: int | None = None + access_mode: str | None = None + tags: list[Tag] = Field(default_factory=list) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AppDetailWithSite(AppDetail): + icon_type: str | None = None + api_base_url: str | None = None + max_active_requests: int | None = None + deleted_tools: list[DeletedTool] = Field(default_factory=list) + site: Site | None = None + + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class AppPagination(ResponseModel): + page: int + limit: int = Field(validation_alias=AliasChoices("per_page", "limit")) + total: int + has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more")) + data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data")) + + +class AppExportResponse(ResponseModel): + data: str + + +register_schema_models( + console_ns, + AppListQuery, + CreateAppPayload, + UpdateAppPayload, + CopyAppPayload, + AppExportQuery, + AppNamePayload, + AppIconPayload, + AppSiteStatusPayload, + AppApiStatusPayload, + AppTracePayload, + Tag, + WorkflowPartial, + ModelConfigPartial, + ModelConfig, + Site, + DeletedTool, + AppPartial, + AppDetail, + AppDetailWithSite, + AppPagination, + AppExportResponse, ) @@ -260,7 +480,7 @@ class AppListApi(Resource): @console_ns.doc("list_apps") @console_ns.doc(description="Get list of applications with pagination and filtering") @console_ns.expect(console_ns.models[AppListQuery.__name__]) - @console_ns.response(200, "Success", app_pagination_model) + @console_ns.response(200, "Success", console_ns.models[AppPagination.__name__]) @setup_required @login_required @account_initialization_required @@ -276,7 +496,8 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) if not app_pagination: - return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} + empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) + return empty.model_dump(mode="json"), 200 if FeatureService.get_system_features().webapp_auth.enabled: app_ids = [str(app.id) for app in app_pagination.items] @@ -320,18 +541,18 @@ class AppListApi(Resource): for app in app_pagination.items: app.has_draft_trigger = str(app.id) in draft_trigger_app_ids - return marshal(app_pagination, app_pagination_model), 200 + pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True) + return pagination_model.model_dump(mode="json"), 200 @console_ns.doc("create_app") @console_ns.doc(description="Create a new application") @console_ns.expect(console_ns.models[CreateAppPayload.__name__]) - @console_ns.response(201, "App created successfully", app_detail_model) + @console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(app_detail_model) @cloud_edition_billing_resource_check("apps") @edit_permission_required def post(self): @@ -341,8 +562,8 @@ class AppListApi(Resource): app_service = AppService() app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) - - return app, 201 + app_detail = AppDetail.model_validate(app, from_attributes=True) + return app_detail.model_dump(mode="json"), 201 @console_ns.route("/apps/") @@ -350,13 +571,12 @@ class AppApi(Resource): @console_ns.doc("get_app_detail") @console_ns.doc(description="Get application details") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Success", app_detail_with_site_model) + @console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__]) @setup_required @login_required @account_initialization_required @enterprise_license_required - @get_app_model - @marshal_with(app_detail_with_site_model) + @get_app_model(mode=None) def get(self, app_model): """Get app detail""" app_service = AppService() @@ -367,21 +587,21 @@ class AppApi(Resource): app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) app_model.access_mode = app_setting.access_mode - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("update_app") @console_ns.doc(description="Update application details") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[UpdateAppPayload.__name__]) - @console_ns.response(200, "App updated successfully", app_detail_with_site_model) + @console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def put(self, app_model): """Update app""" args = UpdateAppPayload.model_validate(console_ns.payload) @@ -398,8 +618,8 @@ class AppApi(Resource): "max_active_requests": args.max_active_requests or 0, } app_model = app_service.update_app(app_model, args_dict) - - return app_model + response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.doc("delete_app") @console_ns.doc(description="Delete application") @@ -425,14 +645,13 @@ class AppCopyApi(Resource): @console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(params={"app_id": "Application ID to copy"}) @console_ns.expect(console_ns.models[CopyAppPayload.__name__]) - @console_ns.response(201, "App copied successfully", app_detail_with_site_model) + @console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model + @get_app_model(mode=None) @edit_permission_required - @marshal_with(app_detail_with_site_model) def post(self, app_model): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -458,7 +677,8 @@ class AppCopyApi(Resource): stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) - return app, 201 + response_model = AppDetailWithSite.model_validate(app, from_attributes=True) + return response_model.model_dump(mode="json"), 201 @console_ns.route("/apps//export") @@ -467,11 +687,7 @@ class AppExportApi(Resource): @console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(params={"app_id": "Application ID to export"}) @console_ns.expect(console_ns.models[AppExportQuery.__name__]) - @console_ns.response( - 200, - "App exported successfully", - console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), - ) + @console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @get_app_model @setup_required @@ -482,13 +698,14 @@ class AppExportApi(Resource): """Export app""" args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return { - "data": AppDslService.export_dsl( + payload = AppExportResponse( + data=AppDslService.export_dsl( app_model=app_model, include_secret=args.include_secret, workflow_id=args.workflow_id, ) - } + ) + return payload.model_dump(mode="json") @console_ns.route("/apps//name") @@ -497,20 +714,19 @@ class AppNameApi(Resource): @console_ns.doc(description="Check if app name is available") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppNamePayload.__name__]) - @console_ns.response(200, "Name availability checked") + @console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__]) @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppNamePayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_name(app_model, args.name) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//icon") @@ -524,16 +740,15 @@ class AppIconApi(Resource): @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//site-enable") @@ -542,21 +757,20 @@ class AppSiteStatus(Resource): @console_ns.doc(description="Enable or disable app site") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__]) - @console_ns.response(200, "Site status updated successfully", app_detail_model) + @console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) @edit_permission_required def post(self, app_model): args = AppSiteStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_site_status(app_model, args.enable_site) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//api-enable") @@ -565,21 +779,20 @@ class AppApiStatus(Resource): @console_ns.doc(description="Enable or disable app API") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__]) - @console_ns.response(200, "API status updated successfully", app_detail_model) + @console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @get_app_model - @marshal_with(app_detail_model) + @get_app_model(mode=None) def post(self, app_model): args = AppApiStatusPayload.model_validate(console_ns.payload) app_service = AppService() app_model = app_service.update_app_api_status(app_model, args.enable_api) - - return app_model + response_model = AppDetail.model_validate(app_model, from_attributes=True) + return response_model.model_dump(mode="json") @console_ns.route("/apps//trace") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index c16dcfd91f..56816dd462 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.conversation_fields import MessageTextField from fields.raws import FilesContainedField from libs.datetime_utils import naive_utc_now, parse_time_range from libs.helper import TimestampField @@ -177,6 +176,12 @@ annotation_hit_history_model = console_ns.model( }, ) + +class MessageTextField(fields.Raw): + def format(self, value): + return value[0]["text"] if value else "" + + # Simple message detail model simple_message_detail_model = console_ns.model( "SimpleMessageDetail", @@ -343,10 +348,13 @@ class CompletionConversationApi(Resource): ) if args.keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike(f"%{args.keyword}%"), - Message.answer.ilike(f"%{args.keyword}%"), + Message.query.ilike(f"%{escaped_keyword}%", escape="\\"), + Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) @@ -455,7 +463,10 @@ class ChatConversationApi(Resource): query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args.keyword: - keyword_filter = f"%{args.keyword}%" + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(args.keyword) + keyword_filter = f"%{escaped_keyword}%" query = ( query.join( Message, @@ -464,11 +475,11 @@ class ChatConversationApi(Resource): .join(subquery, subquery.c.conversation_id == Conversation.id) .where( or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter), + Message.query.ilike(keyword_filter, escape="\\"), + Message.answer.ilike(keyword_filter, escape="\\"), + Conversation.name.ilike(keyword_filter, escape="\\"), + Conversation.introduction.ilike(keyword_filter, escape="\\"), + subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"), ), ) .group_by(Conversation.id) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 772d98822e..4a52bf8abe 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,3 +1,5 @@ +from typing import Any + import flask_login from flask import make_response, request from flask_restx import Resource @@ -96,14 +98,13 @@ class LoginApi(Resource): if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() - # TODO: why invitation is re-assigned with different type? - invitation = args.invite_token # type: ignore - if invitation: - invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore + invitation_data: dict[str, Any] | None = None + if args.invite_token: + invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token) try: - if invitation: - data = invitation.get("data", {}) # type: ignore + if invitation_data: + data = invitation_data.get("data", {}) invitee_email = data.get("email") if data else None if invitee_email != args.email: raise InvalidEmailError() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 7ad1e56373..c20e83d36f 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -124,7 +124,7 @@ class OAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") try: - account = _generate_account(provider, user_info) + account, oauth_new_user = _generate_account(provider, user_info) except AccountNotFoundError: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): @@ -159,7 +159,10 @@ class OAuthCallback(Resource): ip_address=extract_remote_ip(request), ) - response = redirect(f"{dify_config.CONSOLE_WEB_URL}") + base_url = dify_config.CONSOLE_WEB_URL + query_char = "&" if "?" in base_url else "?" + target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}" + response = redirect(target_url) set_access_token_to_cookie(request, response, token_pair.access_token) set_refresh_token_to_cookie(request, response, token_pair.refresh_token) @@ -177,9 +180,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> return account -def _generate_account(provider: str, user_info: OAuthUserInfo): +def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]: # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) + oauth_new_user = False if account: tenants = TenantService.get_join_tenants(account) @@ -193,6 +197,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): tenant_was_created.send(new_tenant) if not account: + oauth_new_user = True if not FeatureService.get_system_features().is_allow_register: if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): raise AccountRegisterError( @@ -220,4 +225,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): # Link account AccountService.link_account_integrate(provider, user_info.id, account) - return account + return account, oauth_new_user diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e94768f985..ac78d3854b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -751,12 +751,12 @@ class DocumentApi(DocumentResource): elif metadata == "without": dataset_process_rules = DatasetService.get_process_rules(dataset_id) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} - data_source_info = document.data_source_detail_dict response = { "id": document.id, "position": document.position, "data_source_type": document.data_source_type, - "data_source_info": data_source_info, + "data_source_info": document.data_source_info_dict, + "data_source_detail_dict": document.data_source_detail_dict, "dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule": dataset_process_rules, "document_process_rule": document_process_rules, @@ -784,12 +784,12 @@ class DocumentApi(DocumentResource): else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} - data_source_info = document.data_source_detail_dict response = { "id": document.id, "position": document.position, "data_source_type": document.data_source_type, - "data_source_info": data_source_info, + "data_source_info": document.data_source_info_dict, + "data_source_detail_dict": document.data_source_detail_dict, "dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule": dataset_process_rules, "document_process_rule": document_process_rules, diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index e73abc2555..16fecb41c6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,10 +3,12 @@ import uuid from flask import request from flask_restx import Resource, marshal from pydantic import BaseModel, Field -from sqlalchemy import select +from sqlalchemy import String, cast, func, or_, select +from sqlalchemy.dialects.postgresql import JSONB from werkzeug.exceptions import Forbidden, NotFound import services +from configs import dify_config from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError @@ -28,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -143,7 +146,31 @@ class DatasetDocumentSegmentListApi(Resource): query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword) + # Search in both content and keywords fields + # Use database-specific methods for JSON array search + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + # PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text + keywords_condition = func.array_to_string( + func.array( + select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB))) + .correlate(DocumentSegment) + .scalar_subquery() + ), + ",", + ).ilike(f"%{escaped_keyword}%", escape="\\") + else: + # MySQL: Cast JSON to string for pattern matching + # MySQL stores Chinese text directly in JSON without Unicode escaping + keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\") + + query = query.where( + or_( + DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"), + keywords_condition, + ) + ) if args.enabled.lower() != "all": if args.enabled.lower() == "true": diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db7c50f422..db1a874437 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal, reqparse +from flask_restx import marshal from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -56,15 +56,10 @@ class DatasetsHitTestingBase: HitTestingService.hit_testing_args_check(args) @staticmethod - def parse_args(): - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, required=False, location="json") - .add_argument("attachment_ids", type=list, required=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - return parser.parse_args() + def parse_args(payload: dict[str, Any]) -> dict[str, Any]: + """Validate and return hit-testing arguments from an incoming payload.""" + hit_testing_payload = HitTestingPayload.model_validate(payload or {}) + return hit_testing_payload.model_dump(exclude_none=True) @staticmethod def perform_hit_testing(dataset, args): diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 46d67f0581..02efc54eea 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource): pipeline=pipeline, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE, streaming=streaming, ) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 51995b8b8a..933c80f509 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,8 +1,7 @@ from typing import Any from flask import request -from flask_restx import marshal_with -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, TypeAdapter, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from fields.conversation_fields import ( + ConversationInfiniteScrollPagination, + ResultResponse, + SimpleConversation, +) from libs.helper import UUIDStrOrEmpty from libs.login import current_user from models import Account @@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl endpoint="installed_app_conversations", ) class ConversationListApi(InstalledAppResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) @console_ns.expect(console_ns.models[ConversationListQuery.__name__]) def get(self, installed_app): app_model = installed_app.app @@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource): if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") with Session(db.engine) as session: - return WebConversationService.pagination_by_last_id( + pagination = WebConversationService.pagination_by_last_id( session=session, app_model=app_model, user=current_user, @@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource): invoke_from=InvokeFrom.EXPLORE, pinned=args.pinned, ) + adapter = TypeAdapter(SimpleConversation) + conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] + return ConversationInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=conversations, + ).model_dump(mode="json") except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 @console_ns.route( @@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource): endpoint="installed_app_conversation_rename", ) class ConversationRenameApi(InstalledAppResource): - @marshal_with(simple_conversation_fields) @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__]) def post(self, installed_app, c_id): app_model = installed_app.app @@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource): try: if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") - return ConversationService.rename( + conversation = ConversationService.rename( app_model, conversation_id, current_user, payload.name, payload.auto_generate ) + return ( + TypeAdapter(SimpleConversation) + .validate_python(conversation, from_attributes=True) + .model_dump(mode="json") + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @console_ns.route( @@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource): raise ValueError("current_user must be an Account instance") WebConversationService.unpin(app_model, conversation_id, current_user) - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index d596d60b36..88487ac96f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,8 +2,7 @@ import logging from typing import Literal from flask import request -from flask_restx import marshal_with -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -23,7 +22,8 @@ from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from fields.message_fields import message_infinite_scroll_pagination_fields +from fields.conversation_fields import ResultResponse +from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant @@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor endpoint="installed_app_messages", ) class MessageListApi(InstalledAppResource): - @marshal_with(message_infinite_scroll_pagination_fields) @console_ns.expect(console_ns.models[MessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() @@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource): args = MessageListQuery.model_validate(request.args.to_dict()) try: - return MessageService.pagination_by_first_id( + pagination = MessageService.pagination_by_first_id( app_model, current_user, str(args.conversation_id), str(args.first_id) if args.first_id else None, args.limit, ) + adapter = TypeAdapter(MessageListItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return MessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except FirstMessageNotExistsError: @@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @console_ns.route( @@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource): logger.exception("internal server error.") raise InternalServerError() - return {"data": questions} + return SuggestedQuestionsResponse(data=questions).model_dump(mode="json") diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 9c6b2aedfb..660a4d5aea 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,5 +1,3 @@ -from flask_restx import marshal_with - from controllers.common import fields from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError @@ -13,7 +11,6 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" - @marshal_with(fields.parameters_fields) def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app @@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return fields.Parameters.model_validate(parameters).model_dump(mode="json") @console_ns.route("/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index bc7b8e7651..ea3de91741 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,14 +1,14 @@ from flask import request -from flask_restx import fields, marshal_with -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import NotFound from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField, UUIDStrOrEmpty +from fields.conversation_fields import ResultResponse +from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem +from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -26,28 +26,8 @@ class SavedMessageCreatePayload(BaseModel): register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) -feedback_fields = {"rating": fields.String} - -message_fields = { - "id": fields.String, - "inputs": fields.Raw, - "query": fields.String, - "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "created_at": TimestampField, -} - - @console_ns.route("/installed-apps//saved-messages", endpoint="installed_app_saved_messages") class SavedMessageListApi(InstalledAppResource): - saved_message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), - } - - @marshal_with(saved_message_infinite_scroll_pagination_fields) @console_ns.expect(console_ns.models[SavedMessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() @@ -57,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource): args = SavedMessageListQuery.model_validate(request.args.to_dict()) - return SavedMessageService.pagination_by_last_id( + pagination = SavedMessageService.pagination_by_last_id( app_model, current_user, str(args.last_id) if args.last_id else None, args.limit, ) + adapter = TypeAdapter(SavedMessageItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return SavedMessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") @console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__]) def post(self, installed_app): @@ -78,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @console_ns.route( @@ -96,4 +83,4 @@ class SavedMessageApi(InstalledAppResource): SavedMessageService.delete(app_model, current_user, message_id) - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 29417dc896..109a3cd0d3 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,7 +1,7 @@ from typing import Literal from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import Forbidden import services @@ -15,18 +15,21 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, setup_required, ) from extensions.ext_database import db -from fields.file_fields import file_fields, upload_config_fields +from fields.file_fields import FileResponse, UploadConfig from libs.login import current_account_with_tenant, login_required from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, UploadConfig, FileResponse) + PREVIEW_WORDS_LIMIT = 3000 @@ -35,26 +38,27 @@ class FileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(upload_config_fields) + @console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__]) def get(self): - return { - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT, - "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT, - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, - "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, - "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, - "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, - }, 200 + config = UploadConfig( + file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT, + batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT, + file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT, + image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT, + single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, + ) + return config.model_dump(mode="json"), 200 @setup_required @login_required @account_initialization_required - @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") + @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() source_str = request.form.get("source") @@ -90,7 +94,8 @@ class FileApi(Resource): except services.errors.file.BlockedFileExtensionError as blocked_extension_error: raise BlockedFileExtensionError(blocked_extension_error.description) - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 @console_ns.route("/files//preview") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 47eef7eb7e..70c7b80ffa 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field import services @@ -11,19 +11,22 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant from services.file_service import FileService from . import console_ns +register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl) + @console_ns.route("/remote-files/") class RemoteFileInfoApi(Resource): - @marshal_with(remote_file_info_fields) + @console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__]) def get(self, url): decoded_url = urllib.parse.unquote(url) resp = ssrf_proxy.head(decoded_url) @@ -31,10 +34,11 @@ class RemoteFileInfoApi(Resource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", 0)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", 0)), + ) + return info.model_dump(mode="json") class RemoteFileUploadPayload(BaseModel): @@ -50,7 +54,7 @@ console_ns.schema_model( @console_ns.route("/remote-files/upload") class RemoteFileUploadApi(Resource): @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) - @marshal_with(file_fields_with_signed_url) + @console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__]) def post(self): args = RemoteFileUploadPayload.model_validate(console_ns.payload) url = args.url @@ -85,13 +89,14 @@ class RemoteFileUploadApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload.model_dump(mode="json"), 201 diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 55eaa2f09f..03ad0f423b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import Literal @@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel): repeat_new_password: str @model_validator(mode="after") - def check_passwords_match(self) -> "AccountPasswordPayload": + def check_passwords_match(self) -> AccountPasswordPayload: if self.new_password != self.repeat_new_password: raise RepeatPasswordNotMatchError() return self diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 497e62b790..c13bfd986e 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -4,12 +4,11 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource, reqparse -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config -from constants import HIDDEN_VALUE, UNKNOWN_VALUE from controllers.web.error import NotFoundError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType @@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel): parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription") properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription") + @model_validator(mode="after") + def check_at_least_one_field(self): + if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)): + raise ValueError("At least one of name, credentials, parameters, or properties must be provided") + return self + class TriggerSubscriptionVerifyRequest(BaseModel): """Request payload for verifying subscription credentials.""" @@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource): user = current_user assert user.current_tenant_id is not None - args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload) + request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload) subscription = TriggerProviderService.get_subscription_by_id( tenant_id=user.current_tenant_id, @@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource): provider_id = TriggerProviderID(subscription.provider_id) try: - # rename only - if ( - args.name is not None - and args.credentials is None - and args.parameters is None - and args.properties is None - ): + # For rename only, just update the name + rename = request.name is not None and not any((request.credentials, request.parameters, request.properties)) + # When credential type is UNAUTHORIZED, it indicates the subscription was manually created + # For Manually created subscription, they dont have credentials, parameters + # They only have name and properties(which is input by user) + manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED + if rename or manually_created: TriggerProviderService.update_trigger_subscription( tenant_id=user.current_tenant_id, subscription_id=subscription_id, - name=args.name, + name=request.name, + properties=request.properties, ) return 200 - # rebuild for create automatically by the provider - match subscription.credential_type: - case CredentialType.UNAUTHORIZED: - TriggerProviderService.update_trigger_subscription( - tenant_id=user.current_tenant_id, - subscription_id=subscription_id, - name=args.name, - properties=args.properties, - ) - return 200 - case CredentialType.API_KEY | CredentialType.OAUTH2: - if args.credentials: - new_credentials: dict[str, Any] = { - key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE) - for key, value in args.credentials.items() - } - else: - new_credentials = subscription.credentials - - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=user.current_tenant_id, - name=args.name, - provider_id=provider_id, - subscription_id=subscription_id, - credentials=new_credentials, - parameters=args.parameters or subscription.parameters, - ) - return 200 - case _: - raise BadRequest("Invalid credential type") + # For the rest cases(API_KEY, OAUTH2) + # we need to call third party provider(e.g. GitHub) to rebuild the subscription + TriggerProviderService.rebuild_trigger_subscription( + tenant_id=user.current_tenant_id, + name=request.name, + provider_id=provider_id, + subscription_id=subscription_id, + credentials=request.credentials or subscription.credentials, + parameters=request.parameters or subscription.parameters, + ) + return 200 except ValueError as e: raise BadRequest(str(e)) except Exception as e: diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 6096a87c56..28ec4b3935 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -4,18 +4,18 @@ from flask import request from flask_restx import Resource from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from ..common.errors import ( FileTooLargeError, UnsupportedFileTypeError, ) +from ..common.schema import register_schema_models from ..console.wraps import setup_required from ..files import files_ns from ..inner_api.plugin.wraps import get_user @@ -35,6 +35,8 @@ files_ns.schema_model( PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) +register_schema_models(files_ns, FileResponse) + @files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource): 415: "Unsupported file type", } ) - @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED) + @files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__]) def post(self): """Upload a file for plugin usage. @@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource): """ args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - file: FileStorage | None = request.files.get("file") + file = request.files.get("file") if file is None: raise Forbidden("File is required.") @@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource): user_id = args.user_id user = get_user(tenant_id, user_id) - filename: str | None = file.filename - mimetype: str | None = file.mimetype + filename = file.filename + mimetype = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") @@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource): preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) # Create a dictionary with all the necessary attributes - result = { - "id": tool_file.id, - "user_id": tool_file.user_id, - "tenant_id": tool_file.tenant_id, - "conversation_id": tool_file.conversation_id, - "file_key": tool_file.file_key, - "mimetype": tool_file.mimetype, - "original_url": tool_file.original_url, - "name": tool_file.name, - "size": tool_file.size, - "mime_type": mimetype, - "extension": extension, - "preview_url": preview_url, - } + result = FileResponse( + id=tool_file.id, + name=tool_file.name, + size=tool_file.size, + extension=extension, + mime_type=mimetype, + preview_url=preview_url, + source_url=tool_file.original_url, + original_url=tool_file.original_url, + user_id=tool_file.user_id, + tenant_id=tool_file.tenant_id, + conversation_id=tool_file.conversation_id, + file_key=tool_file.file_key, + ) - return result, 201 + return result.model_dump(mode="json"), 201 except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 63c373b50f..85ac9336d6 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,7 +1,7 @@ from typing import Literal from flask import request -from flask_restx import Api, Namespace, Resource, fields +from flask_restx import Namespace, Resource, fields from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field @@ -92,7 +92,7 @@ annotation_list_fields = { } -def build_annotation_list_model(api_or_ns: Api | Namespace): +def build_annotation_list_model(api_or_ns: Namespace): """Build the annotation list model for the API or Namespace.""" copied_annotation_list_fields = annotation_list_fields.copy() copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 25d7ccccec..562f5e33cc 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,6 +1,6 @@ from flask_restx import Resource -from controllers.common.fields import build_parameters_model +from controllers.common.fields import Parameters from controllers.service_api import service_api_ns from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token @@ -23,7 +23,6 @@ class AppParameterApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_parameters_model(service_api_ns)) def get(self, app_model: App): """Retrieve app parameters. @@ -45,7 +44,8 @@ class AppParameterApi(Resource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return Parameters.model_validate(parameters).model_dump(mode="json") @service_api_ns.route("/meta") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 40e4bde389..62e8258e25 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -3,8 +3,7 @@ from uuid import UUID from flask import request from flask_restx import Resource -from flask_restx._http import HTTPStatus -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - build_conversation_delete_model, - build_conversation_infinite_scroll_pagination_model, - build_simple_conversation_model, + ConversationDelete, + ConversationInfiniteScrollPagination, + SimpleConversation, ) from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, @@ -105,7 +104,6 @@ class ConversationApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): """List all conversations for the current user. @@ -120,7 +118,7 @@ class ConversationApi(Resource): try: with Session(db.engine) as session: - return ConversationService.pagination_by_last_id( + pagination = ConversationService.pagination_by_last_id( session=session, app_model=app_model, user=end_user, @@ -129,6 +127,13 @@ class ConversationApi(Resource): invoke_from=InvokeFrom.SERVICE_API, sort_by=query_args.sort_by, ) + adapter = TypeAdapter(SimpleConversation) + conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] + return ConversationInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=conversations, + ).model_dump(mode="json") except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -146,7 +151,6 @@ class ConversationDetailApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) def delete(self, app_model: App, end_user: EndUser, c_id): """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) @@ -159,7 +163,7 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ConversationDelete(result="success").model_dump(mode="json"), 204 @service_api_ns.route("/conversations//name") @@ -176,7 +180,6 @@ class ConversationRenameApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns)) def post(self, app_model: App, end_user: EndUser, c_id): """Rename a conversation or auto-generate a name.""" app_mode = AppMode.value_of(app_model.mode) @@ -188,7 +191,14 @@ class ConversationRenameApi(Resource): payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) + conversation = ConversationService.rename( + app_model, conversation_id, end_user, payload.name, payload.auto_generate + ) + return ( + TypeAdapter(SimpleConversation) + .validate_python(conversation, from_attributes=True) + .model_dump(mode="json") + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index ffe4e0b492..6f6dadf768 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -10,13 +10,16 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from models import App, EndUser from services.file_service import FileService +register_schema_models(service_api_ns, FileResponse) + @service_api_ns.route("/files/upload") class FileApi(Resource): @@ -31,8 +34,8 @@ class FileApi(Resource): 415: "Unsupported file type", } ) - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) - @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore + @service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__]) def post(self, app_model: App, end_user: EndUser): """Upload a file for use in conversations. @@ -64,4 +67,5 @@ class FileApi(Resource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index bb908a8fb1..8981bbd7d5 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,11 +1,10 @@ -import json import logging from typing import Literal from uuid import UUID from flask import request -from flask_restx import Namespace, Resource, fields -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom -from fields.conversation_fields import build_message_file_model -from fields.message_fields import build_agent_thought_model, build_feedback_model -from fields.raws import FilesContainedField -from libs.helper import TimestampField +from fields.conversation_fields import ResultResponse +from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -48,50 +45,6 @@ class FeedbackListQuery(BaseModel): register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery) -def build_message_model(api_or_ns: Namespace): - """Build the message model for the API or Namespace.""" - # First build the nested models - feedback_model = build_feedback_model(api_or_ns) - agent_thought_model = build_agent_thought_model(api_or_ns) - message_file_model = build_message_file_model(api_or_ns) - - # Then build the message fields with nested models - message_fields = { - "id": fields.String, - "conversation_id": fields.String, - "parent_message_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_model)), - "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True), - "retriever_resources": fields.Raw( - attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", []) - if obj.message_metadata - else [] - ), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "status": fields.String, - "error": fields.String, - "generation_detail": fields.Raw, - } - return api_or_ns.model("Message", message_fields) - - -def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace): - """Build the message infinite scroll pagination model for the API or Namespace.""" - # Build the nested message model first - message_model = build_message_model(api_or_ns) - - message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_model)), - } - return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields) - - @service_api_ns.route("/messages") class MessageListApi(Resource): @service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__]) @@ -105,7 +58,6 @@ class MessageListApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): """List messages in a conversation. @@ -120,9 +72,16 @@ class MessageListApi(Resource): first_id = str(query_args.first_id) if query_args.first_id else None try: - return MessageService.pagination_by_first_id( + pagination = MessageService.pagination_by_first_id( app_model, end_user, conversation_id, first_id, query_args.limit ) + adapter = TypeAdapter(MessageListItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return MessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except FirstMessageNotExistsError: @@ -163,7 +122,7 @@ class MessageFeedbackApi(Resource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @service_api_ns.route("/app/feedbacks") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 9f8324a84e..8b47a887bb 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,7 +1,7 @@ from flask_restx import Resource from werkzeug.exceptions import Forbidden -from controllers.common.fields import build_site_model +from controllers.common.fields import Site as SiteResponse from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db @@ -23,7 +23,6 @@ class AppSiteApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_site_model(service_api_ns)) def get(self, app_model: App): """Retrieve app site info. @@ -38,4 +37,4 @@ class AppSiteApi(Resource): if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() - return site + return SiteResponse.model_validate(site).model_dump(mode="json") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 4964888fd6..6a549fc926 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -3,7 +3,7 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Api, Namespace, Resource, fields +from flask_restx import Namespace, Resource, fields from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -78,7 +78,7 @@ workflow_run_fields = { } -def build_workflow_run_model(api_or_ns: Api | Namespace): +def build_workflow_run_model(api_or_ns: Namespace): """Build the workflow run model for the API or Namespace.""" return api_or_ns.model("WorkflowRun", workflow_run_fields) diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index d81287d56f..8dbb690901 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + args = self.parse_args(service_api_ns.payload) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 0a2017e2bd..70b5030237 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource): pipeline=pipeline, user=current_user, args=payload.model_dump(), - invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER, streaming=payload.response_mode == "streaming", ) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index db3b93a4dc..62ea532eac 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, ConfigDict, Field from werkzeug.exceptions import Unauthorized @@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: @@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return fields.Parameters.model_validate(parameters).model_dump(mode="json") @web_ns.route("/meta") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 86e19423e5..e76649495a 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,14 +1,21 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from typing import Literal + +from flask import request +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from fields.conversation_fields import ( + ConversationInfiniteScrollPagination, + ResultResponse, + SimpleConversation, +) from libs.helper import uuid_value from models.model import AppMode from services.conversation_service import ConversationService @@ -16,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +class ConversationListQuery(BaseModel): + last_id: str | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at" + + @field_validator("last_id") + @classmethod + def validate_last_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) + + @web_ns.route("/conversations") class ConversationListApi(WebApiResource): @web_ns.doc("Get Conversation List") @@ -54,54 +90,39 @@ class ConversationListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args = request.args.to_dict() + query = ConversationListQuery.model_validate(raw_args) try: with Session(db.engine) as session: - return WebConversationService.pagination_by_last_id( + pagination = WebConversationService.pagination_by_last_id( session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=query.last_id, + limit=query.limit, invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], + pinned=query.pinned, + sort_by=query.sort_by, ) + adapter = TypeAdapter(SimpleConversation) + conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data] + return ConversationInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=conversations, + ).model_dump(mode="json") except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @web_ns.route("/conversations/") class ConversationApi(WebApiResource): - delete_response_fields = { - "result": fields.String, - } - @web_ns.doc("Delete Conversation") @web_ns.doc(description="Delete a specific conversation.") @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) @@ -115,7 +136,6 @@ class ConversationApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(delete_response_fields) def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -126,7 +146,7 @@ class ConversationApi(WebApiResource): ConversationService.delete(app_model, conversation_id, end_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 @web_ns.route("/conversations//name") @@ -155,7 +175,6 @@ class ConversationRenameApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -163,25 +182,23 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(web_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) + conversation = ConversationService.rename( + app_model, conversation_id, end_user, payload.name, payload.auto_generate + ) + return ( + TypeAdapter(SimpleConversation) + .validate_python(conversation, from_attributes=True) + .model_dump(mode="json") + ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @web_ns.route("/conversations//pin") class ConversationPinApi(WebApiResource): - pin_response_fields = { - "result": fields.String, - } - @web_ns.doc("Pin Conversation") @web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.") @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) @@ -195,7 +212,6 @@ class ConversationPinApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(pin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -208,15 +224,11 @@ class ConversationPinApi(WebApiResource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @web_ns.route("/conversations//unpin") class ConversationUnPinApi(WebApiResource): - unpin_response_fields = { - "result": fields.String, - } - @web_ns.doc("Unpin Conversation") @web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.") @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) @@ -230,7 +242,6 @@ class ConversationUnPinApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(unpin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -239,4 +250,4 @@ class ConversationUnPinApi(WebApiResource): conversation_id = str(c_id) WebConversationService.unpin(app_model, conversation_id, end_user) - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 80ad61e549..0036c90800 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,4 @@ from flask import request -from flask_restx import marshal_with import services from controllers.common.errors import ( @@ -9,12 +8,15 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from fields.file_fields import build_file_model +from fields.file_fields import FileResponse from services.file_service import FileService +register_schema_models(web_ns, FileResponse) + @web_ns.route("/files/upload") class FileApi(WebApiResource): @@ -28,7 +30,7 @@ class FileApi(WebApiResource): 415: "Unsupported file type", } ) - @marshal_with(build_file_model(web_ns)) + @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__]) def post(self, app_model, end_user): """Upload a file for use in web applications. @@ -81,4 +83,5 @@ class FileApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return upload_file, 201 + response = FileResponse.model_validate(upload_file, from_attributes=True) + return response.model_dump(mode="json"), 201 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 51ce024a5b..80035ba818 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,8 +2,7 @@ import logging from typing import Literal from flask import request -from flask_restx import fields, marshal_with -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from fields.conversation_fields import message_file_fields -from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields -from fields.raws import FilesContainedField +from fields.conversation_fields import ResultResponse +from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -70,30 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message @web_ns.route("/messages") class MessageListApi(WebApiResource): - message_fields = { - "id": fields.String, - "conversation_id": fields.String, - "parent_message_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "generation_detail": fields.Raw, - } - - message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), - } - @web_ns.doc("Get Message List") @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") @web_ns.doc( @@ -122,7 +96,6 @@ class MessageListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -132,9 +105,16 @@ class MessageListApi(WebApiResource): query = MessageListQuery.model_validate(raw_args) try: - return MessageService.pagination_by_first_id( + pagination = MessageService.pagination_by_first_id( app_model, end_user, query.conversation_id, query.first_id, query.limit ) + adapter = TypeAdapter(WebMessageListItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return WebMessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except FirstMessageNotExistsError: @@ -143,10 +123,6 @@ class MessageListApi(WebApiResource): @web_ns.route("/messages//feedbacks") class MessageFeedbackApi(WebApiResource): - feedback_response_fields = { - "result": fields.String, - } - @web_ns.doc("Create Message Feedback") @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) @@ -171,7 +147,6 @@ class MessageFeedbackApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(feedback_response_fields) def post(self, app_model, end_user, message_id): message_id = str(message_id) @@ -188,7 +163,7 @@ class MessageFeedbackApi(WebApiResource): except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @web_ns.route("/messages//more-like-this") @@ -248,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource): @web_ns.route("/messages//suggested-questions") class MessageSuggestedQuestionApi(WebApiResource): - suggested_questions_response_fields = { - "data": fields.List(fields.String), - } - @web_ns.doc("Get Suggested Questions") @web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).") @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) @@ -265,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(suggested_questions_response_fields) def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -278,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource): app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP ) # questions is a list of strings, not a list of Message objects - # so we can directly return it except MessageNotExistsError: raise NotFound("Message not found") except ConversationNotExistsError: @@ -297,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource): logger.exception("internal server error.") raise InternalServerError() - return {"data": questions} + return SuggestedQuestionsResponse(data=questions).model_dump(mode="json") diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index c1f976829f..b08b3fe858 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from flask_restx import marshal_with from pydantic import BaseModel, Field, HttpUrl import services @@ -14,7 +13,7 @@ from controllers.common.errors import ( from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db -from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model +from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService from ..common.schema import register_schema_models @@ -26,7 +25,7 @@ class RemoteFileUploadPayload(BaseModel): url: HttpUrl = Field(description="Remote file URL") -register_schema_models(web_ns, RemoteFileUploadPayload) +register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl) @web_ns.route("/remote-files/") @@ -41,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_remote_file_info_model(web_ns)) + @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__]) def get(self, app_model, end_user, url): """Get information about a remote file. @@ -65,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource): # failed back to get method resp = ssrf_proxy.get(decoded_url, timeout=3) resp.raise_for_status() - return { - "file_type": resp.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(resp.headers.get("Content-Length", -1)), - } + info = RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", -1)), + ) + return info.model_dump(mode="json") @web_ns.route("/remote-files/upload") @@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource): 500: "Failed to fetch remote file", } ) - @marshal_with(build_file_with_signed_url_model(web_ns)) + @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__]) def post(self, app_model, end_user): """Upload a file from a remote URL. @@ -139,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at, - }, 201 + payload1 = FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ) + return payload1.model_dump(mode="json"), 201 diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 865f3610a7..29993100f6 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,40 +1,32 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource -from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField, uuid_value +from fields.conversation_fields import ResultResponse +from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem +from libs.helper import UUIDStrOrEmpty from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = {"rating": fields.String} -message_fields = { - "id": fields.String, - "inputs": fields.Raw, - "query": fields.String, - "answer": fields.String, - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "created_at": TimestampField, -} +class SavedMessageListQuery(BaseModel): + last_id: UUIDStrOrEmpty | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUIDStrOrEmpty + + +register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) @web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): - saved_message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), - } - - post_response_fields = { - "result": fields.String, - } - @web_ns.doc("Get Saved Messages") @web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.") @web_ns.doc( @@ -58,19 +50,21 @@ class SavedMessageListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + raw_args = request.args.to_dict() + query = SavedMessageListQuery.model_validate(raw_args) - return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) + adapter = TypeAdapter(SavedMessageItem) + items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] + return SavedMessageInfiniteScrollPagination( + limit=pagination.limit, + has_more=pagination.has_more, + data=items, + ).model_dump(mode="json") @web_ns.doc("Save Message") @web_ns.doc(description="Save a specific message for later reference.") @@ -89,28 +83,22 @@ class SavedMessageListApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(post_response_fields) def post(self, app_model, end_user): if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, args["message_id"]) + SavedMessageService.save(app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {"result": "success"} + return ResultResponse(result="success").model_dump(mode="json") @web_ns.route("/saved-messages/") class SavedMessageApi(WebApiResource): - delete_response_fields = { - "result": fields.String, - } - @web_ns.doc("Delete Saved Message") @web_ns.doc(description="Remove a message from saved messages.") @web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}}) @@ -124,7 +112,6 @@ class SavedMessageApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(delete_response_fields) def delete(self, app_model, end_user, message_id): message_id = str(message_id) @@ -133,4 +120,4 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index ee092e55c5..a2ae8dec5b 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -20,6 +20,8 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion @@ -40,6 +42,7 @@ from models import Workflow from models.enums import UserFrom from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) workflow_entry.graph_engine.layer(persistence_layer) + conversation_variable_layer = ConversationVariablePersistenceLayer( + ConversationVariableUpdater(session_factory.get_session_maker()) + ) + workflow_entry.graph_engine.layer(conversation_variable_layer) for layer in self._graph_engine_layers: workflow_entry.graph_engine.layer(layer) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index c4d89b8b2f..6c4f96c2e4 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -471,6 +471,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if node_finish_resp: yield node_finish_resp + # For ANSWER nodes, check if we need to send a message_replace event + # Only send if the final output differs from the accumulated task_state.answer + # This happens when variables were updated by variable_assigner during workflow execution + if event.node_type == NodeType.ANSWER and event.outputs: + final_answer = event.outputs.get("answer") + if final_answer is not None and final_answer != self._task_state.answer: + logger.info( + "ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event", + final_answer, + self._task_state.answer, + ) + # Update the task state answer + self._task_state.answer = str(final_answer) + # Send message_replace event to update the UI + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=str(final_answer), + reason="variable_update", + ) + def _handle_node_failed_events( self, event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent], diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 13eb40fd60..ea4441b5d8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator): pipeline=pipeline, workflow=workflow, start_node_id=start_node_id ) documents: list[Document] = [] - if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"): + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"): from services.dataset_service import DocumentService for datasource_info in datasource_info_list: @@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator): for i, datasource_info in enumerate(datasource_info_list): workflow_run_id = str(uuid.uuid4()) document_id = args.get("original_document_id") or None - if invoke_from == InvokeFrom.PUBLISHED and not is_retry: + if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry: document_id = document_id or documents[i].id document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0cb573cb86..5bc453420d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -42,7 +42,8 @@ class InvokeFrom(StrEnum): # DEBUGGER indicates that this invocation is from # the workflow (or chatflow) edit page. DEBUGGER = "debugger" - PUBLISHED = "published" + # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow. + PUBLISHED_PIPELINE = "published" # VALIDATION indicates that this invocation is from validation. VALIDATION = "validation" diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 79fbafe39e..3f9f3da9b2 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -75,7 +75,7 @@ class AnnotationReplyFeature: AppAnnotationService.add_annotation_history( annotation.id, app_record.id, - annotation.question, + annotation.question_text, annotation.content, query, user_id, diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py new file mode 100644 index 0000000000..77cc00bdc9 --- /dev/null +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -0,0 +1,60 @@ +import logging + +from core.variables import Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers + +logger = logging.getLogger(__name__) + + +class ConversationVariablePersistenceLayer(GraphEngineLayer): + def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None: + super().__init__() + self._conversation_variable_updater = conversation_variable_updater + + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + if not isinstance(event, NodeRunSucceededEvent): + return + if event.node_type != NodeType.VARIABLE_ASSIGNER: + return + if self.graph_runtime_state is None: + return + + updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] + if not updated_variables: + return + + conversation_id = self.graph_runtime_state.system_variable.conversation_id + if conversation_id is None: + return + + updated_any = False + for item in updated_variables: + selector = item.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) + continue + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + continue + variable = self.graph_runtime_state.variable_pool.get(selector) + if not isinstance(variable, Variable): + logger.warning( + "Conversation variable not found in variable pool. selector=%s", + selector, + ) + continue + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) + updated_any = True + + if updated_any: + self._conversation_variable_updater.flush() + + def on_graph_end(self, error: Exception | None) -> None: + pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 61a3e1baca..bf76ae8178 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): """ if isinstance(session_factory, Engine): session_factory = sessionmaker(session_factory) + super().__init__() self._session_maker = session_factory self._state_owner_user_id = state_owner_user_id self._generate_entity = generate_entity @@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer): if not isinstance(event, GraphRunPausedEvent): return - assert self.graph_runtime_state is not None - entity_wrapper: _GenerateEntityUnion if isinstance(self._generate_entity, WorkflowAppGenerateEntity): entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity) diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index fe1a46a945..225b758fcb 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer): trigger_log_id: str, session_maker: sessionmaker[Session], ): + super().__init__() self.trigger_log_id = trigger_log_id self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity @@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer): elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds() # Extract relevant data from result - if not self.graph_runtime_state: - logger.exception("Graph runtime state is not set") - return - outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py index 50c7249fe4..451e4fda0e 100644 --- a/api/core/datasource/__base/datasource_plugin.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from configs import dify_config @@ -30,7 +32,7 @@ class DatasourcePlugin(ABC): """ return DatasourceProviderType.LOCAL_FILE - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": + def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin: return self.__class__( entity=self.entity.model_copy(), runtime=runtime, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 260dcf04f5..dde7d59726 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from enum import StrEnum from typing import Any @@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum): ONLINE_DRIVE = "online_drive" @classmethod - def value_of(cls, value: str) -> "DatasourceProviderType": + def value_of(cls, value: str) -> DatasourceProviderType: """ Get value of given mode. @@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter): typ: DatasourceParameterType, required: bool, options: list[str] | None = None, - ) -> "DatasourceParameter": + ) -> DatasourceParameter: """ get a simple datasource parameter @@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "DatasourceInvokeMeta": + def empty(cls) -> DatasourceInvokeMeta: """ Get an empty instance of DatasourceInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "DatasourceInvokeMeta": + def error_instance(cls, error: str) -> DatasourceInvokeMeta: """ Get an instance of DatasourceInvokeMeta with error """ diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py index 1dae2eafd4..45d4bc4594 100644 --- a/api/core/db/session_factory.py +++ b/api/core/db/session_factory.py @@ -1,7 +1,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -_session_maker: sessionmaker | None = None +_session_maker: sessionmaker[Session] | None = None def configure_session_factory(engine: Engine, expire_on_commit: bool = False): @@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False): _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) -def get_session_maker() -> sessionmaker: +def get_session_maker() -> sessionmaker[Session]: if _session_maker is None: raise RuntimeError("Session factory not configured. Call configure_session_factory() first.") return _session_maker @@ -27,7 +27,7 @@ class SessionFactory: configure_session_factory(engine, expire_on_commit) @staticmethod - def get_session_maker() -> sessionmaker: + def get_session_maker() -> sessionmaker[Session]: return get_session_maker() @staticmethod diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7fdf5e4be6..135d2a4945 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from enum import StrEnum @@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel): updated_at: datetime @classmethod - def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity": + def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity: """Create entity from database model with decryption""" return cls( diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 12431976f0..a123fb0321 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: list[ModelType] def __init__(self, provider_entity: ProviderEntity): @@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel): label=provider_entity.label, icon_small=provider_entity.icon_small, icon_small_dark=provider_entity.icon_small_dark, - icon_large=provider_entity.icon_large, supported_model_types=provider_entity.supported_model_types, ) @@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel): provider: str label: I18nObject icon_small: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] = [] diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 8a8067332d..0078ec7e4f 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import StrEnum, auto from typing import Union @@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel): TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod - def value_of(cls, value: str) -> "ProviderConfig.Type": + def value_of(cls, value: str) -> ProviderConfig.Type: """ Get value of given mode. diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 6d553d7dc6..2ac483673a 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -8,8 +8,9 @@ import urllib.parse from configs import dify_config -def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str: - url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" +def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str: + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) + url = f"{base_url}/files/{upload_file_id}/file-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() diff --git a/api/core/file/models.py b/api/core/file/models.py index d149205d77..6324523b22 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -112,17 +112,17 @@ class File(BaseModel): return text - def generate_url(self) -> str | None: + def generate_url(self, for_external: bool = True) -> str | None: if self.transfer_method == FileTransferMethod.REMOTE_URL: return self.remote_url elif self.transfer_method == FileTransferMethod.LOCAL_FILE: if self.related_id is None: raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id) + return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: assert self.related_id is not None assert self.extension is not None - return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) + return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external) return None def to_plugin_parameter(self) -> dict[str, Any]: @@ -133,7 +133,7 @@ class File(BaseModel): "extension": self.extension, "size": self.size, "type": self.type, - "url": self.generate_url(), + "url": self.generate_url(for_external=False), } @model_validator(mode="after") diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 6fda073913..5cdea19a8d 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -76,7 +76,7 @@ class TemplateTransformer(ABC): Post-process the result to convert scientific notation strings back to numbers """ - def convert_scientific_notation(value): + def convert_scientific_notation(value: Any) -> Any: if isinstance(value, str): # Check if the string looks like scientific notation if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE): @@ -90,7 +90,7 @@ class TemplateTransformer(ABC): return [convert_scientific_notation(v) for v in value] return value - return convert_scientific_notation(result) # type: ignore[no-any-return] + return convert_scientific_notation(result) @classmethod @abstractmethod diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0b36969cf9..1785cbde4c 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None: return None +def _inject_trace_headers(headers: dict | None) -> dict: + """ + Inject W3C traceparent header for distributed tracing. + + When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically. + When OTEL is disabled, we manually inject the traceparent header. + """ + if headers is None: + headers = {} + + # Skip if already present (case-insensitive check) + for key in headers: + if key.lower() == "traceparent": + return headers + + # Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically + if dify_config.ENABLE_OTEL: + return headers + + # Generate and inject traceparent for non-OTEL scenarios + try: + from core.helper.trace_id_helper import generate_traceparent_header + + traceparent = generate_traceparent_header() + if traceparent: + headers["traceparent"] = traceparent + except Exception: + # Silently ignore errors to avoid breaking requests + logger.debug("Failed to generate traceparent header", exc_info=True) + + return headers + + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + # Convert requests-style allow_redirects to httpx-style follow_redirects if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: @@ -106,18 +140,21 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) client = _get_ssrf_client(verify_option) + # Inject traceparent header for distributed tracing (when OTEL is not enabled) + headers = kwargs.get("headers") or {} + headers = _inject_trace_headers(headers) + kwargs["headers"] = headers + # Preserve user-provided Host header # When using a forward proxy, httpx may override the Host header based on the URL. # We extract and preserve any explicitly set Host header to support virtual hosting. - headers = kwargs.get("headers", {}) user_provided_host = _get_user_provided_host_header(headers) retries = 0 while retries <= max_retries: try: - # Build the request manually to preserve the Host header - # httpx may override the Host header when using a proxy, so we use - # the request API to explicitly set headers before sending + # Preserve the user-provided Host header + # httpx may override the Host header when using a proxy headers = {k: v for k, v in headers.items() if k.lower() != "host"} if user_provided_host is not None: headers["host"] = user_provided_host diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index 820502e558..e827859109 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None: if len(parts) == 4 and len(parts[1]) == 32: return parts[1] return None + + +def get_span_id_from_otel_context() -> str | None: + """ + Retrieve the current span ID from the active OpenTelemetry trace context. + + Returns: + A 16-character hex string representing the span ID, or None if not available. + """ + try: + from opentelemetry.trace import get_current_span + from opentelemetry.trace.span import INVALID_SPAN_ID + + span = get_current_span() + if not span: + return None + + span_context = span.get_span_context() + if not span_context or span_context.span_id == INVALID_SPAN_ID: + return None + + return f"{span_context.span_id:016x}" + except Exception: + return None + + +def generate_traceparent_header() -> str | None: + """ + Generate a W3C traceparent header from the current context. + + Uses OpenTelemetry context if available, otherwise uses the + ContextVar-based trace_id from the logging context. + + Format: {version}-{trace_id}-{span_id}-{flags} + Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01 + + Returns: + A valid traceparent header string, or None if generation fails. + """ + import uuid + + # Try OTEL context first + trace_id = get_trace_id_from_otel_context() + span_id = get_span_id_from_otel_context() + + if trace_id and span_id: + return f"00-{trace_id}-{span_id}-01" + + # Fallback: use ContextVar-based trace_id or generate new one + from core.logging.context import get_trace_id as get_logging_trace_id + + trace_id = get_logging_trace_id() or uuid.uuid4().hex + + # Generate a new span_id (16 hex chars) + span_id = uuid.uuid4().hex[:16] + + return f"00-{trace_id}-{span_id}-01" diff --git a/api/core/logging/__init__.py b/api/core/logging/__init__.py new file mode 100644 index 0000000000..db046cc9fa --- /dev/null +++ b/api/core/logging/__init__.py @@ -0,0 +1,20 @@ +"""Structured logging components for Dify.""" + +from core.logging.context import ( + clear_request_context, + get_request_id, + get_trace_id, + init_request_context, +) +from core.logging.filters import IdentityContextFilter, TraceContextFilter +from core.logging.structured_formatter import StructuredJSONFormatter + +__all__ = [ + "IdentityContextFilter", + "StructuredJSONFormatter", + "TraceContextFilter", + "clear_request_context", + "get_request_id", + "get_trace_id", + "init_request_context", +] diff --git a/api/core/logging/context.py b/api/core/logging/context.py new file mode 100644 index 0000000000..18633a0b05 --- /dev/null +++ b/api/core/logging/context.py @@ -0,0 +1,35 @@ +"""Request context for logging - framework agnostic. + +This module provides request-scoped context variables for logging, +using Python's contextvars for thread-safe and async-safe storage. +""" + +import uuid +from contextvars import ContextVar + +_request_id: ContextVar[str] = ContextVar("log_request_id", default="") +_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="") + + +def get_request_id() -> str: + """Get current request ID (10 hex chars).""" + return _request_id.get() + + +def get_trace_id() -> str: + """Get fallback trace ID when OTEL is unavailable (32 hex chars).""" + return _trace_id.get() + + +def init_request_context() -> None: + """Initialize request context. Call at start of each request.""" + req_id = uuid.uuid4().hex[:10] + trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex + _request_id.set(req_id) + _trace_id.set(trace_id) + + +def clear_request_context() -> None: + """Clear request context. Call at end of request (optional).""" + _request_id.set("") + _trace_id.set("") diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py new file mode 100644 index 0000000000..1e8aa8d566 --- /dev/null +++ b/api/core/logging/filters.py @@ -0,0 +1,94 @@ +"""Logging filters for structured logging.""" + +import contextlib +import logging + +import flask + +from core.logging.context import get_request_id, get_trace_id + + +class TraceContextFilter(logging.Filter): + """ + Filter that adds trace_id and span_id to log records. + Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id. + """ + + def filter(self, record: logging.LogRecord) -> bool: + # Get trace context from OpenTelemetry + trace_id, span_id = self._get_otel_context() + + # Set trace_id (fallback to ContextVar if no OTEL context) + if trace_id: + record.trace_id = trace_id + else: + record.trace_id = get_trace_id() + + record.span_id = span_id or "" + + # For backward compatibility, also set req_id + record.req_id = get_request_id() + + return True + + def _get_otel_context(self) -> tuple[str, str]: + """Extract trace_id and span_id from OpenTelemetry context.""" + with contextlib.suppress(Exception): + from opentelemetry.trace import get_current_span + from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID + + span = get_current_span() + if span and span.get_span_context(): + ctx = span.get_span_context() + if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID: + trace_id = f"{ctx.trace_id:032x}" + span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else "" + return trace_id, span_id + return "", "" + + +class IdentityContextFilter(logging.Filter): + """ + Filter that adds user identity context to log records. + Extracts tenant_id, user_id, and user_type from Flask-Login current_user. + """ + + def filter(self, record: logging.LogRecord) -> bool: + identity = self._extract_identity() + record.tenant_id = identity.get("tenant_id", "") + record.user_id = identity.get("user_id", "") + record.user_type = identity.get("user_type", "") + return True + + def _extract_identity(self) -> dict[str, str]: + """Extract identity from current_user if in request context.""" + try: + if not flask.has_request_context(): + return {} + from flask_login import current_user + + # Check if user is authenticated using the proxy + if not current_user.is_authenticated: + return {} + + # Access the underlying user object + user = current_user + + from models import Account + from models.model import EndUser + + identity: dict[str, str] = {} + + if isinstance(user, Account): + if user.current_tenant_id: + identity["tenant_id"] = user.current_tenant_id + identity["user_id"] = user.id + identity["user_type"] = "account" + elif isinstance(user, EndUser): + identity["tenant_id"] = user.tenant_id + identity["user_id"] = user.id + identity["user_type"] = user.type or "end_user" + + return identity + except Exception: + return {} diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py new file mode 100644 index 0000000000..4295d2dd34 --- /dev/null +++ b/api/core/logging/structured_formatter.py @@ -0,0 +1,107 @@ +"""Structured JSON log formatter for Dify.""" + +import logging +import traceback +from datetime import UTC, datetime +from typing import Any + +import orjson + +from configs import dify_config + + +class StructuredJSONFormatter(logging.Formatter): + """ + JSON log formatter following the specified schema: + { + "ts": "ISO 8601 UTC", + "severity": "INFO|ERROR|WARN|DEBUG", + "service": "service name", + "caller": "file:line", + "trace_id": "hex 32", + "span_id": "hex 16", + "identity": { "tenant_id", "user_id", "user_type" }, + "message": "log message", + "attributes": { ... }, + "stack_trace": "..." + } + """ + + SEVERITY_MAP: dict[int, str] = { + logging.DEBUG: "DEBUG", + logging.INFO: "INFO", + logging.WARNING: "WARN", + logging.ERROR: "ERROR", + logging.CRITICAL: "ERROR", + } + + def __init__(self, service_name: str | None = None): + super().__init__() + self._service_name = service_name or dify_config.APPLICATION_NAME + + def format(self, record: logging.LogRecord) -> str: + log_dict = self._build_log_dict(record) + try: + return orjson.dumps(log_dict).decode("utf-8") + except TypeError: + # Fallback: convert non-serializable objects to string + import json + + return json.dumps(log_dict, default=str, ensure_ascii=False) + + def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: + # Core fields + log_dict: dict[str, Any] = { + "ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"), + "severity": self.SEVERITY_MAP.get(record.levelno, "INFO"), + "service": self._service_name, + "caller": f"{record.filename}:{record.lineno}", + "message": record.getMessage(), + } + + # Trace context (from TraceContextFilter) + trace_id = getattr(record, "trace_id", "") + span_id = getattr(record, "span_id", "") + + if trace_id: + log_dict["trace_id"] = trace_id + if span_id: + log_dict["span_id"] = span_id + + # Identity context (from IdentityContextFilter) + identity = self._extract_identity(record) + if identity: + log_dict["identity"] = identity + + # Dynamic attributes + attributes = getattr(record, "attributes", None) + if attributes: + log_dict["attributes"] = attributes + + # Stack trace for errors with exceptions + if record.exc_info and record.levelno >= logging.ERROR: + log_dict["stack_trace"] = self._format_exception(record.exc_info) + + return log_dict + + def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None: + tenant_id = getattr(record, "tenant_id", None) + user_id = getattr(record, "user_id", None) + user_type = getattr(record, "user_type", None) + + if not any([tenant_id, user_id, user_type]): + return None + + identity: dict[str, str] = {} + if tenant_id: + identity["tenant_id"] = tenant_id + if user_id: + identity["user_id"] = user_id + if user_type: + identity["user_type"] = user_type + return identity + + def _format_exception(self, exc_info: tuple[Any, ...]) -> str: + if exc_info and exc_info[0] is not None: + return "".join(traceback.format_exception(*exc_info)) + return "" diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index c97ae6eac7..84a6fd0d1f 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): request_id: RequestId, request_meta: RequestParams.Meta | None, request: ReceiveRequestT, - session: """BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT - ]""", + session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], ): self.request_id = request_id diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 89dae2dbff..3ac83b4c96 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from collections.abc import Mapping, Sequence from enum import StrEnum, auto @@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum): TOOL = auto() @classmethod - def value_of(cls, value: str) -> "PromptMessageRole": + def value_of(cls, value: str) -> PromptMessageRole: """ Get value of given mode. diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index aee6ce1108..19194d162c 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from decimal import Decimal from enum import StrEnum, auto from typing import Any @@ -20,7 +22,7 @@ class ModelType(StrEnum): TTS = auto() @classmethod - def value_of(cls, origin_model_type: str) -> "ModelType": + def value_of(cls, origin_model_type: str) -> ModelType: """ Get model type from origin model type. @@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum): JSON_SCHEMA = auto() @classmethod - def value_of(cls, value: Any) -> "DefaultParameterName": + def value_of(cls, value: Any) -> DefaultParameterName: """ Get parameter name from value. diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 648b209ef1..2d88751668 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] models: list[AIModelEntity] = [] @@ -123,7 +122,6 @@ class ProviderEntity(BaseModel): label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None - icon_large: I18nObject | None = None icon_small_dark: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None @@ -157,7 +155,6 @@ class ProviderEntity(BaseModel): provider=self.provider, label=self.label, icon_small=self.icon_small, - icon_large=self.icon_large, supported_model_types=self.supported_model_types, models=self.models, ) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b8704ef4ed..28f162a928 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import logging from collections.abc import Sequence @@ -38,7 +40,7 @@ class ModelProviderFactory: plugin_providers = self.get_plugin_model_providers() return [provider.declaration for provider in plugin_providers] - def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]: + def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: """ Get all plugin model providers :return: list of plugin model providers @@ -76,7 +78,7 @@ class ModelProviderFactory: plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) return plugin_model_provider_entity.declaration - def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity": + def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: """ Get plugin model provider :param provider: provider name @@ -285,7 +287,7 @@ class ModelProviderFactory: """ Get provider icon :param provider: provider name - :param icon_type: icon type (icon_small or icon_large) + :param icon_type: icon type (icon_small or icon_small_dark) :param lang: language (zh_Hans or en_US) :return: provider icon """ @@ -309,13 +311,7 @@ class ModelProviderFactory: else: file_name = provider_schema.icon_small_dark.en_US else: - if not provider_schema.icon_large: - raise ValueError(f"Provider {provider} does not have large icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_large.zh_Hans - else: - file_name = provider_schema.icon_large.en_US + raise ValueError(f"Unsupported icon type: {icon_type}.") if not file_name: raise ValueError(f"Provider {provider} does not have icon.") diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 3b83121357..6674228dc0 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum from collections.abc import Mapping, Sequence from datetime import datetime @@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum): return [item.value for item in cls] @classmethod - def of(cls, credential_type: str) -> "CredentialType": + def of(cls, credential_type: str) -> CredentialType: type_name = credential_type.lower() if type_name in {"api-key", "api_key"}: return cls.API_KEY diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7bb2749afa..0e49824ad0 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -103,6 +103,9 @@ class BasePluginClient: prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br") + # Inject traceparent header for distributed tracing + self._inject_trace_headers(prepared_headers) + prepared_data: bytes | dict[str, Any] | str | None = ( data if isinstance(data, (bytes, str, dict)) or data is None else None ) @@ -114,6 +117,31 @@ class BasePluginClient: return str(url), prepared_headers, prepared_data, params, files + def _inject_trace_headers(self, headers: dict[str, str]) -> None: + """ + Inject W3C traceparent header for distributed tracing. + + This ensures trace context is propagated to plugin daemon even if + HTTPXClientInstrumentor doesn't cover module-level httpx functions. + """ + if not dify_config.ENABLE_OTEL: + return + + import contextlib + + # Skip if already present (case-insensitive check) + for key in headers: + if key.lower() == "traceparent": + return + + # Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call + with contextlib.suppress(Exception): + from core.helper.trace_id_helper import generate_traceparent_header + + traceparent = generate_traceparent_header() + if traceparent: + headers["traceparent"] = traceparent + def _stream_request( self, method: str, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6c818bdc8b..10d86d1762 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -331,7 +331,6 @@ class ProviderManager: provider=provider_schema.provider, label=provider_schema.label, icon_small=provider_schema.icon_small, - icon_large=provider_schema.icon_large, supported_model_types=provider_schema.supported_model_types, ), ) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index 9cb009035b..e182c35b99 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -27,26 +27,44 @@ class CleanProcessor: pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" text = re.sub(pattern, "", text) - # Remove URL but keep Markdown image URLs - # First, temporarily replace Markdown image URLs with a placeholder - markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)" - placeholders: list[str] = [] + # Remove URL but keep Markdown image URLs and link URLs + # Replace the ENTIRE markdown link/image with a single placeholder to protect + # the link text (which might also be a URL) from being removed + markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)" + markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)" + placeholders: list[tuple[str, str, str]] = [] # (type, text, url) - def replace_with_placeholder(match, placeholders=placeholders): + def replace_markdown_with_placeholder(match, placeholders=placeholders): + link_type = "link" + link_text = match.group(1) + url = match.group(2) + placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__" + placeholders.append((link_type, link_text, url)) + return placeholder + + def replace_image_with_placeholder(match, placeholders=placeholders): + link_type = "image" url = match.group(1) - placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__" - placeholders.append(url) - return f"![image]({placeholder})" + placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__" + placeholders.append((link_type, "image", url)) + return placeholder - text = re.sub(markdown_image_pattern, replace_with_placeholder, text) + # Protect markdown links first + text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text) + # Then protect markdown images + text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text) # Now remove all remaining URLs - url_pattern = r"https?://[^\s)]+" + url_pattern = r"https?://\S+" text = re.sub(url_pattern, "", text) - # Finally, restore the Markdown image URLs - for i, url in enumerate(placeholders): - text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url) + # Restore the Markdown links and images + for i, (link_type, text_or_alt, url) in enumerate(placeholders): + placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__" + if link_type == "link": + text = text.replace(placeholder, f"[{text_or_alt}]({url})") + else: # image + text = text.replace(placeholder, f"![{text_or_alt}]({url})") return text def filter_string(self, text): diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index a306f9ba0c..91bb71bfa6 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import json import logging @@ -6,7 +8,7 @@ import re import threading import time import uuid -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import clickzetta # type: ignore from pydantic import BaseModel, model_validator @@ -76,7 +78,7 @@ class ClickzettaConnectionPool: Manages connection reuse across ClickzettaVector instances. """ - _instance: Optional["ClickzettaConnectionPool"] = None + _instance: ClickzettaConnectionPool | None = None _lock = threading.Lock() def __init__(self): @@ -89,7 +91,7 @@ class ClickzettaConnectionPool: self._start_cleanup_thread() @classmethod - def get_instance(cls) -> "ClickzettaConnectionPool": + def get_instance(cls) -> ClickzettaConnectionPool: """Get singleton instance of connection pool.""" if cls._instance is None: with cls._lock: @@ -104,7 +106,7 @@ class ClickzettaConnectionPool: f"{config.workspace}:{config.vcluster}:{config.schema_name}" ) - def _create_connection(self, config: ClickzettaConfig) -> "Connection": + def _create_connection(self, config: ClickzettaConfig) -> Connection: """Create a new ClickZetta connection.""" max_retries = 3 retry_delay = 1.0 @@ -134,7 +136,7 @@ class ClickzettaConnectionPool: raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") - def _configure_connection(self, connection: "Connection"): + def _configure_connection(self, connection: Connection): """Configure connection session settings.""" try: with connection.cursor() as cursor: @@ -181,7 +183,7 @@ class ClickzettaConnectionPool: except Exception: logger.exception("Failed to configure connection, continuing with defaults") - def _is_connection_valid(self, connection: "Connection") -> bool: + def _is_connection_valid(self, connection: Connection) -> bool: """Check if connection is still valid.""" try: with connection.cursor() as cursor: @@ -190,7 +192,7 @@ class ClickzettaConnectionPool: except Exception: return False - def get_connection(self, config: ClickzettaConfig) -> "Connection": + def get_connection(self, config: ClickzettaConfig) -> Connection: """Get a connection from the pool or create a new one.""" config_key = self._get_config_key(config) @@ -221,7 +223,7 @@ class ClickzettaConnectionPool: # No valid connection found, create new one return self._create_connection(config) - def return_connection(self, config: ClickzettaConfig, connection: "Connection"): + def return_connection(self, config: ClickzettaConfig, connection: Connection): """Return a connection to the pool.""" config_key = self._get_config_key(config) @@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector): self._connection_pool = ClickzettaConnectionPool.get_instance() self._init_write_queue() - def _get_connection(self) -> "Connection": + def _get_connection(self) -> Connection: """Get a connection from the pool.""" return self._connection_pool.get_connection(self._config) - def _return_connection(self, connection: "Connection"): + def _return_connection(self, connection: Connection): """Return a connection to the pool.""" self._connection_pool.return_connection(self._config, connection) class ConnectionContext: """Context manager for borrowing and returning connections.""" - def __init__(self, vector_instance: "ClickzettaVector"): + def __init__(self, vector_instance: ClickzettaVector): self.vector = vector_instance self.connection: Connection | None = None - def __enter__(self) -> "Connection": + def __enter__(self) -> Connection: self.connection = self.vector._get_connection() return self.connection @@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector): if self.connection: self.vector._return_connection(self.connection) - def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + def get_connection_context(self) -> ClickzettaVector.ConnectionContext: """Get a connection context manager.""" return self.ConnectionContext(self) @@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector): """Return the vector database type.""" return "clickzetta" - def _ensure_connection(self) -> "Connection": + def _ensure_connection(self) -> Connection: """Get a connection from the pool.""" return self._get_connection() @@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector): # No need for dataset_id filter since each dataset has its own table - # Use simple quote escaping for LIKE clause - escaped_query = query.replace("'", "''") - filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'") + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query).replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'") where_clause = " AND ".join(filter_clauses) search_sql = f""" diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index b1bfabb76e..5bdb0af0b3 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -287,11 +287,15 @@ class IrisVector(BaseVector): cursor.execute(sql, (query,)) else: # Fallback to LIKE search (inefficient for large datasets) - query_pattern = f"%{query}%" + # Escape special characters for LIKE clause to prevent SQL injection + from libs.helper import escape_like_pattern + + escaped_query = escape_like_pattern(query) + query_pattern = f"%{escaped_query}%" sql = f""" SELECT TOP {top_k} id, text, meta FROM {self.schema}.{self.table_name} - WHERE text LIKE ? + WHERE text LIKE ? ESCAPE '\\' """ cursor.execute(sql, (query_pattern,)) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 84d1e26b34..b48dd93f04 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -66,6 +66,8 @@ class WeaviateVector(BaseVector): in a Weaviate collection. """ + _DOCUMENT_ID_PROPERTY = "document_id" + def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): """ Initializes the Weaviate vector store. @@ -353,15 +355,12 @@ class WeaviateVector(BaseVector): return [] col = self._client.collections.use(self._collection_name) - props = list({*self._attributes, "document_id", Field.TEXT_KEY.value}) + props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value}) where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -408,10 +407,7 @@ class WeaviateVector(BaseVector): where = None doc_ids = kwargs.get("document_ids_filter") or [] if doc_ids: - ors = [Filter.by_property("document_id").equal(x) for x in doc_ids] - where = ors[0] - for f in ors[1:]: - where = where | f + where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids) top_k = int(kwargs.get("top_k", 4)) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 1fe74d3042..69adac522d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Sequence from typing import Any @@ -22,7 +24,7 @@ class DatasetDocumentStore: self._document_id = document_id @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": + def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore: return cls(**config_dict) def to_dict(self) -> dict[str, Any]: diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 013c287248..6d28ce25bc 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -112,7 +112,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": - extractor = PdfExtractor(file_path) + extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -148,7 +148,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": - extractor = PdfExtractor(file_path) + extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 80530d99a6..6aabcac704 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,25 +1,57 @@ """Abstract interface for document loader implementations.""" import contextlib +import io +import logging +import uuid from collections.abc import Iterator +import pypdfium2 +import pypdfium2.raw as pdfium_c + +from configs import dify_config from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document +from extensions.ext_database import db from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole +from models.model import UploadFile + +logger = logging.getLogger(__name__) class PdfExtractor(BaseExtractor): - """Load pdf files. - + """ + PdfExtractor is used to extract text and images from PDF files. Args: - file_path: Path to the file to load. + file_path: Path to the PDF file. + tenant_id: Workspace ID. + user_id: ID of the user performing the extraction. + file_cache_key: Optional cache key for the extracted text. """ - def __init__(self, file_path: str, file_cache_key: str | None = None): - """Initialize with file path.""" + # Magic bytes for image format detection: (magic_bytes, extension, mime_type) + IMAGE_FORMATS = [ + (b"\xff\xd8\xff", "jpg", "image/jpeg"), + (b"\x89PNG\r\n\x1a\n", "png", "image/png"), + (b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"), + (b"GIF8", "gif", "image/gif"), + (b"BM", "bmp", "image/bmp"), + (b"II*\x00", "tiff", "image/tiff"), + (b"MM\x00*", "tiff", "image/tiff"), + (b"II+\x00", "tiff", "image/tiff"), + (b"MM\x00+", "tiff", "image/tiff"), + ] + MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS) + + def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None): + """Initialize PdfExtractor.""" self._file_path = file_path + self._tenant_id = tenant_id + self._user_id = user_id self._file_cache_key = file_cache_key def extract(self) -> list[Document]: @@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor): def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) @@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor): text_page = page.get_textpage() content = text_page.get_text_range() text_page.close() + + image_content = self._extract_images(page) + if image_content: + content += "\n" + image_content + page.close() metadata = {"source": blob.source, "page": page_number} yield Document(page_content=content, metadata=metadata) finally: pdf_reader.close() + + def _extract_images(self, page) -> str: + """ + Extract images from a PDF page, save them to storage and database, + and return markdown image links. + + Args: + page: pypdfium2 page object. + + Returns: + Markdown string containing links to the extracted images. + """ + image_content = [] + upload_files = [] + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + try: + image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,)) + for obj in image_objects: + try: + # Extract image bytes + img_byte_arr = io.BytesIO() + # Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly + # Fallback to png for other formats + obj.extract(img_byte_arr, fb_format="png") + img_bytes = img_byte_arr.getvalue() + + if not img_bytes: + continue + + header = img_bytes[: self.MAX_MAGIC_LEN] + image_ext = None + mime_type = None + for magic, ext, mime in self.IMAGE_FORMATS: + if header.startswith(magic): + image_ext = ext + mime_type = mime + break + + if not image_ext or not mime_type: + continue + + file_uuid = str(uuid.uuid4()) + file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext + + storage.save(file_key, img_bytes) + + # save file to db + upload_file = UploadFile( + tenant_id=self._tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=len(img_bytes), + extension=image_ext, + mime_type=mime_type, + created_by=self._user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self._user_id, + used_at=naive_utc_now(), + ) + upload_files.append(upload_file) + image_content.append(f"![image]({base_url}/files/{upload_file.id}/file-preview)") + except Exception as e: + logger.warning("Failed to extract image from PDF: %s", e) + continue + except Exception as e: + logger.warning("Failed to get objects from PDF page: %s", e) + if upload_files: + db.session.add_all(upload_files) + db.session.commit() + return "\n".join(image_content) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index f67f613e9d..511f5a698d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -7,10 +7,11 @@ import re import tempfile import uuid from urllib.parse import urlparse -from xml.etree import ElementTree import httpx from docx import Document as DocxDocument +from docx.oxml.ns import qn +from docx.text.run import Run from configs import dify_config from core.helper import ssrf_proxy @@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor): image_map = self._extract_images_from_docx(doc) - hyperlinks_url = None - url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") - for para in doc.paragraphs: - for run in para.runs: - if run.text and hyperlinks_url: - result = f" [{run.text}]({hyperlinks_url}) " - run.text = result - hyperlinks_url = None - if "HYPERLINK" in run.element.xml: - try: - xml = ElementTree.XML(run.element.xml) - x_child = [c for c in xml.iter() if c is not None] - for x in x_child: - if x is None: - continue - if x.tag.endswith("instrText"): - if x.text is None: - continue - for i in url_pattern.findall(x.text): - hyperlinks_url = str(i) - except Exception: - logger.exception("Failed to parse HYPERLINK xml") - def parse_paragraph(paragraph): - paragraph_content = [] - - def append_image_link(image_id, has_drawing): + def append_image_link(image_id, has_drawing, target_buffer): """Helper to append image link from image_map based on relationship type.""" rel = doc.part.rels[image_id] if rel.is_external: if image_id in image_map and not has_drawing: - paragraph_content.append(image_map[image_id]) + target_buffer.append(image_map[image_id]) else: image_part = rel.target_part if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) - for run in paragraph.runs: + def process_run(run, target_buffer): + # Helper to extract text and embedded images from a run element and append them to target_buffer if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): # Process drawing type images drawing_elements = run.element.findall( @@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor): # External image: use embed_id as key if embed_id in image_map: has_drawing = True - paragraph_content.append(image_map[embed_id]) + target_buffer.append(image_map[embed_id]) else: # Internal image: use target_part as key image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: has_drawing = True - paragraph_content.append(image_map[image_part]) + target_buffer.append(image_map[image_part]) # Process pict type images shape_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" @@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) # Find imagedata element in VML image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") if image_data is not None: @@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - append_image_link(image_id, has_drawing) + append_image_link(image_id, has_drawing, target_buffer) if run.text.strip(): - paragraph_content.append(run.text.strip()) + target_buffer.append(run.text.strip()) + + def process_hyperlink(hyperlink_elem, target_buffer): + # Helper to extract text from a hyperlink element and append it to target_buffer + r_id = hyperlink_elem.get(qn("r:id")) + + # Extract text from runs inside the hyperlink + link_text_parts = [] + for run_elem in hyperlink_elem.findall(qn("w:r")): + run = Run(run_elem, paragraph) + # Hyperlink text may be split across multiple runs (e.g., with different formatting), + # so collect all run texts first + if run.text: + link_text_parts.append(run.text) + + link_text = "".join(link_text_parts).strip() + + # Resolve URL + if r_id: + try: + rel = doc.part.rels.get(r_id) + if rel and rel.is_external: + link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})" + except Exception: + logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id) + + if link_text: + target_buffer.append(link_text) + + paragraph_content = [] + # State for legacy HYPERLINK fields + hyperlink_field_url = None + hyperlink_field_text_parts: list = [] + is_collecting_field_text = False + # Iterate through paragraph elements in document order + for child in paragraph._element: + tag = child.tag + if tag == qn("w:r"): + # Regular run + run = Run(child, paragraph) + + # Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks + fld_chars = child.findall(qn("w:fldChar")) + instr_texts = child.findall(qn("w:instrText")) + + # Handle Fields + if fld_chars or instr_texts: + # Process instrText to find HYPERLINK "url" + for instr in instr_texts: + if instr.text and "HYPERLINK" in instr.text: + # Quick regex to extract URL + match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE) + if match: + hyperlink_field_url = match.group(1) + + # Process fldChar + for fld_char in fld_chars: + fld_char_type = fld_char.get(qn("w:fldCharType")) + if fld_char_type == "begin": + # Start of a field: reset legacy link state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + elif fld_char_type == "separate": + # Separator: if we found a URL, start collecting visible text + if hyperlink_field_url: + is_collecting_field_text = True + elif fld_char_type == "end": + # End of field + if is_collecting_field_text and hyperlink_field_url: + # Create markdown link and append to main content + display_text = "".join(hyperlink_field_text_parts).strip() + if display_text: + link_md = f"[{display_text}]({hyperlink_field_url})" + paragraph_content.append(link_md) + # Reset state + hyperlink_field_url = None + hyperlink_field_text_parts = [] + is_collecting_field_text = False + + # Decide where to append content + target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content + process_run(run, target_buffer) + elif tag == qn("w:hyperlink"): + process_hyperlink(child, paragraph_content) return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py index 7472598a7f..bf8db95b4e 100644 --- a/api/core/rag/pipeline/queue.py +++ b/api/core/rag/pipeline/queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Sequence from typing import Any @@ -16,7 +18,7 @@ class TaskWrapper(BaseModel): return self.model_dump_json() @classmethod - def deserialize(cls, serialized_data: str) -> "TaskWrapper": + def deserialize(cls, serialized_data: str) -> TaskWrapper: return cls.model_validate_json(serialized_data) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4ec59940e3..f8f85d141a 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -515,6 +515,7 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + dataset_count = len(available_datasets) with measure_time() as timer: cancel_event = threading.Event() thread_exceptions: list[Exception] = [] @@ -537,6 +538,7 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": query, "attachment_id": None, + "dataset_count": dataset_count, "cancel_event": cancel_event, "thread_exceptions": thread_exceptions, }, @@ -562,6 +564,7 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": None, "attachment_id": attachment_id, + "dataset_count": dataset_count, "cancel_event": cancel_event, "thread_exceptions": thread_exceptions, }, @@ -1195,18 +1198,24 @@ class DatasetRetrieval: json_field = DatasetDocument.doc_metadata[metadata_name].as_string() + from libs.helper import escape_like_pattern + match condition: case "contains": - filters.append(json_field.like(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}%", escape="\\")) case "not contains": - filters.append(json_field.notlike(f"%{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\")) case "start with": - filters.append(json_field.like(f"{value}%")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"{escaped_value}%", escape="\\")) case "end with": - filters.append(json_field.like(f"%{value}")) + escaped_value = escape_like_pattern(str(value)) + filters.append(json_field.like(f"%{escaped_value}", escape="\\")) case "is" | "=": if isinstance(value, str): @@ -1422,6 +1431,7 @@ class DatasetRetrieval: score_threshold: float, query: str | None, attachment_id: str | None, + dataset_count: int, cancel_event: threading.Event | None = None, thread_exceptions: list[Exception] | None = None, ): @@ -1470,37 +1480,38 @@ class DatasetRetrieval: if cancel_event and cancel_event.is_set(): break - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - if query: - all_documents_item = data_post_processor.invoke( - query=query, - documents=all_documents_item, - score_threshold=score_threshold, - top_n=top_k, - query_type=QueryType.TEXT_QUERY, - ) - if attachment_id: - all_documents_item = data_post_processor.invoke( - documents=all_documents_item, - score_threshold=score_threshold, - top_n=top_k, - query_type=QueryType.IMAGE_QUERY, - query=attachment_id, - ) - else: - if index_type == IndexTechniqueType.ECONOMY: - if not query: - all_documents_item = [] - else: - all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) - elif index_type == IndexTechniqueType.HIGH_QUALITY: - all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + # Skip second reranking when there is only one dataset + if reranking_enable and dataset_count > 1: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) else: - all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item - if all_documents_item: - all_documents.extend(all_documents_item) + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) except Exception as e: if cancel_event: cancel_event.set() diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index 51bfae1cd3..b4ecfe47ff 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import logging import threading from collections.abc import Mapping, MutableMapping from pathlib import Path -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar class SchemaRegistry: @@ -11,7 +13,7 @@ class SchemaRegistry: logger: ClassVar[logging.Logger] = logging.getLogger(__name__) - _default_instance: ClassVar[Optional["SchemaRegistry"]] = None + _default_instance: ClassVar[SchemaRegistry | None] = None _lock: ClassVar[threading.Lock] = threading.Lock() def __init__(self, base_dir: str): @@ -20,7 +22,7 @@ class SchemaRegistry: self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {} @classmethod - def default_registry(cls) -> "SchemaRegistry": + def default_registry(cls) -> SchemaRegistry: """Returns the default schema registry for builtin schemas (thread-safe singleton)""" if cls._default_instance is None: with cls._lock: diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index cdbfd027ee..24fc11aefc 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy @@ -25,7 +27,7 @@ class Tool(ABC): self.entity = entity self.runtime = runtime - def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool: """ fork a new tool with metadata :return: the new tool @@ -221,7 +223,7 @@ class Tool(ABC): type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) ) - def create_file_message(self, file: "File") -> ToolInvokeMessage: + def create_file_message(self, file: File) -> ToolInvokeMessage: return ToolInvokeMessage( type=ToolInvokeMessage.MessageType.FILE, message=ToolInvokeMessage.FileMessage(), diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 84efefba07..51b0407886 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.__base.tool import Tool @@ -24,7 +26,7 @@ class BuiltinTool(Tool): super().__init__(**kwargs) self.provider = provider - def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool: """ fork a new tool with metadata :return: the new tool diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 0cc992155a..e2f6c00555 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pydantic import Field from sqlalchemy import select @@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController): self.tools = [] @classmethod - def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": + def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController: credentials_schema = [ ProviderConfig( name="auth_type", diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 583a3584f7..b5c7a6310c 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import contextlib from collections.abc import Mapping @@ -55,7 +57,7 @@ class ToolProviderType(StrEnum): MCP = auto() @classmethod - def value_of(cls, value: str) -> "ToolProviderType": + def value_of(cls, value: str) -> ToolProviderType: """ Get value of given mode. @@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum): OPENAI_ACTIONS = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderSchemaType": + def value_of(cls, value: str) -> ApiProviderSchemaType: """ Get value of given mode. @@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum): API_KEY_QUERY = auto() @classmethod - def value_of(cls, value: str) -> "ApiProviderAuthType": + def value_of(cls, value: str) -> ApiProviderAuthType: """ Get value of given mode. @@ -307,7 +309,7 @@ class ToolParameter(PluginParameter): typ: ToolParameterType, required: bool, options: list[str] | None = None, - ) -> "ToolParameter": + ) -> ToolParameter: """ get a simple tool parameter @@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel): tool_config: dict | None = None @classmethod - def empty(cls) -> "ToolInvokeMeta": + def empty(cls) -> ToolInvokeMeta: """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> "ToolInvokeMeta": + def error_instance(cls, error: str) -> ToolInvokeMeta: """ Get an instance of ToolInvokeMeta with error """ diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 96917045e3..ef9e9c103a 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import logging @@ -118,7 +120,7 @@ class MCPTool(Tool): for item in json_list: yield self.create_json_message(item) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool: return MCPTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 828dc3b810..d3a2ad488c 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Generator from typing import Any @@ -46,7 +48,7 @@ class PluginTool(Tool): message_id=message_id, ) - def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool: return PluginTool( entity=self.entity, runtime=runtime, diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index fef3157f27..22e099deba 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -7,12 +7,12 @@ import time from configs import dify_config -def sign_tool_file(tool_file_id: str, extension: str) -> str: +def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str: """ sign file to get a temporary url for plugin access """ - # Use internal URL for plugin/tool file access in Docker environments - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + # Use internal URL for plugin/tool file access in Docker environments, unless for_external is True + base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL) file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3486182192..584975de05 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser: @staticmethod def auto_parse_to_tool_bundle( content: str, extra_info: dict | None = None, warning: dict | None = None - ) -> tuple[list[ApiToolBundle], str]: + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ auto parse to tool bundle diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 0f9a91a111..4bfaa5e49b 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -4,6 +4,7 @@ import re def remove_leading_symbols(text: str) -> str: """ Remove leading punctuation or symbols from the given text. + Preserves markdown links like [text](url) at the start. Args: text (str): The input text to process. @@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str: Returns: str: The text with leading punctuation or symbols removed. """ + # Check if text starts with a markdown link - preserve it + markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)" + if re.match(markdown_link_pattern, text): + return text + # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 2bd973f831..a706f101ca 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping from pydantic import Field @@ -47,14 +49,13 @@ class WorkflowToolProviderController(ToolProviderController): self.provider_id = provider_id @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": + def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: with session_factory.create_session() as session, session.begin(): app = session.get(App, db_provider.app_id) if not app: raise ValueError("app not found") user = session.get(Account, db_provider.user_id) if db_provider.user_id else None - controller = WorkflowToolProviderController( entity=ToolProviderEntity( identity=ToolProviderIdentity( @@ -67,7 +68,7 @@ class WorkflowToolProviderController(ToolProviderController): credentials_schema=[], plugin_id=None, ), - provider_id="", + provider_id=db_provider.id, ) controller.tools = [ diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 30334f5da8..81a1d54199 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging from collections.abc import Generator, Mapping, Sequence @@ -181,7 +183,7 @@ class WorkflowTool(Tool): return found return None - def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": + def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool: """ fork a new tool with metadata diff --git a/api/core/variables/types.py b/api/core/variables/types.py index ce71711344..13b926c978 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Mapping from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.file.models import File @@ -52,7 +54,7 @@ class SegmentType(StrEnum): return self in _ARRAY_TYPES @classmethod - def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + def infer_segment_type(cls, value: Any) -> SegmentType | None: """ Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. @@ -173,7 +175,7 @@ class SegmentType(StrEnum): raise AssertionError("this statement should be unreachable.") @staticmethod - def cast_value(value: Any, type_: "SegmentType"): + def cast_value(value: Any, type_: SegmentType): # Cast Python's `bool` type to `int` when the runtime type requires # an integer or number. # @@ -193,7 +195,7 @@ class SegmentType(StrEnum): return [int(i) for i in value] return value - def exposed_type(self) -> "SegmentType": + def exposed_type(self) -> SegmentType: """Returns the type exposed to the frontend. The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. @@ -202,7 +204,7 @@ class SegmentType(StrEnum): return SegmentType.NUMBER return self - def element_type(self) -> "SegmentType | None": + def element_type(self) -> SegmentType | None: """Return the element type of the current segment type, or `None` if the element type is undefined. Raises: @@ -217,7 +219,7 @@ class SegmentType(StrEnum): return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) @staticmethod - def get_zero_value(t: "SegmentType"): + def get_zero_value(t: SegmentType): # Lazy import to avoid circular dependency from factories import variable_factory diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md index 72f5dbe1e2..9a39f976a6 100644 --- a/api/core/workflow/README.md +++ b/api/core/workflow/README.md @@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO")) engine.layer(ExecutionLimitsLayer(max_nodes=100)) ``` +`engine.layer()` binds the read-only runtime state before execution, so layer hooks +can assume `graph_runtime_state` is available. + ### Event-Driven Architecture All node executions emit events for monitoring and integration: diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index a8a86d3db2..1b3fb36f1f 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain implementation details like tenant_id, app_id, etc. """ +from __future__ import annotations + from collections.abc import Mapping from datetime import datetime from typing import Any @@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any], inputs: Mapping[str, Any], started_at: datetime, - ) -> "WorkflowExecution": + ) -> WorkflowExecution: return WorkflowExecution( id_=id_, workflow_id=workflow_id, diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index ba5a01fc94..7be94c2426 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import defaultdict from collections.abc import Mapping, Sequence @@ -175,7 +177,7 @@ class Graph: def _create_node_instances( cls, node_configs_map: dict[str, dict[str, object]], - node_factory: "NodeFactory", + node_factory: NodeFactory, ) -> dict[str, Node]: """ Create node instances from configurations using the node factory. @@ -197,7 +199,7 @@ class Graph: return nodes @classmethod - def new(cls) -> "GraphBuilder": + def new(cls) -> GraphBuilder: """Create a fluent builder for assembling a graph programmatically.""" return GraphBuilder(graph_cls=cls) @@ -284,9 +286,9 @@ class Graph: cls, *, graph_config: Mapping[str, object], - node_factory: "NodeFactory", + node_factory: NodeFactory, root_node_id: str | None = None, - ) -> "Graph": + ) -> Graph: """ Initialize graph @@ -383,7 +385,7 @@ class GraphBuilder: self._edges: list[Edge] = [] self._edge_counter = 0 - def add_root(self, node: Node) -> "GraphBuilder": + def add_root(self, node: Node) -> GraphBuilder: """Register the root node. Must be called exactly once.""" if self._nodes: @@ -398,7 +400,7 @@ class GraphBuilder: *, from_node_id: str | None = None, source_handle: str = "source", - ) -> "GraphBuilder": + ) -> GraphBuilder: """Append a node and connect it from the specified predecessor.""" if not self._nodes: @@ -419,7 +421,7 @@ class GraphBuilder: return self - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder": + def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: """Connect two existing nodes without adding a new node.""" if tail not in self._nodes_by_id: diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 4be3adb8f8..0fccd4a0fd 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue. import json from typing import TYPE_CHECKING, Any, final -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand if TYPE_CHECKING: from extensions.ext_redis import RedisClientWrapper @@ -113,6 +113,8 @@ class RedisChannel: return AbortCommand.model_validate(data) if command_type == CommandType.PAUSE: return PauseCommand.model_validate(data) + if command_type == CommandType.UPDATE_VARIABLES: + return UpdateVariablesCommand.model_validate(data) # For other command types, use base class return GraphEngineCommand.model_validate(data) diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py index 837f5e55fd..7b4f0dfff7 100644 --- a/api/core/workflow/graph_engine/command_processing/__init__.py +++ b/api/core/workflow/graph_engine/command_processing/__init__.py @@ -5,11 +5,12 @@ This package handles external commands sent to the engine during execution. """ -from .command_handlers import AbortCommandHandler, PauseCommandHandler +from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler from .command_processor import CommandProcessor __all__ = [ "AbortCommandHandler", "CommandProcessor", "PauseCommandHandler", + "UpdateVariablesCommandHandler", ] diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index e9f109c88c..cfe856d9e8 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -4,9 +4,10 @@ from typing import final from typing_extensions import override from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.runtime import VariablePool from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand from .command_processor import CommandHandler logger = logging.getLogger(__name__) @@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler): reason = command.reason pause_reason = SchedulingPause(message=reason) execution.pause(pause_reason) + + +@final +class UpdateVariablesCommandHandler(CommandHandler): + def __init__(self, variable_pool: VariablePool) -> None: + self._variable_pool = variable_pool + + @override + def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: + assert isinstance(command, UpdateVariablesCommand) + for update in command.updates: + try: + variable = update.value + self._variable_pool.add(variable.selector, variable) + logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) + except ValueError as exc: + logger.warning( + "Skipping invalid variable selector %s for workflow %s: %s", + getattr(update.value, "selector", None), + execution.workflow_id, + exc, + ) diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 0d51b2b716..6dce03c94d 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine instance to control its execution flow. """ -from enum import StrEnum +from collections.abc import Sequence +from enum import StrEnum, auto from typing import Any from pydantic import BaseModel, Field +from core.variables.variables import VariableUnion + class CommandType(StrEnum): """Types of commands that can be sent to GraphEngine.""" - ABORT = "abort" - PAUSE = "pause" + ABORT = auto() + PAUSE = auto() + UPDATE_VARIABLES = auto() class GraphEngineCommand(BaseModel): @@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand): command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") reason: str = Field(default="unknown reason", description="reason for pause") + + +class VariableUpdate(BaseModel): + """Represents a single variable update instruction.""" + + value: VariableUnion = Field(description="New variable value") + + +class UpdateVariablesCommand(GraphEngineCommand): + """Command to update a group of variables in the variable pool.""" + + command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") + updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2e8b8f345f..9a870d7bf5 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,9 +5,12 @@ This engine uses a modular architecture with separated packages following Domain-Driven Design principles for improved maintainability and testability. """ +from __future__ import annotations + import contextvars import logging import queue +import threading from collections.abc import Generator from typing import TYPE_CHECKING, cast, final @@ -30,8 +33,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr if TYPE_CHECKING: # pragma: no cover - used only for static analysis from core.workflow.runtime.graph_runtime_state import GraphProtocol -from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler -from .entities.commands import AbortCommand, PauseCommand +from .command_processing import ( + AbortCommandHandler, + CommandProcessor, + PauseCommandHandler, + UpdateVariablesCommandHandler, +) +from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager from .graph_state_manager import GraphStateManager @@ -70,10 +78,13 @@ class GraphEngine: scale_down_idle_time: float | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" + # stop event + self._stop_event = threading.Event() # Bind runtime state to current workflow context self._graph = graph self._graph_runtime_state = graph_runtime_state + self._graph_runtime_state.stop_event = self._stop_event self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel @@ -140,6 +151,9 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) + self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) + # === Extensibility === # Layers allow plugins to extend engine functionality self._layers: list[GraphEngineLayer] = [] @@ -169,6 +183,7 @@ class GraphEngine: max_workers=self._max_workers, scale_up_threshold=self._scale_up_threshold, scale_down_idle_time=self._scale_down_idle_time, + stop_event=self._stop_event, ) # === Orchestration === @@ -199,6 +214,7 @@ class GraphEngine: event_handler=self._event_handler_registry, execution_coordinator=self._execution_coordinator, event_emitter=self._event_manager, + stop_event=self._stop_event, ) # === Validation === @@ -212,9 +228,16 @@ class GraphEngine: if id(node.graph_runtime_state) != expected_state_id: raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - def layer(self, layer: GraphEngineLayer) -> "GraphEngine": + def _bind_layer_context( + self, + layer: GraphEngineLayer, + ) -> None: + layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) + + def layer(self, layer: GraphEngineLayer) -> GraphEngine: """Add a layer for extending functionality.""" self._layers.append(layer) + self._bind_layer_context(layer) return self def run(self) -> Generator[GraphEngineEvent, None, None]: @@ -301,14 +324,7 @@ class GraphEngine: def _initialize_layers(self) -> None: """Initialize layers with context.""" self._event_manager.set_layers(self._layers) - # Create a read-only wrapper for the runtime state - read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state) for layer in self._layers: - try: - layer.initialize(read_only_state, self._command_channel) - except Exception as e: - logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) - try: layer.on_graph_start() except Exception as e: @@ -316,6 +332,7 @@ class GraphEngine: def _start_execution(self, *, resume: bool = False) -> None: """Start execution subsystems.""" + self._stop_event.clear() paused_nodes: list[str] = [] if resume: paused_nodes = self._graph_runtime_state.consume_paused_nodes() @@ -343,13 +360,12 @@ class GraphEngine: def _stop_execution(self) -> None: """Stop execution subsystems.""" + self._stop_event.set() self._dispatcher.stop() self._worker_pool.stop() # Don't mark complete here as the dispatcher already does it # Notify layers - logger = logging.getLogger(__name__) - for layer in self._layers: try: layer.on_graph_end(self._graph_execution.error) diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py index 78f8ecdcdf..b9c9243963 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py @@ -60,6 +60,7 @@ class SkipPropagator: if edge_states["has_taken"]: # Enqueue node self._state_manager.enqueue_node(downstream_node_id) + self._state_manager.start_execution(downstream_node_id) return # All edges are skipped, propagate skip to this node diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md index 17845ee1f0..b0f295037c 100644 --- a/api/core/workflow/graph_engine/layers/README.md +++ b/api/core/workflow/graph_engine/layers/README.md @@ -8,7 +8,7 @@ Pluggable middleware for engine extensions. Abstract base class for layers. -- `initialize()` - Receive runtime context +- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) - `on_graph_start()` - Execution start hook - `on_event()` - Process all events - `on_graph_end()` - Execution end hook @@ -34,6 +34,9 @@ engine.layer(debug_layer) engine.run() ``` +`engine.layer()` binds the read-only runtime state before execution, so +`graph_runtime_state` is always available inside layer hooks. + ## Custom Layers ```python diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 780f92a0f4..89293b9b30 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState +class GraphEngineLayerNotInitializedError(Exception): + """Raised when a layer's runtime state is accessed before initialization.""" + + def __init__(self, layer_name: str | None = None) -> None: + name = layer_name or "GraphEngineLayer" + super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") + + class GraphEngineLayer(ABC): """ Abstract base class for GraphEngine layers. @@ -28,22 +36,27 @@ class GraphEngineLayer(ABC): def __init__(self) -> None: """Initialize the layer. Subclasses can override with custom parameters.""" - self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None + self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None self.command_channel: CommandChannel | None = None + @property + def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: + if self._graph_runtime_state is None: + raise GraphEngineLayerNotInitializedError(type(self).__name__) + return self._graph_runtime_state + def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: """ Initialize the layer with engine dependencies. - Called by GraphEngine before execution starts to inject the read-only runtime state - and command channel. This allows layers to observe engine context and send - commands, but prevents direct state modification. - + Called by GraphEngine to inject the read-only runtime state and command channel. + This is invoked when the layer is registered with a `GraphEngine` instance. + Implementations should be idempotent. Args: graph_runtime_state: Read-only view of the runtime state command_channel: Channel for sending commands to the engine """ - self.graph_runtime_state = graph_runtime_state + self._graph_runtime_state = graph_runtime_state self.command_channel = command_channel @abstractmethod diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index 034ebcf54f..e0402cd09c 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info("=" * 80) self.logger.info("🚀 GRAPH EXECUTION STARTED") self.logger.info("=" * 80) - - if self.graph_runtime_state: - # Log initial state - self.logger.info("Initial State:") + # Log initial state + self.logger.info("Initial State:") @override def on_event(self, event: GraphEngineEvent) -> None: @@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer): self.logger.info(" Node retries: %s", self.retry_count) # Log final state if available - if self.graph_runtime_state and self.include_outputs: - if self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) + if self.include_outputs and self.graph_runtime_state.outputs: + self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) self.logger.info("=" * 80) diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/workflow/graph_engine/layers/persistence.py index b70f36ec9e..e81df4f3b7 100644 --- a/api/core/workflow/graph_engine/layers/persistence.py +++ b/api/core/workflow/graph_engine/layers/persistence.py @@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer): if update_finished: execution.finished_at = naive_utc_now() runtime_state = self.graph_runtime_state - if runtime_state is None: - return execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs @@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _system_variables(self) -> Mapping[str, Any]: runtime_state = self.graph_runtime_state - if runtime_state is None: - return {} return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index 0577ba8f02..d2cfa755d9 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel. This module provides a simplified interface for controlling workflow executions using the new Redis command channel, without requiring user permission checks. -Supports stop, pause, and resume operations. """ import logging +from collections.abc import Sequence from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + GraphEngineCommand, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -23,7 +29,6 @@ class GraphEngineManager: This class provides a simple interface for controlling workflow executions by sending commands through Redis channels, without user validation. - Supports stop and pause operations. """ @staticmethod @@ -45,6 +50,16 @@ class GraphEngineManager: pause_command = PauseCommand(reason=reason or "User requested pause") GraphEngineManager._send_command(task_id, pause_command) + @staticmethod + def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None: + """Send a command to update variables in a running workflow.""" + + if not updates: + return + + update_command = UpdateVariablesCommand(updates=updates) + GraphEngineManager._send_command(task_id, update_command) + @staticmethod def _send_command(task_id: str, command: GraphEngineCommand) -> None: """Send a command to the workflow-specific Redis channel.""" diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 334a3f77bf..27439a2412 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -44,6 +44,7 @@ class Dispatcher: event_queue: queue.Queue[GraphNodeEventBase], event_handler: "EventHandler", execution_coordinator: ExecutionCoordinator, + stop_event: threading.Event, event_emitter: EventManager | None = None, ) -> None: """ @@ -61,7 +62,7 @@ class Dispatcher: self._event_emitter = event_emitter self._thread: threading.Thread | None = None - self._stop_event = threading.Event() + self._stop_event = stop_event self._start_time: float | None = None def start(self) -> None: @@ -69,16 +70,14 @@ class Dispatcher: if self._thread and self._thread.is_alive(): return - self._stop_event.clear() self._start_time = time.time() self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) self._thread.start() def stop(self) -> None: """Stop the dispatcher thread.""" - self._stop_event.set() if self._thread and self._thread.is_alive(): - self._thread.join(timeout=10.0) + self._thread.join(timeout=2.0) def _dispatcher_loop(self) -> None: """Main dispatcher loop.""" diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py index 1144e1de69..a9d4f470e5 100644 --- a/api/core/workflow/graph_engine/ready_queue/factory.py +++ b/api/core/workflow/graph_engine/ready_queue/factory.py @@ -2,6 +2,8 @@ Factory for creating ReadyQueue instances from serialized state. """ +from __future__ import annotations + from typing import TYPE_CHECKING from .in_memory import InMemoryReadyQueue @@ -11,7 +13,7 @@ if TYPE_CHECKING: from .protocol import ReadyQueue -def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue": +def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: """ Create a ReadyQueue instance from a serialized state. diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8b7c2e441e..8ceaa428c3 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally by ResponseStreamCoordinator to manage streaming sessions. """ +from __future__ import annotations + from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode @@ -27,7 +29,7 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> "ResponseSession": + def from_node(cls, node: Node) -> ResponseSession: """ Create a ResponseSession from an AnswerNode or EndNode. diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index e37a08ae47..83419830b6 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -42,6 +42,7 @@ class Worker(threading.Thread): event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: Sequence[GraphEngineLayer], + stop_event: threading.Event, worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -65,13 +66,16 @@ class Worker(threading.Thread): self._worker_id = worker_id self._flask_app = flask_app self._context_vars = context_vars - self._stop_event = threading.Event() self._last_task_time = time.time() + self._stop_event = stop_event self._layers = layers if layers is not None else [] def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() + """Worker is controlled via shared stop_event from GraphEngine. + + This method is a no-op retained for backward compatibility. + """ + pass @property def is_idle(self) -> bool: diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 5b9234586b..df76ebe882 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -41,6 +41,7 @@ class WorkerPool: event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, layers: list[GraphEngineLayer], + stop_event: threading.Event, flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -81,6 +82,7 @@ class WorkerPool: self._worker_counter = 0 self._lock = threading.RLock() self._running = False + self._stop_event = stop_event # No longer tracking worker states with callbacks to avoid lock contention @@ -135,7 +137,7 @@ class WorkerPool: # Wait for workers to finish for worker in self._workers: if worker.is_alive(): - worker.join(timeout=10.0) + worker.join(timeout=2.0) self._workers.clear() @@ -152,6 +154,7 @@ class WorkerPool: worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, + stop_event=self._stop_event, ) worker.start() diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 4be006de11..234651ce96 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, cast @@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]): variable_pool: VariablePool, node_data: AgentNodeData, for_log: bool = False, - strategy: "PluginAgentStrategy", + strategy: PluginAgentStrategy, ) -> dict[str, Any]: """ Generate parameters based on the given tool parameters, variable pool, and node data. @@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]): def _generate_credentials( self, parameters: dict[str, Any], - ) -> "InvokeCredentials": + ) -> InvokeCredentials: """ Generate credentials based on the given agent parameters. """ @@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]): model_schema.features.remove(feature) return model_schema - def _filter_mcp_type_tool( - self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]] - ) -> list[dict[str, Any]]: + def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Filter MCP type tool :param strategy: plugin agent strategy diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py index 944f5f0b20..ba2c83d8a6 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exc.py @@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError): self.expected_type = expected_type self.actual_type = actual_type super().__init__(message) + + +class AgentMaxIterationError(AgentNodeError): + """Exception raised when the agent exceeds the maximum iteration limit.""" + + def __init__(self, max_iteration: int): + self.max_iteration = max_iteration + super().__init__( + f"Agent exceeded the maximum iteration limit of {max_iteration}. " + f"The agent was unable to complete the task within the allowed number of iterations." + ) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 5aab6bbde4..e5a20c8e91 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from abc import ABC from builtins import type as type_ @@ -111,7 +113,7 @@ class DefaultValue(BaseModel): raise DefaultValueTypeError(f"Cannot convert to number: {value}") @model_validator(mode="after") - def validate_value_type(self) -> "DefaultValue": + def validate_value_type(self) -> DefaultValue: # Type validation configuration type_validators = { DefaultValueType.STRING: { diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 448e07e78c..06e4c0440d 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import logging import operator @@ -62,7 +64,7 @@ logger = logging.getLogger(__name__) class Node(Generic[NodeDataT]): - node_type: ClassVar["NodeType"] + node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData @@ -201,14 +203,14 @@ class Node(Generic[NodeDataT]): return None # Global registry populated via __init_subclass__ - _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {} + _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} def __init__( self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params self.id = id @@ -244,7 +246,7 @@ class Node(Generic[NodeDataT]): return @property - def graph_init_params(self) -> "GraphInitParams": + def graph_init_params(self) -> GraphInitParams: return self._graph_init_params @property @@ -267,6 +269,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def _should_stop(self) -> bool: + """Check if execution should be stopped.""" + return self.graph_runtime_state.stop_event.is_set() + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -335,6 +341,21 @@ class Node(Generic[NodeDataT]): yield event else: yield event + + if self._should_stop(): + error_message = "Execution cancelled" + yield NodeRunFailedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + start_at=self._start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error_message, + ), + error=error_message, + ) + return except Exception as e: logger.exception("Node %s failed to run", self._node_id) result = NodeRunResult( @@ -441,7 +462,7 @@ class Node(Generic[NodeDataT]): raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod - def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]: + def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. Import all modules under core.workflow.nodes so subclasses register themselves on import. diff --git a/api/core/workflow/nodes/base/template.py b/api/core/workflow/nodes/base/template.py index ba3e2058cf..81f4b9f6fb 100644 --- a/api/core/workflow/nodes/base/template.py +++ b/api/core/workflow/nodes/base/template.py @@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes, similar to SegmentGroup but focused on template representation without values. """ +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -58,7 +60,7 @@ class Template: segments: list[TemplateSegmentUnion] @classmethod - def from_answer_template(cls, template_str: str) -> "Template": + def from_answer_template(cls, template_str: str) -> Template: """Create a Template from an Answer node template string. Example: @@ -107,7 +109,7 @@ class Template: return cls(segments=segments) @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template": + def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: """Create a Template from an End node outputs configuration. End nodes are treated as templates of concatenated variables with newlines. diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index a38e10030a..e3035d3bf0 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,8 +1,7 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast -from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider @@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.limits import CodeNodeLimits from .exc import ( CodeNodeError, @@ -20,9 +20,41 @@ from .exc import ( OutputValidationError, ) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + class CodeNode(Node[CodeNodeData]): node_type = NodeType.CODE + _DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = ( + Python3CodeProvider, + JavascriptCodeProvider, + ) + _limits: CodeNodeLimits + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + code_executor: type[CodeExecutor] | None = None, + code_providers: Sequence[type[CodeNodeProvider]] | None = None, + code_limits: CodeNodeLimits, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor + self._code_providers: tuple[type[CodeNodeProvider], ...] = ( + tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS + ) + self._limits = code_limits @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]): if filters: code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next( + provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language) + ) return code_provider.get_default_config() + @classmethod + def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]: + return cls._DEFAULT_CODE_PROVIDERS + @classmethod def version(cls) -> str: return "1" @@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]): variables[variable_name] = variable.to_object() if variable else None # Run code try: - result = CodeExecutor.execute_workflow_code_template( + _ = self._select_code_provider(code_language) + result = self._code_executor.execute_workflow_code_template( language=code_language, code=code, inputs=variables, @@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) + def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]: + for provider in self._code_providers: + if provider.is_accept_language(code_language): + return provider + raise CodeNodeError(f"Unsupported code language: {code_language}") + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string @@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]): if value is None: return None - if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + if len(value) > self._limits.max_string_length: raise OutputValidationError( f"The length of output variable `{variable}` must be" - f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + f" less than {self._limits.max_string_length} characters" ) return value.replace("\x00", "") @@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]): if value is None: return None - if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: + if value > self._limits.max_number or value < self._limits.min_number: raise OutputValidationError( f"Output variable `{variable}` is out of range," - f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + f" it must be between {self._limits.min_number} and {self._limits.max_number}." ) if isinstance(value, float): decimal_value = Decimal(str(value)).normalize() precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] # raise error if precision is too high - if precision > dify_config.CODE_MAX_PRECISION: + if precision > self._limits.max_precision: raise OutputValidationError( f"Output variable `{variable}` has too high precision," - f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + f" it must be less than {self._limits.max_precision} digits." ) return value @@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]): # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. # Note that `_transform_result` may produce lists containing `None` values, # which don't conform to the type requirements of `Array*Segment` classes. - if depth > dify_config.CODE_MAX_DEPTH: - raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") + if depth > self._limits.max_depth: + raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") transformed_result: dict[str, Any] = {} if output_schema is None: @@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]): f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." ) else: - if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: + if len(value) > self._limits.max_number_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." + f" less than {self._limits.max_number_array_length} elements." ) for i, inner_value in enumerate(value): @@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]): f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: + if len(result[output_name]) > self._limits.max_string_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." + f" less than {self._limits.max_string_array_length} elements." ) transformed_result[output_name] = [ @@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]): f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: + if len(result[output_name]) > self._limits.max_object_array_length: raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." + f" less than {self._limits.max_object_array_length} elements." ) for i, value in enumerate(result[output_name]): diff --git a/api/core/workflow/nodes/code/limits.py b/api/core/workflow/nodes/code/limits.py new file mode 100644 index 0000000000..a6b9e9e68e --- /dev/null +++ b/api/core/workflow/nodes/code/limits.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class CodeNodeLimits: + max_string_length: int + max_number: int | float + min_number: int | float + max_precision: int + max_depth: int + max_number_array_length: int + max_string_array_length: int + max_object_array_length: int diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f1f07addd3..20de710db3 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import io import json @@ -134,7 +136,7 @@ class LLMNode(Node[LLMNodeData]): # Instance attributes specific to LLMNode. # Output variable for file - _file_outputs: list["File"] + _file_outputs: list[File] _llm_file_saver: LLMFileSaver @@ -142,8 +144,8 @@ class LLMNode(Node[LLMNodeData]): self, id: str, config: Mapping[str, Any], - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -445,7 +447,7 @@ class LLMNode(Node[LLMNodeData]): structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -499,7 +501,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], node_id: str, node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", @@ -675,7 +677,7 @@ class LLMNode(Node[LLMNodeData]): ) @staticmethod - def _image_file_to_markdown(file: "File", /): + def _image_file_to_markdown(file: File, /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -924,7 +926,7 @@ class LLMNode(Node[LLMNodeData]): def fetch_prompt_messages( *, sys_query: str | None = None, - sys_files: Sequence["File"], + sys_files: Sequence[File], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -935,7 +937,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - context_files: list["File"] | None = None, + context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -1287,7 +1289,7 @@ class LLMNode(Node[LLMNodeData]): *, invoke_result: LLMResult | LLMResultWithStructuredOutput, saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], reasoning_format: Literal["separated", "tagged"] = "tagged", request_latency: float | None = None, ) -> ModelInvokeCompletedEvent: @@ -1329,7 +1331,7 @@ class LLMNode(Node[LLMNodeData]): *, content: ImagePromptMessageContent, file_saver: LLMFileSaver, - ) -> "File": + ) -> File: """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -1379,7 +1381,7 @@ class LLMNode(Node[LLMNodeData]): *, contents: str | list[PromptMessageContentUnionTypes] | None, file_saver: LLMFileSaver, - file_outputs: list["File"], + file_outputs: list[File], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index c55ad346bf..557d3a330d 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -1,10 +1,21 @@ +from collections.abc import Sequence from typing import TYPE_CHECKING, final from typing_extensions import override +from configs import dify_config +from core.helper.code_executor.code_executor import CodeExecutor +from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.workflow.enums import NodeType from core.workflow.graph import NodeFactory from core.workflow.nodes.base.node import Node +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, + Jinja2TemplateRenderer, +) +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from libs.typing import is_str, is_str_dict from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING @@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory): self, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", + *, + code_executor: type[CodeExecutor] | None = None, + code_providers: Sequence[type[CodeNodeProvider]] | None = None, + code_limits: CodeNodeLimits | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state + self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor + self._code_providers: tuple[type[CodeNodeProvider], ...] = ( + tuple(code_providers) if code_providers else CodeNode.default_code_providers() + ) + self._code_limits = code_limits or CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ) + self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() @override def create_node(self, node_config: dict[str, object]) -> Node: @@ -72,6 +103,25 @@ class DifyNodeFactory(NodeFactory): raise ValueError(f"No latest version class found for node type: {node_type}") # Create node instance + if node_type == NodeType.CODE: + return CodeNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + code_executor=self._code_executor, + code_providers=self._code_providers, + code_limits=self._code_limits, + ) + if node_type == NodeType.TEMPLATE_TRANSFORM: + return TemplateTransformNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + template_renderer=self._template_renderer, + ) + return node_class( id=node_id, config=node_config, diff --git a/api/core/workflow/nodes/template_transform/template_renderer.py b/api/core/workflow/nodes/template_transform/template_renderer.py new file mode 100644 index 0000000000..a5f06bf2bb --- /dev/null +++ b/api/core/workflow/nodes/template_transform/template_renderer.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Protocol + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage + + +class TemplateRenderError(ValueError): + """Raised when rendering a Jinja2 template fails.""" + + +class Jinja2TemplateRenderer(Protocol): + """Render Jinja2 templates for template transform nodes.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + """Render a Jinja2 template with provided variables.""" + raise NotImplementedError + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Adapter that renders Jinja2 templates via CodeExecutor.""" + + _code_executor: type[CodeExecutor] + + def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None: + self._code_executor = code_executor or CodeExecutor + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = self._code_executor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, code=template, inputs=variables + ) + except CodeExecutionError as exc: + raise TemplateRenderError(str(exc)) from exc + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2274323960..f7e0bccccf 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,18 +1,44 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from core.workflow.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, + Jinja2TemplateRenderer, + TemplateRenderError, +) + +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = NodeType.TEMPLATE_TRANSFORM + _template_renderer: Jinja2TemplateRenderer + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + template_renderer: Jinja2TemplateRenderer | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): variables[variable_name] = value.to_object() if value else None # Run code try: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables - ) - except CodeExecutionError as e: + rendered = self._template_renderer.render_template(self.node_data.template, variables) + except TemplateRenderError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, @@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} ) @classmethod diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py deleted file mode 100644 index 050e213535..0000000000 --- a/api/core/workflow/nodes/variable_assigner/common/impl.py +++ /dev/null @@ -1,28 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.variables.variables import Variable -from extensions.ext_database import db -from models import ConversationVariable - -from .exc import VariableOperatorNodeError - - -class ConversationVariableUpdaterImpl: - def update(self, conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() - - def flush(self): - pass - - -def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: - return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index da23207b62..d2ea7d94ea 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,9 +1,8 @@ -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, TypeAlias +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: from core.workflow.runtime import GraphRuntimeState -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - - class VariableAssignerNode(Node[VariableAssignerData]): node_type = NodeType.VARIABLE_ASSIGNER - _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY def __init__( self, @@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]): config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, ): super().__init__( id=id, @@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._conv_var_updater_factory = conv_var_updater_factory @classmethod def version(cls) -> str: @@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - # TODO: Move database operation to the pipeline. - # Update conversation variable. - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - raise VariableOperatorNodeError("conversation_id not found") - conv_var_updater = self._conv_var_updater_factory() - conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) - conv_var_updater.flush() updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={ diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 389fb54d35..486e6bb6a7 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,24 +1,20 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem from .enums import InputType, Operation from .exc import ( - ConversationIDNotFoundError, InputTypeNotSupportedError, InvalidDataError, InvalidInputValueError, @@ -26,6 +22,10 @@ from .exc import ( VariableNotFoundError, ) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ class VariableAssignerNode(Node[VariableAssignerNodeData]): node_type = NodeType.VARIABLE_ASSIGNER + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: """ Check if this Variable Assigner node blocks the output of specific variables. @@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): return False - def _conv_var_updater_factory(self) -> ConversationVariableUpdater: - return conversation_variable_updater_factory() - @classmethod def version(cls) -> str: return "2" @@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) - conv_var_updater = self._conv_var_updater_factory() - # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) if not isinstance(variable, Variable): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value - if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) - if not conversation_id: - if self.invoke_from != InvokeFrom.DEBUGGER: - raise ConversationIDNotFoundError - else: - conversation_id = conversation_id.value - conv_var_updater.update( - conversation_id=cast(str, conversation_id), - variable=variable, - ) - conv_var_updater.flush() updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/core/workflow/repositories/draft_variable_repository.py index 97bfcd5666..66ef714c16 100644 --- a/api/core/workflow/repositories/draft_variable_repository.py +++ b/api/core/workflow/repositories/draft_variable_repository.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc from collections.abc import Mapping from typing import Any, Protocol @@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol): node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, - ) -> "DraftVariableSaver": + ) -> DraftVariableSaver: pass diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 1561b789df..401cecc162 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -2,6 +2,7 @@ from __future__ import annotations import importlib import json +import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass @@ -168,6 +169,7 @@ class GraphRuntimeState: self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() + self.stop_event: threading.Event = threading.Event() if graph is not None: self.attach_graph(graph) diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py index 5e0878e873..bfbb5ba704 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.model_runtime.entities.llm_entities import LLMUsage @@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView class ReadOnlyVariablePool(Protocol): """Read-only interface for VariablePool.""" - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Get a variable value (read-only).""" ... diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py index 8539727fd6..d3e4c60d9b 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any @@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper: def __init__(self, variable_pool: VariablePool) -> None: self._variable_pool = variable_pool - def get(self, node_id: str, variable_key: str) -> Segment | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """Return a copy of a variable value if present.""" - value = self._variable_pool.get([node_id, variable_key]) + value = self._variable_pool.get(selector) return deepcopy(value) if value is not None else None def get_all_by_node(self, node_id: str) -> Mapping[str, object]: diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 7fbaec9e70..85ceb9d59e 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import defaultdict from collections.abc import Mapping, Sequence @@ -267,6 +269,6 @@ class VariablePool(BaseModel): self.add(selector, value) @classmethod - def empty(cls) -> "VariablePool": + def empty(cls) -> VariablePool: """Create an empty variable pool.""" return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py index ad925912a4..cda8091771 100644 --- a/api/core/workflow/system_variable.py +++ b/api/core/workflow/system_variable.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence from types import MappingProxyType from typing import Any @@ -70,7 +72,7 @@ class SystemVariable(BaseModel): return data @classmethod - def empty(cls) -> "SystemVariable": + def empty(cls) -> SystemVariable: return cls() def to_dict(self) -> dict[SystemVariableKey, Any]: @@ -114,7 +116,7 @@ class SystemVariable(BaseModel): d[SystemVariableKey.TIMESTAMP] = self.timestamp return d - def as_view(self) -> "SystemVariableReadOnlyView": + def as_view(self) -> SystemVariableReadOnlyView: return SystemVariableReadOnlyView(self) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 5a69eb15ac..c0279f893b 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -3,8 +3,9 @@ set -e # Set UTF-8 encoding to address potential encoding issues in containerized environments -export LANG=${LANG:-en_US.UTF-8} -export LC_ALL=${LC_ALL:-en_US.UTF-8} +# Use C.UTF-8 which is universally available in all containers +export LANG=${LANG:-C.UTF-8} +export LC_ALL=${LC_ALL:-C.UTF-8} export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8} if [[ "${MIGRATION_ENABLED}" == "true" ]]; then diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index cf994c11df..7d13f0c061 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) +EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE) EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") @@ -42,10 +43,28 @@ def init_app(app: DifyApp): _apply_cors_once( web_bp, - resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, - supports_credentials=True, - allow_headers=list(AUTHENTICATED_HEADERS), - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + resources={ + # Embedded bot endpoints (unauthenticated, cross-origin safe) + r"^/chat-messages$": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + r"^/chat-messages/.*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": False, + "allow_headers": list(EMBED_HEADERS), + "methods": ["GET", "POST", "OPTIONS"], + }, + # Default web application endpoints (authenticated) + r"/*": { + "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS, + "supports_credentials": True, + "allow_headers": list(AUTHENTICATED_HEADERS), + "methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + }, + }, expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(web_bp) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 5cf4984709..2fbab001d0 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -12,9 +12,8 @@ from dify_app import DifyApp def _get_celery_ssl_options() -> dict[str, Any] | None: """Get SSL configuration for Celery broker/backend connections.""" - # Use REDIS_USE_SSL for consistency with the main Redis client # Only apply SSL if we're using Redis as broker/backend - if not dify_config.REDIS_USE_SSL: + if not dify_config.BROKER_USE_SSL: return None # Check if Celery is actually using Redis @@ -47,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None: def init_app(app: DifyApp) -> Celery: class FlaskTask(Task): def __call__(self, *args: object, **kwargs: object) -> object: + from core.logging.context import init_request_context + with app.app_context(): + # Initialize logging context for this task (similar to before_request in Flask) + init_request_context() return self.run(*args, **kwargs) broker_transport_options = {} diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 71a63168a5..daa3756dba 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -11,6 +11,7 @@ def init_app(app: DifyApp): create_tenant, extract_plugins, extract_unique_plugins, + file_usage, fix_app_site_missing, install_plugins, install_rag_pipeline_plugins, @@ -47,6 +48,7 @@ def init_app(app: DifyApp): clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, remove_orphaned_files_on_storage, + file_usage, setup_system_tool_oauth_client, setup_system_trigger_oauth_client, cleanup_orphaned_draft_variables, diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index c90b1d0a9f..2e0d4c889a 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -53,3 +53,10 @@ def _setup_gevent_compatibility(): def init_app(app: DifyApp): db.init_app(app) _setup_gevent_compatibility() + + # Eagerly build the engine so pool_size/max_overflow/etc. come from config + try: + with app.app_context(): + _ = db.engine # triggers engine creation with the configured options + except Exception: + logger.exception("Failed to initialize SQLAlchemy engine during app startup") diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 000d03ac41..978a40c503 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -1,18 +1,19 @@ +"""Logging extension for Dify Flask application.""" + import logging import os import sys -import uuid from logging.handlers import RotatingFileHandler -import flask - from configs import dify_config -from core.helper.trace_id_helper import get_trace_id_from_otel_context from dify_app import DifyApp def init_app(app: DifyApp): + """Initialize logging with support for text or JSON format.""" log_handlers: list[logging.Handler] = [] + + # File handler log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) @@ -25,27 +26,53 @@ def init_app(app: DifyApp): ) ) - # Always add StreamHandler to log to console + # Console handler sh = logging.StreamHandler(sys.stdout) log_handlers.append(sh) - # Apply RequestIdFilter to all handlers - for handler in log_handlers: - handler.addFilter(RequestIdFilter()) + # Apply filters to all handlers + from core.logging.filters import IdentityContextFilter, TraceContextFilter + for handler in log_handlers: + handler.addFilter(TraceContextFilter()) + handler.addFilter(IdentityContextFilter()) + + # Configure formatter based on format type + formatter = _create_formatter() + for handler in log_handlers: + handler.setFormatter(formatter) + + # Configure root logger logging.basicConfig( level=dify_config.LOG_LEVEL, - format=dify_config.LOG_FORMAT, - datefmt=dify_config.LOG_DATEFORMAT, handlers=log_handlers, force=True, ) - # Apply RequestIdFormatter to all handlers - apply_request_id_formatter() - # Disable propagation for noisy loggers to avoid duplicate logs logging.getLogger("sqlalchemy.engine").propagate = False + + # Apply timezone if specified (only for text format) + if dify_config.LOG_OUTPUT_FORMAT == "text": + _apply_timezone(log_handlers) + + +def _create_formatter() -> logging.Formatter: + """Create appropriate formatter based on configuration.""" + if dify_config.LOG_OUTPUT_FORMAT == "json": + from core.logging.structured_formatter import StructuredJSONFormatter + + return StructuredJSONFormatter() + else: + # Text format - use existing pattern with backward compatible formatter + return _TextFormatter( + fmt=dify_config.LOG_FORMAT, + datefmt=dify_config.LOG_DATEFORMAT, + ) + + +def _apply_timezone(handlers: list[logging.Handler]): + """Apply timezone conversion to text formatters.""" log_tz = dify_config.LOG_TZ if log_tz: from datetime import datetime @@ -57,34 +84,51 @@ def init_app(app: DifyApp): def time_converter(seconds): return datetime.fromtimestamp(seconds, tz=timezone).timetuple() - for handler in logging.root.handlers: + for handler in handlers: if handler.formatter: - handler.formatter.converter = time_converter + handler.formatter.converter = time_converter # type: ignore[attr-defined] -def get_request_id(): - if getattr(flask.g, "request_id", None): - return flask.g.request_id +class _TextFormatter(logging.Formatter): + """Text formatter that ensures trace_id and req_id are always present.""" - new_uuid = uuid.uuid4().hex[:10] - flask.g.request_id = new_uuid - - return new_uuid + def format(self, record: logging.LogRecord) -> str: + if not hasattr(record, "req_id"): + record.req_id = "" + if not hasattr(record, "trace_id"): + record.trace_id = "" + if not hasattr(record, "span_id"): + record.span_id = "" + return super().format(record) +def get_request_id() -> str: + """Get request ID for current request context. + + Deprecated: Use core.logging.context.get_request_id() directly. + """ + from core.logging.context import get_request_id as _get_request_id + + return _get_request_id() + + +# Backward compatibility aliases class RequestIdFilter(logging.Filter): - # This is a logging filter that makes the request ID available for use in - # the logging format. Note that we're checking if we're in a request - # context, as we may want to log things before Flask is fully loaded. - def filter(self, record): - trace_id = get_trace_id_from_otel_context() or "" - record.req_id = get_request_id() if flask.has_request_context() else "" - record.trace_id = trace_id + """Deprecated: Use TraceContextFilter from core.logging.filters instead.""" + + def filter(self, record: logging.LogRecord) -> bool: + from core.logging.context import get_request_id as _get_request_id + from core.logging.context import get_trace_id as _get_trace_id + + record.req_id = _get_request_id() + record.trace_id = _get_trace_id() return True class RequestIdFormatter(logging.Formatter): - def format(self, record): + """Deprecated: Use _TextFormatter instead.""" + + def format(self, record: logging.LogRecord) -> str: if not hasattr(record, "req_id"): record.req_id = "" if not hasattr(record, "trace_id"): @@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter): def apply_request_id_formatter(): + """Deprecated: Formatter is now applied in init_app.""" for handler in logging.root.handlers: if handler.formatter: handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT) diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py index 22d1f473a3..8c64a25be4 100644 --- a/api/extensions/logstore/aliyun_logstore.py +++ b/api/extensions/logstore/aliyun_logstore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import threading @@ -33,7 +35,7 @@ class AliyunLogStore: Ensures only one instance exists to prevent multiple PG connection pools. """ - _instance: "AliyunLogStore | None" = None + _instance: AliyunLogStore | None = None _initialized: bool = False # Track delayed PG connection for newly created projects @@ -66,7 +68,7 @@ class AliyunLogStore: "\t", ] - def __new__(cls) -> "AliyunLogStore": + def __new__(cls) -> AliyunLogStore: """Implement singleton pattern.""" if cls._instance is None: cls._instance = super().__new__(cls) diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 6e6631cfef..1119534d52 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) +def to_serializable(obj): + """ + Convert non-JSON-serializable objects into JSON-compatible formats. + + - Uses `to_dict()` if it's a callable method. + - Falls back to string representation. + """ + if hasattr(obj, "to_dict") and callable(obj.to_dict): + return obj.to_dict() + return str(obj) + + class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): def __init__( self, @@ -69,6 +81,11 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): # Set to True to enable dual-write for safe migration, False to use LogStore only self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + # Control flag for whether to write the `graph` field to LogStore. + # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; + # otherwise write an empty {} instead. Defaults to writing the `graph` field. + self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true" + def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]: """ Convert a domain model to a logstore model (List[Tuple[str, str]]). @@ -108,9 +125,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): ), ("type", domain_model.workflow_type.value), ("version", domain_model.workflow_version), - ("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"), - ("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"), - ("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"), + ( + "graph", + json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable) + if domain_model.graph and self._enable_put_graph_field + else "{}", + ), + ( + "inputs", + json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable) + if domain_model.inputs + else "{}", + ), + ( + "outputs", + json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable) + if domain_model.outputs + else "{}", + ), ("status", domain_model.status.value), ("error_message", domain_model.error_message or ""), ("total_tokens", str(domain_model.total_tokens)), diff --git a/api/extensions/otel/instrumentation.py b/api/extensions/otel/instrumentation.py index 3597110cba..6617f69513 100644 --- a/api/extensions/otel/instrumentation.py +++ b/api/extensions/otel/instrumentation.py @@ -19,26 +19,43 @@ logger = logging.getLogger(__name__) class ExceptionLoggingHandler(logging.Handler): + """ + Handler that records exceptions to the current OpenTelemetry span. + + Unlike creating a new span, this records exceptions on the existing span + to maintain trace context consistency throughout the request lifecycle. + """ + def emit(self, record: logging.LogRecord): with contextlib.suppress(Exception): - if record.exc_info: - tracer = get_tracer_provider().get_tracer("dify.exception.logging") - with tracer.start_as_current_span( - "log.exception", - attributes={ - "log.level": record.levelname, - "log.message": record.getMessage(), - "log.logger": record.name, - "log.file.path": record.pathname, - "log.file.line": record.lineno, - }, - ) as span: - span.set_status(StatusCode.ERROR) - if record.exc_info[1]: - span.record_exception(record.exc_info[1]) - span.set_attribute("exception.message", str(record.exc_info[1])) - if record.exc_info[0]: - span.set_attribute("exception.type", record.exc_info[0].__name__) + if not record.exc_info: + return + + from opentelemetry.trace import get_current_span + + span = get_current_span() + if not span or not span.is_recording(): + return + + # Record exception on the current span instead of creating a new one + span.set_status(StatusCode.ERROR, record.getMessage()) + + # Add log context as span events/attributes + span.add_event( + "log.exception", + attributes={ + "log.level": record.levelname, + "log.message": record.getMessage(), + "log.logger": record.name, + "log.file.path": record.pathname, + "log.file.line": record.lineno, + }, + ) + + if record.exc_info[1]: + span.record_exception(record.exc_info[1]) + if record.exc_info[0]: + span.set_attribute("exception.type", record.exc_info[0].__name__) def instrument_exception_logging() -> None: diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 51a97b20f8..1d9911465b 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -5,6 +5,8 @@ automatic cleanup, backup and restore. Supports complete lifecycle management for knowledge base files. """ +from __future__ import annotations + import json import logging import operator @@ -48,7 +50,7 @@ class FileMetadata: return data @classmethod - def from_dict(cls, data: dict) -> "FileMetadata": + def from_dict(cls, data: dict) -> FileMetadata: """Create instance from dictionary""" data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 38835d5ac7..e69306dcb2 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import TimestampField @@ -12,7 +12,7 @@ annotation_fields = { } -def build_annotation_model(api_or_ns: Api | Namespace): +def build_annotation_model(api_or_ns: Namespace): """Build the annotation model for the API or Namespace.""" return api_or_ns.model("Annotation", annotation_fields) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index d5b2574edc..fe59cdcbb4 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,237 +1,339 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from fields.member_fields import simple_account_fields -from libs.helper import TimestampField +from datetime import datetime +from typing import Any, TypeAlias -from .raws import FilesContainedField +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from core.file import File + +JSONValue: TypeAlias = Any -class MessageTextField(fields.Raw): - def format(self, value): - return value[0]["text"] if value else "" +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -feedback_fields = { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_fields, allow_null=True), -} +class MessageFile(ResponseModel): + id: str + filename: str + type: str + url: str | None = None + mime_type: str | None = None + size: int | None = None + transfer_method: str + belongs_to: str | None = None + upload_file_id: str | None = None -annotation_fields = { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_fields, allow_null=True), - "created_at": TimestampField, -} - -annotation_hit_history_fields = { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True), - "created_at": TimestampField, -} - -message_file_fields = { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), -} + @field_validator("transfer_method", mode="before") + @classmethod + def _normalize_transfer_method(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) -def build_message_file_model(api_or_ns: Api | Namespace): - """Build the message file fields for the API or Namespace.""" - return api_or_ns.model("MessageFile", message_file_fields) +class SimpleConversation(ResponseModel): + id: str + name: str + inputs: dict[str, JSONValue] + status: str + introduction: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value -agent_thought_fields = { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), -} - -message_detail_fields = { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_fields)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_fields, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields)), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - "generation_detail": fields.Raw, -} - -feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} -status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer} -model_config_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "model": fields.Raw, - "user_input_form": fields.Raw, - "pre_prompt": fields.String, - "agent_mode": fields.Raw, -} - -simple_model_config_fields = { - "model": fields.Raw(attribute="model_dict"), - "pre_prompt": fields.String, -} - -simple_message_detail_fields = { - "inputs": FilesContainedField, - "query": fields.String, - "message": MessageTextField, - "answer": fields.String, -} - -conversation_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String(), - "from_account_id": fields.String, - "from_account_name": fields.String, - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotation": fields.Nested(annotation_fields, allow_null=True), - "model_config": fields.Nested(simple_model_config_fields), - "user_feedback_stats": fields.Nested(feedback_stat_fields), - "admin_feedback_stats": fields.Nested(feedback_stat_fields), - "message": fields.Nested(simple_message_detail_fields, attribute="first_message"), -} - -conversation_pagination_fields = { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_fields), attribute="items"), -} - -conversation_message_detail_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "model_config": fields.Nested(model_config_fields), - "message": fields.Nested(message_detail_fields, attribute="first_message"), -} - -conversation_with_summary_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String, - "from_account_id": fields.String, - "from_account_name": fields.String, - "name": fields.String, - "summary": fields.String(attribute="summary_or_query"), - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "model_config": fields.Nested(simple_model_config_fields), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_fields), - "admin_feedback_stats": fields.Nested(feedback_stat_fields), - "status_count": fields.Nested(status_count_fields), -} - -conversation_with_summary_pagination_fields = { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"), -} - -conversation_detail_fields = { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "introduction": fields.String, - "model_config": fields.Nested(model_config_fields), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_fields), - "admin_feedback_stats": fields.Nested(feedback_stat_fields), -} - -simple_conversation_fields = { - "id": fields.String, - "name": fields.String, - "inputs": FilesContainedField, - "status": fields.String, - "introduction": fields.String, - "created_at": TimestampField, - "updated_at": TimestampField, -} - -conversation_delete_fields = { - "result": fields.String, -} - -conversation_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(simple_conversation_fields)), -} +class ConversationInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[SimpleConversation] -def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): - """Build the conversation infinite scroll pagination model for the API or Namespace.""" - simple_conversation_model = build_simple_conversation_model(api_or_ns) - - copied_fields = conversation_infinite_scroll_pagination_fields.copy() - copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model)) - return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) +class ConversationDelete(ResponseModel): + result: str -def build_conversation_delete_model(api_or_ns: Api | Namespace): - """Build the conversation delete model for the API or Namespace.""" - return api_or_ns.model("ConversationDelete", conversation_delete_fields) +class ResultResponse(ResponseModel): + result: str -def build_simple_conversation_model(api_or_ns: Api | Namespace): - """Build the simple conversation model for the API or Namespace.""" - return api_or_ns.model("SimpleConversation", simple_conversation_fields) +class SimpleAccount(ResponseModel): + id: str + name: str + email: str + + +class Feedback(ResponseModel): + rating: str + content: str | None = None + from_source: str + from_end_user_id: str | None = None + from_account: SimpleAccount | None = None + + +class Annotation(ResponseModel): + id: str + question: str | None = None + content: str + account: SimpleAccount | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class AnnotationHitHistory(ResponseModel): + annotation_id: str + annotation_create_account: SimpleAccount | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class AgentThought(ResponseModel): + id: str + chain_id: str | None = None + message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id") + message_id: str + position: int + thought: str | None = None + tool: str | None = None + tool_labels: JSONValue + tool_input: str | None = None + created_at: int | None = None + observation: str | None = None + files: list[str] + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + @model_validator(mode="after") + def _fallback_chain_id(self): + if self.chain_id is None and self.message_chain_id: + self.chain_id = self.message_chain_id + return self + + +class MessageDetail(ResponseModel): + id: str + conversation_id: str + inputs: dict[str, JSONValue] + query: str + message: JSONValue + message_tokens: int + answer: str + answer_tokens: int + provider_response_latency: float + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + feedbacks: list[Feedback] + workflow_run_id: str | None = None + annotation: Annotation | None = None + annotation_hit_history: AnnotationHitHistory | None = None + created_at: int | None = None + agent_thoughts: list[AgentThought] + message_files: list[MessageFile] + metadata: JSONValue + status: str + error: str | None = None + parent_message_id: str | None = None + generation_detail: JSONValue | None = Field(default=None, validation_alias="generation_detail_dict") + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class FeedbackStat(ResponseModel): + like: int + dislike: int + + +class StatusCount(ResponseModel): + success: int + failed: int + partial_success: int + + +class ModelConfig(ResponseModel): + opening_statement: str | None = None + suggested_questions: JSONValue | None = None + model: JSONValue | None = None + user_input_form: JSONValue | None = None + pre_prompt: str | None = None + agent_mode: JSONValue | None = None + + +class SimpleModelConfig(ResponseModel): + model: JSONValue | None = None + pre_prompt: str | None = None + + +class SimpleMessageDetail(ResponseModel): + inputs: dict[str, JSONValue] + query: str + message: str + answer: str + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + +class Conversation(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_end_user_session_id: str | None = None + from_account_id: str | None = None + from_account_name: str | None = None + read_at: int | None = None + created_at: int | None = None + updated_at: int | None = None + annotation: Annotation | None = None + model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config") + user_feedback_stats: FeedbackStat | None = None + admin_feedback_stats: FeedbackStat | None = None + message: SimpleMessageDetail | None = None + + +class ConversationPagination(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[Conversation] + + +class ConversationMessageDetail(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + created_at: int | None = None + model_config_: ModelConfig | None = Field(default=None, alias="model_config") + message: MessageDetail | None = None + + +class ConversationWithSummary(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_end_user_session_id: str | None = None + from_account_id: str | None = None + from_account_name: str | None = None + name: str + summary: str + read_at: int | None = None + created_at: int | None = None + updated_at: int | None = None + annotated: bool + model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config") + message_count: int + user_feedback_stats: FeedbackStat | None = None + admin_feedback_stats: FeedbackStat | None = None + status_count: StatusCount | None = None + + +class ConversationWithSummaryPagination(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationWithSummary] + + +class ConversationDetail(ResponseModel): + id: str + status: str + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + created_at: int | None = None + updated_at: int | None = None + annotated: bool + introduction: str | None = None + model_config_: ModelConfig | None = Field(default=None, alias="model_config") + message_count: int + user_feedback_stats: FeedbackStat | None = None + admin_feedback_stats: FeedbackStat | None = None + + +def to_timestamp(value: datetime | None) -> int | None: + if value is None: + return None + return int(value.timestamp()) + + +def format_files_contained(value: JSONValue) -> JSONValue: + if isinstance(value, File): + return value.model_dump() + if isinstance(value, dict): + return {k: format_files_contained(v) for k, v in value.items()} + if isinstance(value, list): + return [format_files_contained(v) for v in value] + return value + + +def message_text(value: JSONValue) -> str: + if isinstance(value, list) and value: + first = value[0] + if isinstance(first, dict): + text = first.get("text") + if isinstance(text, str): + return text + return "" + + +def extract_model_config(value: object | None) -> dict[str, JSONValue]: + if value is None: + return {} + if isinstance(value, dict): + return value + if hasattr(value, "to_dict"): + return value.to_dict() + return {} diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 7d5e311591..c55014a368 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import TimestampField @@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = { } -def build_conversation_variable_model(api_or_ns: Api | Namespace): +def build_conversation_variable_model(api_or_ns: Namespace): """Build the conversation variable model for the API or Namespace.""" return api_or_ns.model("ConversationVariable", conversation_variable_fields) -def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): +def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace): """Build the conversation variable infinite scroll pagination model for the API or Namespace.""" # Build the nested variable model first conversation_variable_model = build_conversation_variable_model(api_or_ns) diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index ea43e3b5fd..5389b0213a 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields simple_end_user_fields = { "id": fields.String, @@ -8,5 +8,5 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Api | Namespace): +def build_simple_end_user_model(api_or_ns: Namespace): return api_or_ns.model("SimpleEndUser", simple_end_user_fields) diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index a707500445..913fb675f9 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,93 +1,85 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -upload_config_fields = { - "file_size_limit": fields.Integer, - "batch_count_limit": fields.Integer, - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "workflow_file_upload_limit": fields.Integer, - "image_file_batch_limit": fields.Integer, - "single_chunk_attachment_limit": fields.Integer, -} +from pydantic import BaseModel, ConfigDict, field_validator -def build_upload_config_model(api_or_ns: Api | Namespace): - """Build the upload config model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("UploadConfig", upload_config_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -file_fields = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, - "preview_url": fields.String, - "source_url": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def build_file_model(api_or_ns: Api | Namespace): - """Build the file model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("File", file_fields) +class UploadConfig(ResponseModel): + file_size_limit: int + batch_count_limit: int + file_upload_limit: int | None = None + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + workflow_file_upload_limit: int + image_file_batch_limit: int + single_chunk_attachment_limit: int + attachment_image_file_size_limit: int | None = None -remote_file_info_fields = { - "file_type": fields.String(attribute="file_type"), - "file_length": fields.Integer(attribute="file_length"), -} +class FileResponse(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + preview_url: str | None = None + source_url: str | None = None + original_url: str | None = None + user_id: str | None = None + tenant_id: str | None = None + conversation_id: str | None = None + file_key: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_remote_file_info_model(api_or_ns: Api | Namespace): - """Build the remote file info model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) +class RemoteFileInfo(ResponseModel): + file_type: str + file_length: int -file_fields_with_signed_url = { - "id": fields.String, - "name": fields.String, - "size": fields.Integer, - "extension": fields.String, - "url": fields.String, - "mime_type": fields.String, - "created_by": fields.String, - "created_at": TimestampField, -} +class FileWithSignedUrl(ResponseModel): + id: str + name: str + size: int + extension: str | None = None + url: str | None = None + mime_type: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_file_with_signed_url_model(api_or_ns: Api | Namespace): - """Build the file with signed URL model for the API or Namespace. - - Args: - api_or_ns: Flask-RestX Api or Namespace instance - - Returns: - The registered model - """ - return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) +__all__ = [ + "FileResponse", + "FileWithSignedUrl", + "RemoteFileInfo", + "UploadConfig", +] diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 08e38a6931..25160927e6 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import AvatarUrlField, TimestampField @@ -9,7 +9,7 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Api | Namespace): +def build_simple_account_model(api_or_ns: Namespace): return api_or_ns.model("SimpleAccount", simple_account_fields) diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 8b9bcac76f..797f01c00c 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,78 +1,138 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField +from datetime import datetime +from typing import TypeAlias -from .raws import FilesContainedField +from pydantic import BaseModel, ConfigDict, Field, field_validator -feedback_fields = { - "rating": fields.String, -} +from core.file import File +from fields.conversation_fields import AgentThought, JSONValue, MessageFile + +JSONValueType: TypeAlias = JSONValue -def build_feedback_model(api_or_ns: Api | Namespace): - """Build the feedback model for the API or Namespace.""" - return api_or_ns.model("Feedback", feedback_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict(from_attributes=True, extra="ignore") -agent_thought_fields = { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), -} +class SimpleFeedback(ResponseModel): + rating: str | None = None -def build_agent_thought_model(api_or_ns: Api | Namespace): - """Build the agent thought model for the API or Namespace.""" - return api_or_ns.model("AgentThought", agent_thought_fields) +class RetrieverResource(ResponseModel): + id: str + message_id: str + position: int + dataset_id: str | None = None + dataset_name: str | None = None + document_id: str | None = None + document_name: str | None = None + data_source_type: str | None = None + segment_id: str | None = None + score: float | None = None + hit_count: int | None = None + word_count: int | None = None + segment_position: int | None = None + index_node_hash: str | None = None + content: str | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value -retriever_resource_fields = { - "id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "dataset_id": fields.String, - "dataset_name": fields.String, - "document_id": fields.String, - "document_name": fields.String, - "data_source_type": fields.String, - "segment_id": fields.String, - "score": fields.Float, - "hit_count": fields.Integer, - "word_count": fields.Integer, - "segment_position": fields.Integer, - "index_node_hash": fields.String, - "content": fields.String, - "created_at": TimestampField, -} +class MessageListItem(ResponseModel): + id: str + conversation_id: str + parent_message_id: str | None = None + inputs: dict[str, JSONValueType] + query: str + answer: str = Field(validation_alias="re_sign_file_url_answer") + feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback") + retriever_resources: list[RetrieverResource] + created_at: int | None = None + agent_thoughts: list[AgentThought] + message_files: list[MessageFile] + status: str + error: str | None = None + generation_detail: JSONValueType | None = Field(default=None, validation_alias="generation_detail_dict") -message_fields = { - "id": fields.String, - "conversation_id": fields.String, - "parent_message_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), - "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields)), - "status": fields.String, - "error": fields.String, - "generation_detail": fields.Raw, -} + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType: + return format_files_contained(value) -message_infinite_scroll_pagination_fields = { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), -} + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class WebMessageListItem(MessageListItem): + metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict") + + +class MessageInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[MessageListItem] + + +class WebMessageInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[WebMessageListItem] + + +class SavedMessageItem(ResponseModel): + id: str + inputs: dict[str, JSONValueType] + query: str + answer: str + message_files: list[MessageFile] + feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback") + created_at: int | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class SavedMessageInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[SavedMessageItem] + + +class SuggestedQuestionsResponse(ResponseModel): + data: list[str] + + +def to_timestamp(value: datetime | None) -> int | None: + if value is None: + return None + return int(value.timestamp()) + + +def format_files_contained(value: JSONValueType) -> JSONValueType: + if isinstance(value, File): + return value.model_dump() + if isinstance(value, dict): + return {k: format_files_contained(v) for k, v in value.items()} + if isinstance(value, list): + return [format_files_contained(v) for v in value] + return value diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py index f9e858c68b..97c02e7085 100644 --- a/api/fields/rag_pipeline_fields.py +++ b/api/fields/rag_pipeline_fields.py @@ -1,4 +1,4 @@ -from flask_restx import fields # type: ignore +from flask_restx import fields from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index d5b7c86a04..e359a4408c 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields dataset_tag_fields = { "id": fields.String, @@ -8,5 +8,5 @@ dataset_tag_fields = { } -def build_dataset_tag_fields(api_or_ns: Api | Namespace): +def build_dataset_tag_fields(api_or_ns: Namespace): return api_or_ns.model("DataSetTag", dataset_tag_fields) diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 4cbdf6f0ca..0ebc03a98c 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields from fields.member_fields import build_simple_account_model, simple_account_fields @@ -17,7 +17,7 @@ workflow_app_log_partial_fields = { } -def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): +def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) simple_account_model = build_simple_account_model(api_or_ns) @@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = { } -def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): +def build_workflow_app_log_pagination_model(api_or_ns: Namespace): """Build the workflow app log pagination model for the API or Namespace.""" # Build the nested partial model first workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 7b878e05c8..1b2948811b 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -19,7 +19,7 @@ workflow_run_for_log_fields = { } -def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): +def build_workflow_run_for_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py new file mode 100644 index 0000000000..f84d226447 --- /dev/null +++ b/api/libs/archive_storage.py @@ -0,0 +1,347 @@ +""" +Archive Storage Client for S3-compatible storage. + +This module provides a dedicated storage client for archiving or exporting logs +to S3-compatible object storage. +""" + +import base64 +import datetime +import gzip +import hashlib +import logging +from collections.abc import Generator +from typing import Any, cast + +import boto3 +import orjson +from botocore.client import Config +from botocore.exceptions import ClientError + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class ArchiveStorageError(Exception): + """Base exception for archive storage operations.""" + + pass + + +class ArchiveStorageNotConfiguredError(ArchiveStorageError): + """Raised when archive storage is not properly configured.""" + + pass + + +class ArchiveStorage: + """ + S3-compatible storage client for archiving or exporting. + + This client provides methods for storing and retrieving archived data in JSONL+gzip format. + """ + + def __init__(self, bucket: str): + if not dify_config.ARCHIVE_STORAGE_ENABLED: + raise ArchiveStorageNotConfiguredError("Archive storage is not enabled") + + if not bucket: + raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured") + if not all( + [ + dify_config.ARCHIVE_STORAGE_ENDPOINT, + bucket, + dify_config.ARCHIVE_STORAGE_ACCESS_KEY, + dify_config.ARCHIVE_STORAGE_SECRET_KEY, + ] + ): + raise ArchiveStorageNotConfiguredError( + "Archive storage configuration is incomplete. " + "Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, " + "ARCHIVE_STORAGE_SECRET_KEY, and a bucket name" + ) + + self.bucket = bucket + self.client = boto3.client( + "s3", + endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT, + aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY, + aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY, + region_name=dify_config.ARCHIVE_STORAGE_REGION, + config=Config(s3={"addressing_style": "path"}), + ) + + # Verify bucket accessibility + try: + self.client.head_bucket(Bucket=self.bucket) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "404": + raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist") + elif error_code == "403": + raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'") + else: + raise ArchiveStorageError(f"Failed to access archive bucket: {e}") + + def put_object(self, key: str, data: bytes) -> str: + """ + Upload an object to the archive storage. + + Args: + key: Object key (path) within the bucket + data: Binary data to upload + + Returns: + MD5 checksum of the uploaded data + + Raises: + ArchiveStorageError: If upload fails + """ + checksum = hashlib.md5(data).hexdigest() + try: + self.client.put_object( + Bucket=self.bucket, + Key=key, + Body=data, + ContentMD5=self._content_md5(data), + ) + logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum) + return checksum + except ClientError as e: + raise ArchiveStorageError(f"Failed to upload object '{key}': {e}") + + def get_object(self, key: str) -> bytes: + """ + Download an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Returns: + Binary data of the object + + Raises: + ArchiveStorageError: If download fails + FileNotFoundError: If object does not exist + """ + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + return response["Body"].read() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "NoSuchKey": + raise FileNotFoundError(f"Archive object not found: {key}") + raise ArchiveStorageError(f"Failed to download object '{key}': {e}") + + def get_object_stream(self, key: str) -> Generator[bytes, None, None]: + """ + Stream an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Yields: + Chunks of binary data + + Raises: + ArchiveStorageError: If download fails + FileNotFoundError: If object does not exist + """ + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + yield from response["Body"].iter_chunks() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "NoSuchKey": + raise FileNotFoundError(f"Archive object not found: {key}") + raise ArchiveStorageError(f"Failed to stream object '{key}': {e}") + + def object_exists(self, key: str) -> bool: + """ + Check if an object exists in the archive storage. + + Args: + key: Object key (path) within the bucket + + Returns: + True if object exists, False otherwise + """ + try: + self.client.head_object(Bucket=self.bucket, Key=key) + return True + except ClientError: + return False + + def delete_object(self, key: str) -> None: + """ + Delete an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Raises: + ArchiveStorageError: If deletion fails + """ + try: + self.client.delete_object(Bucket=self.bucket, Key=key) + logger.debug("Deleted object: %s", key) + except ClientError as e: + raise ArchiveStorageError(f"Failed to delete object '{key}': {e}") + + def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str: + """ + Generate a pre-signed URL for downloading an object. + + Args: + key: Object key (path) within the bucket + expires_in: URL validity duration in seconds (default: 1 hour) + + Returns: + Pre-signed URL string. + + Raises: + ArchiveStorageError: If generation fails + """ + try: + return self.client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self.bucket, "Key": key}, + ExpiresIn=expires_in, + ) + except ClientError as e: + raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}") + + def list_objects(self, prefix: str) -> list[str]: + """ + List objects under a given prefix. + + Args: + prefix: Object key prefix to filter by + + Returns: + List of object keys matching the prefix + """ + keys = [] + paginator = self.client.get_paginator("list_objects_v2") + + try: + for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): + for obj in page.get("Contents", []): + keys.append(obj["Key"]) + except ClientError as e: + raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}") + + return keys + + @staticmethod + def _content_md5(data: bytes) -> str: + """Calculate base64-encoded MD5 for Content-MD5 header.""" + return base64.b64encode(hashlib.md5(data).digest()).decode() + + @staticmethod + def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes: + """ + Serialize records to gzipped JSONL format. + + Args: + records: List of dictionaries to serialize + + Returns: + Gzipped JSONL bytes + """ + lines = [] + for record in records: + # Convert datetime objects to ISO format strings + serialized = ArchiveStorage._serialize_record(record) + lines.append(orjson.dumps(serialized)) + + jsonl_content = b"\n".join(lines) + if jsonl_content: + jsonl_content += b"\n" + + return gzip.compress(jsonl_content) + + @staticmethod + def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]: + """ + Deserialize gzipped JSONL data to records. + + Args: + data: Gzipped JSONL bytes + + Returns: + List of dictionaries + """ + jsonl_content = gzip.decompress(data) + records = [] + + for line in jsonl_content.splitlines(): + if line: + records.append(orjson.loads(line)) + + return records + + @staticmethod + def _serialize_record(record: dict[str, Any]) -> dict[str, Any]: + """Serialize a single record, converting special types.""" + + def _serialize(item: Any) -> Any: + if isinstance(item, datetime.datetime): + return item.isoformat() + if isinstance(item, dict): + return {key: _serialize(value) for key, value in item.items()} + if isinstance(item, list): + return [_serialize(value) for value in item] + return item + + return cast(dict[str, Any], _serialize(record)) + + @staticmethod + def compute_checksum(data: bytes) -> str: + """Compute MD5 checksum of data.""" + return hashlib.md5(data).hexdigest() + + +# Singleton instance (lazy initialization) +_archive_storage: ArchiveStorage | None = None +_export_storage: ArchiveStorage | None = None + + +def get_archive_storage() -> ArchiveStorage: + """ + Get the archive storage singleton instance. + + Returns: + ArchiveStorage instance + + Raises: + ArchiveStorageNotConfiguredError: If archive storage is not configured + """ + global _archive_storage + if _archive_storage is None: + archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET + if not archive_bucket: + raise ArchiveStorageNotConfiguredError( + "Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET" + ) + _archive_storage = ArchiveStorage(bucket=archive_bucket) + return _archive_storage + + +def get_export_storage() -> ArchiveStorage: + """ + Get the export storage singleton instance. + + Returns: + ArchiveStorage instance + """ + global _export_storage + if _export_storage is None: + export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET + if not export_bucket: + raise ArchiveStorageNotConfiguredError( + "Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET" + ) + _export_storage = ArchiveStorage(bucket=export_bucket) + return _export_storage diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index 5bbf0c79a3..d4cb3e9971 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -2,6 +2,8 @@ Broadcast channel for Pub/Sub messaging. """ +from __future__ import annotations + import types from abc import abstractmethod from collections.abc import Iterator @@ -129,6 +131,6 @@ class BroadcastChannel(Protocol): """ @abstractmethod - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: """topic returns a `Topic` instance for the given topic name.""" ... diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 1fc3db8156..5bb4f579c1 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -20,7 +22,7 @@ class BroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "Topic": + def topic(self, topic: str) -> Topic: return Topic(self._client, topic) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 16e3a80ee1..d190c51bbc 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis @@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel: ): self._client = redis_client - def topic(self, topic: str) -> "ShardedTopic": + def topic(self, topic: str) -> ShardedTopic: return ShardedTopic(self._client, topic) diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index ff74ccbe8e..0828cf80bf 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and eliminates the need for repetitive language switching logic. """ +from __future__ import annotations + from dataclasses import dataclass from enum import StrEnum, auto from typing import Any, Protocol @@ -53,7 +55,7 @@ class EmailLanguage(StrEnum): ZH_HANS = "zh-Hans" @classmethod - def from_language_code(cls, language_code: str) -> "EmailLanguage": + def from_language_code(cls, language_code: str) -> EmailLanguage: """Convert a language code to EmailLanguage with fallback to English.""" if language_code == "zh-Hans": return cls.ZH_HANS diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 61a90ee4a9..e8592407c3 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,5 +1,4 @@ import re -import sys from collections.abc import Mapping from typing import Any @@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api): data.setdefault("code", "unknown") data.setdefault("status", status_code) - # Log stack - exc_info: Any = sys.exc_info() - if exc_info[1] is None: - exc_info = (None, None, None) - current_app.log_exception(exc_info) + # Note: Exception logging is handled by Flask/Flask-RESTX framework automatically + # Explicit log_exception call removed to avoid duplicate log entries return data, status_code diff --git a/api/libs/helper.py b/api/libs/helper.py index 74e1808e49..07c4823727 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,38 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def escape_like_pattern(pattern: str) -> str: + """ + Escape special characters in a string for safe use in SQL LIKE patterns. + + This function escapes the special characters used in SQL LIKE patterns: + - Backslash (\\) -> \\ + - Percent (%) -> \\% + - Underscore (_) -> \\_ + + The escaped pattern can then be safely used in SQL LIKE queries with the + ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards. + + Args: + pattern: The string pattern to escape + + Returns: + Escaped string safe for use in SQL LIKE queries + + Examples: + >>> escape_like_pattern("50% discount") + '50\\% discount' + >>> escape_like_pattern("test_data") + 'test\\_data' + >>> escape_like_pattern("path\\to\\file") + 'path\\\\to\\\\file' + """ + if not pattern: + return pattern + # Escape backslash first, then percent and underscore + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: """ Extract tenant_id from Account or EndUser object. diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py index 17ed067d81..657d28f896 100644 --- a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py +++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '00bacef91f18' down_revision = '8ec536f3c800' @@ -23,31 +20,17 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description', sa.Text(), nullable=False)) - batch_op.drop_column('description_str') - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False)) - batch_op.drop_column('description_str') + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False)) + batch_op.drop_column('description_str') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False)) - batch_op.drop_column('description') - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False)) - batch_op.drop_column('description') + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False)) + batch_op.drop_column('description') # ### end Alembic commands ### diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py index ed70bf5d08..912d9dbfa4 100644 --- a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py +++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py @@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '114eed84c228' down_revision = 'c71211c8f604' @@ -32,13 +28,7 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: - batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False)) - else: - with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: - batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False)) + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py index 509bd5d0e8..0ca905129d 100644 --- a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py +++ b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '161cadc1af8d' down_revision = '7e6a8693e07a' @@ -23,16 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: - # Step 1: Add column without NOT NULL constraint - op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False)) - else: - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: - # Step 1: Add column without NOT NULL constraint - op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False)) + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + # Step 1: Add column without NOT NULL constraint + op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py index 0767b725f6..be1b42f883 100644 --- a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py +++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py @@ -9,11 +9,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - # revision identifiers, used by Alembic. revision = '6af6a521a53e' down_revision = 'd57ba9ebb251' @@ -23,58 +18,30 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('document_id', - existing_type=sa.UUID(), - nullable=True) - batch_op.alter_column('data_source_type', - existing_type=sa.TEXT(), - nullable=True) - batch_op.alter_column('segment_id', - existing_type=sa.UUID(), - nullable=True) - else: - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('document_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('data_source_type', - existing_type=models.types.LongText(), - nullable=True) - batch_op.alter_column('segment_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('document_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('data_source_type', + existing_type=models.types.LongText(), + nullable=True) + batch_op.alter_column('segment_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('segment_id', - existing_type=sa.UUID(), - nullable=False) - batch_op.alter_column('data_source_type', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('document_id', - existing_type=sa.UUID(), - nullable=False) - else: - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('segment_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.alter_column('data_source_type', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('document_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('segment_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.alter_column('data_source_type', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('document_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py index a749c8bddf..5d12419bf7 100644 --- a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198 from alembic import op import models as models import sqlalchemy as sa -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'd3f6769a94a3' diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py index 45842295ea..a49d6a52f6 100644 --- a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -28,85 +28,45 @@ def upgrade(): op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") - if _is_pg(conn): - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) - else: - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) - - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) - - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) - - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) - - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py index fdd8984029..8a36c9c4a5 100644 --- a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -49,57 +49,33 @@ def upgrade(): op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL") op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL") op.execute("UPDATE workflows SET features = '' WHERE features IS NULL") - if _is_pg(conn): - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('graph', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('features', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('updated_at', - existing_type=postgresql.TIMESTAMP(), - nullable=False) - else: - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('graph', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('features', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('updated_at', - existing_type=sa.TIMESTAMP(), - nullable=False) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('graph', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('features', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('updated_at', + existing_type=sa.TIMESTAMP(), + nullable=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('updated_at', - existing_type=postgresql.TIMESTAMP(), - nullable=True) - batch_op.alter_column('features', - existing_type=sa.TEXT(), - nullable=True) - batch_op.alter_column('graph', - existing_type=sa.TEXT(), - nullable=True) - else: - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('updated_at', - existing_type=sa.TIMESTAMP(), - nullable=True) - batch_op.alter_column('features', - existing_type=models.types.LongText(), - nullable=True) - batch_op.alter_column('graph', - existing_type=models.types.LongText(), - nullable=True) + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('updated_at', + existing_type=sa.TIMESTAMP(), + nullable=True) + batch_op.alter_column('features', + existing_type=models.types.LongText(), + nullable=True) + batch_op.alter_column('graph', + existing_type=models.types.LongText(), + nullable=True) if _is_pg(conn): with op.batch_alter_table('messages', schema=None) as batch_op: diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py index 16ca902726..1fc4a64df1 100644 --- a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py +++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py @@ -86,57 +86,30 @@ def upgrade(): def migrate_existing_provider_models_data(): """migrate provider_models table data to provider_model_credentials""" - conn = op.get_bind() - # Define table structure for data manipulation - if _is_pg(conn): - provider_models_table = table('provider_models', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('encrypted_config', sa.Text()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()), - column('credential_id', models.types.StringUUID()), - ) - else: - provider_models_table = table('provider_models', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('encrypted_config', models.types.LongText()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()), - column('credential_id', models.types.StringUUID()), - ) + # Define table structure for data manipulatio + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('encrypted_config', models.types.LongText()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) - if _is_pg(conn): - provider_model_credentials_table = table('provider_model_credentials', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('credential_name', sa.String()), - column('encrypted_config', sa.Text()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()) - ) - else: - provider_model_credentials_table = table('provider_model_credentials', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('credential_name', sa.String()), - column('encrypted_config', models.types.LongText()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()) - ) + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', models.types.LongText()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) # Get database connection @@ -183,14 +156,8 @@ def migrate_existing_provider_models_data(): def downgrade(): # Re-add encrypted_config column to provider_models table - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('provider_models', schema=None) as batch_op: - batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('provider_models', schema=None) as batch_op: - batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True)) + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True)) if not context.is_offline_mode(): # Migrate data back from provider_model_credentials to provider_models diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py index 75b4d61173..79fe9d9bba 100644 --- a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py +++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py @@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695 from alembic import op import models as models import sqlalchemy as sa -from libs.uuid_utils import uuidv7 def _is_pg(conn): diff --git a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py index 4f472fe4b4..cf2b973d2d 100644 --- a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py +++ b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py @@ -9,8 +9,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" import sqlalchemy as sa @@ -23,12 +21,7 @@ depends_on = None def upgrade(): # Add encrypted_headers column to tool_mcp_providers table - conn = op.get_bind() - - if _is_pg(conn): - op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True)) - else: - op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True)) + op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True)) def downgrade(): diff --git a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py index 8eac0dee10..bad516dcac 100644 --- a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py +++ b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py @@ -44,6 +44,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'), sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx') ) + if _is_pg(conn): op.create_table('datasource_oauth_tenant_params', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -70,6 +71,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique') ) + if _is_pg(conn): op.create_table('datasource_providers', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -104,6 +106,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name') ) + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False) @@ -133,6 +136,7 @@ def upgrade(): sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey') ) + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False) @@ -174,6 +178,7 @@ def upgrade(): sa.Column('updated_by', models.types.StringUUID(), nullable=True), sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey') ) + if _is_pg(conn): op.create_table('pipeline_customized_templates', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -193,7 +198,6 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') ) else: - # MySQL: Use compatible syntax op.create_table('pipeline_customized_templates', sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), @@ -211,6 +215,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') ) + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False) @@ -236,6 +241,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey') ) + if _is_pg(conn): op.create_table('pipelines', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -266,6 +272,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_pkey') ) + if _is_pg(conn): op.create_table('workflow_draft_variable_files', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -292,6 +299,7 @@ def upgrade(): sa.Column('value_type', sa.String(20), nullable=False), sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey')) ) + if _is_pg(conn): op.create_table('workflow_node_execution_offload', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -316,6 +324,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')), sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key')) ) + if _is_pg(conn): with op.batch_alter_table('datasets', schema=None) as batch_op: batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True)) @@ -342,6 +351,7 @@ def upgrade(): comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',) ) batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False) + if _is_pg(conn): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False)) diff --git a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py index 0776ab0818..ec0cfbd11d 100644 --- a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py +++ b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py @@ -9,8 +9,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" import sqlalchemy as sa @@ -33,15 +31,9 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: - batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False)) - batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True)) - else: - with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: - batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False)) - batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True)) + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False)) + batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py index 627219cc4b..12905b3674 100644 --- a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py +++ b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py @@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407 from alembic import op import models as models import sqlalchemy as sa -from libs.uuid_utils import uuidv7 def _is_pg(conn): return conn.dialect.name == "postgresql" diff --git a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py index 9641a15c89..c27c1058d1 100644 --- a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py +++ b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py @@ -105,6 +105,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client') ) + if _is_pg(conn): op.create_table('trigger_subscriptions', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), @@ -143,6 +144,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'), sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider') ) + with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op: batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True) batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False) @@ -176,6 +178,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'), sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription') ) + with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op: batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False) @@ -207,6 +210,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'), sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node') ) + with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op: batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False) @@ -264,6 +268,7 @@ def upgrade(): sa.Column('finished_at', sa.DateTime(), nullable=True), sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey') ) + with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op: batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False) batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False) @@ -299,6 +304,7 @@ def upgrade(): sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'), sa.UniqueConstraint('webhook_id', name='uniq_webhook_id') ) + with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op: batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False) diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py index fae506906b..127ffd5599 100644 --- a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py +++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '23db93619b9d' down_revision = '8ae9bc661daa' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True)) + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py index 2676ef0b94..31829d8e58 100644 --- a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py +++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py @@ -62,14 +62,8 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True)) with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: batch_op.drop_index('app_annotation_settings_app_idx') diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py index 3362a3a09f..07a8cd86b1 100644 --- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -11,9 +11,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '2a3aebbbf4bb' down_revision = 'c031d46af369' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('apps', schema=None) as batch_op: - batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('apps', schema=None) as batch_op: - batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True)) + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py index 40bd727f66..211b2d8882 100644 --- a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py +++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py @@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '2e9819ca5b28' down_revision = 'ab23c11305d4' @@ -24,35 +20,19 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True)) - batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) - batch_op.drop_column('dataset_id') - else: - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True)) - batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) - batch_op.drop_column('dataset_id') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True)) + batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) + batch_op.drop_column('dataset_id') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True)) - batch_op.drop_index('api_token_tenant_idx') - batch_op.drop_column('tenant_id') - else: - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True)) - batch_op.drop_index('api_token_tenant_idx') - batch_op.drop_column('tenant_id') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True)) + batch_op.drop_index('api_token_tenant_idx') + batch_op.drop_column('tenant_id') # ### end Alembic commands ### diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py index 76056a9460..3491c85e2f 100644 --- a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '42e85ed5564d' down_revision = 'f9107f83abab' @@ -24,59 +20,31 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('app_model_config_id', - existing_type=postgresql.UUID(), - nullable=True) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=True) - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=True) - else: - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('app_model_config_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=True) - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('app_model_config_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('app_model_config_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('app_model_config_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('app_model_config_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py index ef066587b7..8537a87233 100644 --- a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py +++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py @@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415 """ from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '4829e54d2fee' down_revision = '114eed84c228' @@ -23,39 +19,21 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - # PostgreSQL: Keep original syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - # MySQL: Use compatible syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=models.types.StringUUID(), - nullable=True) + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - # PostgreSQL: Keep original syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - # MySQL: Use compatible syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=models.types.StringUUID(), - nullable=False) + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py index b080e7680b..22405e3cc8 100644 --- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506 """ from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '563cf8bf777b' down_revision = 'b5429b71023c' @@ -23,35 +19,19 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py index 1ace8ea5a0..01d7d5ba21 100644 --- a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py +++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py @@ -48,12 +48,9 @@ def upgrade(): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('datasets', schema=None) as batch_op: - batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True)) - else: - with op.batch_alter_table('datasets', schema=None) as batch_op: - batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py index 457338ef42..0faa48f535 100644 --- a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py +++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '714aafe25d39' down_revision = 'f2a6fc85e260' @@ -23,16 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False)) - batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False)) - else: - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False)) - batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False)) + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False)) + batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py index 7bcd1a1be3..aa7b4a21e2 100644 --- a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py +++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '77e83833755c' down_revision = '6dcb43972bdc' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py index 3c0aa082d5..34a17697d3 100644 --- a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py +++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py @@ -27,7 +27,6 @@ def upgrade(): conn = op.get_bind() if _is_pg(conn): - # PostgreSQL: Keep original syntax op.create_table('tool_providers', sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tenant_id', postgresql.UUID(), nullable=False), @@ -40,7 +39,6 @@ def upgrade(): sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) else: - # MySQL: Use compatible syntax op.create_table('tool_providers', sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), @@ -52,12 +50,9 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py index beea90b384..884839c010 100644 --- a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py +++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '88072f0caa04' down_revision = '246ba09cbbdb' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tenants', schema=None) as batch_op: - batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('tenants', schema=None) as batch_op: - batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True)) + with op.batch_alter_table('tenants', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/89c7899ca936_.py b/api/migrations/versions/89c7899ca936_.py index 2420710e74..d26f1e82d6 100644 --- a/api/migrations/versions/89c7899ca936_.py +++ b/api/migrations/versions/89c7899ca936_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '89c7899ca936' down_revision = '187385f442fc' @@ -23,39 +20,21 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.VARCHAR(length=255), - type_=sa.Text(), - existing_nullable=True) - else: - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - existing_nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + existing_nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.Text(), - type_=sa.VARCHAR(length=255), - existing_nullable=True) - else: - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - existing_nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + existing_nullable=True) # ### end Alembic commands ### diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py index 111e81240b..6022ea2c20 100644 --- a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py +++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '8ec536f3c800' down_revision = 'ad472b61a054' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False)) - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False)) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py index 1c1c6cacbb..9d6d40114d 100644 --- a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py +++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py @@ -57,12 +57,9 @@ def upgrade(): batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False) batch_op.create_index('message_file_message_idx', ['message_id'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True)) if _is_pg(conn): with op.batch_alter_table('upload_files', schema=None) as batch_op: diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py index 5d29d354f3..0b3f92a12e 100644 --- a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py +++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py @@ -24,7 +24,6 @@ def upgrade(): conn = op.get_bind() if _is_pg(conn): - # PostgreSQL: Keep original syntax with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False)) batch_op.drop_index('pinned_conversation_conversation_idx') @@ -35,7 +34,6 @@ def upgrade(): batch_op.drop_index('saved_message_message_idx') batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False) else: - # MySQL: Use compatible syntax with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'"), nullable=False)) batch_op.drop_index('pinned_conversation_conversation_idx') diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py index 616cb2f163..c8747a51f7 100644 --- a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py +++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'a5b56fb053ef' down_revision = 'd3d503a3471c' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py index 900ff78036..f56aeb7e66 100644 --- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'a9836e3baeee' down_revision = '968fff4c0ab9' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/b24be59fbb04_.py b/api/migrations/versions/b24be59fbb04_.py index b0a6d10d8c..ae91eaf1bc 100644 --- a/api/migrations/versions/b24be59fbb04_.py +++ b/api/migrations/versions/b24be59fbb04_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'b24be59fbb04' down_revision = 'de95f5c77138' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py index 772395c25b..c02c24c23f 100644 --- a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py +++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'b3a09c049e8e' down_revision = '2e9819ca5b28' @@ -23,20 +20,11 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) - batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) - batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) + batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 76be794ff4..fe51d1c78d 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -7,7 +7,6 @@ Create Date: 2024-06-17 10:01:00.255189 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py index 9e02ec5d84..36e934f0fc 100644 --- a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py +++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py @@ -54,12 +54,9 @@ def upgrade(): batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False) batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True)) if _is_pg(conn): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: @@ -68,54 +65,31 @@ def upgrade(): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'"), nullable=False)) - if _is_pg(conn): - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.add_column(sa.Column('question', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=True) - batch_op.alter_column('message_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('message_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('message_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - if _is_pg(conn): - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.alter_column('message_id', - existing_type=postgresql.UUID(), - nullable=False) - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=False) - batch_op.drop_column('hit_count') - batch_op.drop_column('question') - else: - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.alter_column('message_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.drop_column('hit_count') - batch_op.drop_column('question') + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('message_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.drop_column('hit_count') + batch_op.drop_column('question') with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.drop_column('type') diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py index 02098e91c1..ac1c14e50c 100644 --- a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py +++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py @@ -12,9 +12,6 @@ from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'f2a6fc85e260' down_revision = '46976cc39132' @@ -24,16 +21,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False)) - batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) - else: - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False)) - batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False)) + batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 420e6adc6c..f7a9c20026 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,7 +8,7 @@ from uuid import uuid4 import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, validates from typing_extensions import deprecated from .base import TypeBase @@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase): role: TenantAccountRole | None = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False) + @validates("status") + def _normalize_status(self, _key: str, value: str | AccountStatus) -> str: + if isinstance(value, AccountStatus): + return value.value + return value + @property def is_password_set(self): return self.password is not None diff --git a/api/models/model.py b/api/models/model.py index 32be20e60a..617e70b992 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re import uuid @@ -5,7 +7,7 @@ from collections.abc import Mapping from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from uuid import uuid4 import sqlalchemy as sa @@ -56,7 +58,7 @@ class AppMode(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> "AppMode": + def value_of(cls, value: str) -> AppMode: """ Get value of given mode. @@ -72,6 +74,7 @@ class AppMode(StrEnum): class IconType(StrEnum): IMAGE = auto() EMOJI = auto() + LINK = auto() class App(Base): @@ -83,7 +86,7 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) @@ -122,19 +125,19 @@ class App(Base): return "" @property - def site(self) -> Optional["Site"]: + def site(self) -> Site | None: site = db.session.query(Site).where(Site.app_id == self.id).first() return site @property - def app_model_config(self) -> Optional["AppModelConfig"]: + def app_model_config(self) -> AppModelConfig | None: if self.app_model_config_id: return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional["Workflow"]: + def workflow(self) -> Workflow | None: if self.workflow_id: from .workflow import Workflow @@ -289,7 +292,7 @@ class App(Base): return deleted_tools @property - def tags(self) -> list["Tag"]: + def tags(self) -> list[Tag]: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) @@ -1195,7 +1198,7 @@ class Message(Base): return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self) -> list["MessageAgentThought"]: + def agent_thoughts(self) -> list[MessageAgentThought]: return ( db.session.query(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) @@ -1320,7 +1323,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "Message": + def from_dict(cls, data: dict[str, Any]) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1433,15 +1436,20 @@ class MessageAnnotation(Base): app_id: Mapped[str] = mapped_column(StringUUID) conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[str | None] = mapped_column(StringUUID) - question = mapped_column(LongText, nullable=True) - content = mapped_column(LongText, nullable=False) + question: Mapped[str | None] = mapped_column(LongText, nullable=True) + content: Mapped[str] = mapped_column(LongText, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) - account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) + @property + def question_text(self) -> str: + """Return a non-null question string, falling back to the answer content.""" + return self.question or self.content + @property def account(self): account = db.session.query(Account).where(Account.id == self.account_id).first() @@ -1542,7 +1550,7 @@ class OperationLog(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) action: Mapped[str] = mapped_column(String(255), nullable=False) - content: Mapped[Any] = mapped_column(sa.JSON) + content: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/models/provider.py b/api/models/provider.py index 2afd8c5329..441b54c797 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from enum import StrEnum, auto from functools import cached_property @@ -19,7 +21,7 @@ class ProviderType(StrEnum): SYSTEM = auto() @staticmethod - def value_of(value: str) -> "ProviderType": + def value_of(value: str) -> ProviderType: for member in ProviderType: if member.value == value: return member @@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum): """hosted trial quota""" @staticmethod - def value_of(value: str) -> "ProviderQuotaType": + def value_of(value: str) -> ProviderQuotaType: for member in ProviderQuotaType: if member.value == value: return member @@ -76,7 +78,7 @@ class Provider(TypeBase): quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="") quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None) - quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0) + quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index e4f9bcb582..e7b98dcf27 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime from decimal import Decimal @@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase): ) @property - def schema_type(self) -> "ApiProviderSchemaType": + def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) @property - def tools(self) -> list["ApiToolBundle"]: + def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)] @property @@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase): return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() @property - def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]: + def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: return [ WorkflowToolParameterConfiguration.model_validate(config) for config in json.loads(self.parameter_configuration) @@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase): except (json.JSONDecodeError, TypeError): return [] - def to_entity(self) -> "MCPProviderEntity": + def to_entity(self) -> MCPProviderEntity: """Convert to domain entity""" from core.entities.mcp_provider import MCPProviderEntity @@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase): ) @property - def description_i18n(self) -> "I18nObject": + def description_i18n(self) -> I18nObject: return I18nObject.model_validate(json.loads(self.description)) diff --git a/api/models/trigger.py b/api/models/trigger.py index 87e2a5ccfc..209345eb84 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -415,7 +415,7 @@ class AppTrigger(TypeBase): node_id: Mapped[str | None] = mapped_column(String(64), nullable=False) trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - provider_name: Mapped[str] = mapped_column(String(255), server_default="", default="") # why it is nullable? + provider_name: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default="") status: Mapped[str] = mapped_column( EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 5131177836..7e8a0f7c2e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -98,7 +100,7 @@ class WorkflowType(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> "WorkflowType": + def value_of(cls, value: str) -> WorkflowType: """ Get value of given mode. @@ -111,7 +113,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: """ Get workflow type from app mode. @@ -212,7 +214,7 @@ class Workflow(Base): # bug rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", - ) -> "Workflow": + ) -> Workflow: workflow = Workflow() workflow.id = str(uuid4()) workflow.tenant_id = tenant_id @@ -650,7 +652,7 @@ class WorkflowRun(Base): finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( + pause: Mapped[WorkflowPause | None] = orm.relationship( "WorkflowPause", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", uselist=False, @@ -725,7 +727,7 @@ class WorkflowRun(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": + def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -877,7 +879,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) - offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship( + offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( "WorkflowNodeExecutionOffload", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", uselist=True, @@ -887,13 +889,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def preload_offload_data( - query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], + query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], ): return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) @staticmethod def preload_offload_data_and_files( - query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], + query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], ): return query.options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( @@ -968,7 +970,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property @@ -1082,7 +1084,7 @@ class WorkflowNodeExecutionOffload(Base): back_populates="offload_data", ) - file: Mapped[Optional["UploadFile"]] = orm.relationship( + file: Mapped[UploadFile | None] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1100,7 +1102,7 @@ class WorkflowAppLogCreatedFrom(StrEnum): INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": + def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: """ Get value of given mode. @@ -1217,7 +1219,7 @@ class ConversationVariable(TypeBase): ) @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: obj = cls( id=variable.id, app_id=app_id, @@ -1370,7 +1372,7 @@ class WorkflowDraftVariable(Base): ) # Relationship to WorkflowDraftVariableFile - variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship( + variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1540,8 +1542,9 @@ class WorkflowDraftVariable(Base): node_execution_id: str | None, description: str = "", file_id: str | None = None, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = WorkflowDraftVariable() + variable.id = str(uuid4()) variable.created_at = naive_utc_now() variable.updated_at = naive_utc_now() variable.description = description @@ -1562,7 +1565,7 @@ class WorkflowDraftVariable(Base): name: str, value: Segment, description: str = "", - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, @@ -1583,7 +1586,7 @@ class WorkflowDraftVariable(Base): value: Segment, node_execution_id: str, editable: bool = False, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, @@ -1606,7 +1609,7 @@ class WorkflowDraftVariable(Base): visible: bool = True, editable: bool = True, file_id: str | None = None, - ) -> "WorkflowDraftVariable": + ) -> WorkflowDraftVariable: variable = cls._new( app_id=app_id, node_id=node_id, @@ -1702,7 +1705,7 @@ class WorkflowDraftVariableFile(Base): ) # Relationship to UploadFile - upload_file: Mapped["UploadFile"] = orm.relationship( + upload_file: Mapped[UploadFile] = orm.relationship( foreign_keys=[upload_file_id], lazy="raise", uselist=False, @@ -1769,7 +1772,7 @@ class WorkflowPause(DefaultFieldsMixin, Base): state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) # Relationship to WorkflowRun - workflow_run: Mapped["WorkflowRun"] = orm.relationship( + workflow_run: Mapped[WorkflowRun] = orm.relationship( foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", @@ -1825,7 +1828,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": + def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: if isinstance(pause_reason, HumanInputRequired): return cls( type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index db610df290..77d6b5a138 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -16,6 +16,11 @@ celery_redis = Redis( port=redis_config.get("port") or 6379, password=redis_config.get("password") or None, db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, + ssl=bool(dify_config.BROKER_USE_SSL), + ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None, + ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None, + ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None, + ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None, ) logger = logging.getLogger(__name__) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index d03cbddceb..b73302508a 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -77,7 +77,7 @@ class AppAnnotationService: if annotation_setting: add_annotation_to_index_task.delay( annotation.id, - annotation.question, + question, current_tenant_id, app_id, annotation_setting.collection_binding_id, @@ -137,13 +137,16 @@ class AppAnnotationService: if not app: raise NotFound("App not found") if keyword: + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) stmt = ( select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .where( or_( - MessageAnnotation.question.ilike(f"%{keyword}%"), - MessageAnnotation.content.ilike(f"%{keyword}%"), + MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"), + MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"), ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) @@ -253,7 +256,7 @@ class AppAnnotationService: if app_annotation_setting: update_annotation_to_index_task.delay( annotation.id, - annotation.question, + annotation.question_text, current_tenant_id, app_id, app_annotation_setting.collection_binding_id, diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index deba0b79e8..acd2a25a86 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig +from models.model import AppModelConfig, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -428,10 +428,10 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in ["emoji", "link", "image"]: + if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]: icon_type = icon_type_value else: - icon_type = "emoji" + icon_type = IconType.EMOJI.value icon = icon or str(app_data.get("icon", "")) if app: diff --git a/api/services/app_service.py b/api/services/app_service.py index ef89a4fd10..02ebfbace0 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -55,8 +55,11 @@ class AppService: if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) if args.get("name"): + from libs.helper import escape_like_pattern + name = args["name"][:30] - filters.append(App.name.ilike(f"%{name}%")) + escaped_name = escape_like_pattern(name) + filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if args.get("tag_ids") and len(args["tag_ids"]) > 0: target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 659e7406fb..295d48d8a1 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.variables.types import SegmentType -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.model import App, Conversation, EndUser, Message +from services.conversation_variable_updater import ConversationVariableUpdater from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, @@ -218,7 +218,9 @@ class ConversationService: # Apply variable_name filter if provided if variable_name: # Filter using JSON extraction to match variable names case-insensitively - escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + from libs.helper import escape_like_pattern + + escaped_variable_name = escape_like_pattern(variable_name) # Filter using JSON extraction to match variable names case-insensitively if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: stmt = stmt.where( @@ -335,7 +337,7 @@ class ConversationService: updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) # Use the conversation variable updater to persist the changes - updater = conversation_variable_updater_factory() + updater = ConversationVariableUpdater(session_factory.get_session_maker()) updater.update(conversation_id, updated_variable) updater.flush() diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py new file mode 100644 index 0000000000..acc0ec2b22 --- /dev/null +++ b/api/services/conversation_variable_updater.py @@ -0,0 +1,28 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker + +from core.variables.variables import Variable +from models import ConversationVariable + + +class ConversationVariableNotFoundError(Exception): + pass + + +class ConversationVariableUpdater: + def __init__(self, session_maker: sessionmaker[Session]) -> None: + self._session_maker: sessionmaker[Session] = session_maker + + def update(self, conversation_id: str, variable: Variable) -> None: + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with self._session_maker() as session: + row = session.scalar(stmt) + if not row: + raise ConversationVariableNotFoundError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self) -> None: + pass diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ac4b25c5dc..18e5613438 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -144,7 +144,8 @@ class DatasetService: query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.where(Dataset.name.ilike(f"%{search}%")) + escaped_search = helper.escape_like_pattern(search) + query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\")) # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: @@ -3423,7 +3424,8 @@ class SegmentService: .order_by(ChildChunk.position.asc()) ) if keyword: - query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\")) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod @@ -3456,7 +3458,8 @@ class SegmentService: query = query.where(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + escaped_keyword = helper.escape_like_pattern(keyword) + query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\")) query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc()) paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index f405546909..a29d848ac5 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -70,7 +70,6 @@ class ProviderResponse(BaseModel): description: I18nObject | None = None icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None supported_model_types: Sequence[ModelType] @@ -98,11 +97,6 @@ class ProviderResponse(BaseModel): en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans", ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self @@ -116,7 +110,6 @@ class ProviderWithModelsResponse(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] @@ -134,11 +127,6 @@ class ProviderWithModelsResponse(BaseModel): self.icon_small_dark = I18nObject( en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self @@ -163,11 +151,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): self.icon_small_dark = I18nObject( en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 40faa85b9a..65dd41af43 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -35,7 +35,10 @@ class ExternalDatasetService: .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: - query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + from libs.helper import escape_like_pattern + + escaped_search = escape_like_pattern(search) + query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\")) external_knowledge_apis = db.paginate( select=query, page=page, per_page=per_page, max_per_page=100, error_out=False diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index eea382febe..edd1004b82 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -99,7 +99,6 @@ class ModelProviderService: description=provider_configuration.provider.description, icon_small=provider_configuration.provider.icon_small, icon_small_dark=provider_configuration.provider.icon_small_dark, - icon_large=provider_configuration.provider.icon_large, background=provider_configuration.provider.background, help=provider_configuration.provider.help, supported_model_types=provider_configuration.provider.supported_model_types, @@ -423,7 +422,6 @@ class ModelProviderService: label=first_model.provider.label, icon_small=first_model.provider.icon_small, icon_small_dark=first_model.provider.icon_small_dark, - icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, models=[ ProviderModelWithStatusEntity( @@ -488,7 +486,6 @@ class ModelProviderService: provider=result.provider.provider, label=result.provider.label, icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, supported_model_types=result.provider.supported_model_types, ), ) @@ -522,7 +519,7 @@ class ModelProviderService: :param tenant_id: workspace id :param provider: provider name - :param icon_type: icon type (icon_small or icon_large) + :param icon_type: icon type (icon_small or icon_small_dark) :param lang: language (zh_Hans or en_US) :return: """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f53448e7fe..1ba64813ba 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -874,7 +874,7 @@ class RagPipelineService: variable_pool = node_instance.graph_runtime_state.variable_pool invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) if invoke_from: - if invoke_from.value == InvokeFrom.PUBLISHED: + if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() @@ -1318,7 +1318,7 @@ class RagPipelineService: "datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)], "original_document_id": document.id, }, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, streaming=False, call_depth=0, workflow_thread_pool_id=None, diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 937e6593fe..bd3585acf4 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -19,7 +19,10 @@ class TagService: .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) + from libs.helper import escape_like_pattern + + escaped_keyword = escape_like_pattern(keyword) + query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) results: list = query.order_by(Tag.created_at.desc()).all() return results diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 250d29f335..c32157919b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -85,7 +85,9 @@ class ApiToolManageService: raise ValueError(f"invalid schema: {str(e)}") @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]: + def convert_schema_to_tool_bundles( + schema: str, extra_info: dict | None = None + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ convert schema to tool bundles @@ -103,7 +105,7 @@ class ApiToolManageService: provider_name: str, icon: dict, credentials: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str, custom_disclaimer: str, @@ -112,9 +114,6 @@ class ApiToolManageService: """ create api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists @@ -241,18 +240,15 @@ class ApiToolManageService: original_provider: str, icon: dict, credentials: dict, - schema_type: str, + _schema_type: ApiProviderSchemaType, schema: str, - privacy_policy: str, + privacy_policy: str | None, custom_disclaimer: str, labels: list[str], ): """ update api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists @@ -277,7 +273,7 @@ class ApiToolManageService: provider.icon = json.dumps(icon) provider.schema = schema provider.description = extra_info.get("description", "") - provider.schema_type_str = ApiProviderSchemaType.OPENAPI + provider.schema_type_str = schema_type provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer @@ -356,7 +352,7 @@ class ApiToolManageService: tool_name: str, credentials: dict, parameters: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, ): """ diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index ef77c33c1b..4131d75145 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -853,7 +853,7 @@ class TriggerProviderService: """ Create a subscription builder for rebuilding an existing subscription. - This method creates a builder pre-filled with data from the rebuild request, + This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub) keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged. :param tenant_id: Tenant ID @@ -868,111 +868,50 @@ class TriggerProviderService: if not provider_controller: raise ValueError(f"Provider {provider_id} not found") - # Use distributed lock to prevent race conditions on the same subscription - lock_key = f"trigger_subscription_rebuild_lock:{tenant_id}_{subscription_id}" - with redis_client.lock(lock_key, timeout=20): - with Session(db.engine, expire_on_commit=False) as session: - try: - # Get subscription within the transaction - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() - ) - if not subscription: - raise ValueError(f"Subscription {subscription_id} not found") + subscription = TriggerProviderService.get_subscription_by_id( + tenant_id=tenant_id, + subscription_id=subscription_id, + ) + if not subscription: + raise ValueError(f"Subscription {subscription_id} not found") - credential_type = CredentialType.of(subscription.credential_type) - if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]: - raise ValueError("Credential type not supported for rebuild") + credential_type = CredentialType.of(subscription.credential_type) + if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}: + raise ValueError(f"Credential type {credential_type} not supported for auto creation") - # Decrypt existing credentials for merging - credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription( - tenant_id=tenant_id, - controller=provider_controller, - subscription=subscription, - ) - decrypted_credentials = dict(credential_encrypter.decrypt(subscription.credentials)) + # Delete the previous subscription + user_id = subscription.user_id + unsubscribe_result = TriggerManager.unsubscribe_trigger( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + subscription=subscription.to_entity(), + credentials=subscription.credentials, + credential_type=credential_type, + ) + if not unsubscribe_result.success: + raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}") - # Merge credentials: if caller passed HIDDEN_VALUE, retain existing decrypted value - merged_credentials: dict[str, Any] = { - key: value if value != HIDDEN_VALUE else decrypted_credentials.get(key, UNKNOWN_VALUE) - for key, value in credentials.items() - } - - user_id = subscription.user_id - - # TODO: Trying to invoke update api of the plugin trigger provider - - # FALLBACK: If the update api is not implemented, - # delete the previous subscription and create a new one - - # Unsubscribe the previous subscription (external call, but we'll handle errors) - try: - TriggerManager.unsubscribe_trigger( - tenant_id=tenant_id, - user_id=user_id, - provider_id=provider_id, - subscription=subscription.to_entity(), - credentials=decrypted_credentials, - credential_type=credential_type, - ) - except Exception as e: - logger.exception("Error unsubscribing trigger during rebuild", exc_info=e) - # Continue anyway - the subscription might already be deleted externally - - # Create a new subscription with the same subscription_id and endpoint_id (external call) - new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger( - tenant_id=tenant_id, - user_id=user_id, - provider_id=provider_id, - endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id), - parameters=parameters, - credentials=merged_credentials, - credential_type=credential_type, - ) - - # Update the subscription in the same transaction - # Inline update logic to reuse the same session - if name is not None and name != subscription.name: - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() - ) - if existing and existing.id != subscription.id: - raise ValueError(f"Subscription name '{name}' already exists for this provider") - subscription.name = name - - # Update parameters - subscription.parameters = dict(parameters) - - # Update credentials with merged (and encrypted) values - subscription.credentials = dict(credential_encrypter.encrypt(merged_credentials)) - - # Update properties - if new_subscription.properties: - properties_encrypter, _ = create_provider_encrypter( - tenant_id=tenant_id, - config=provider_controller.get_properties_schema(), - cache=NoOpProviderCredentialCache(), - ) - subscription.properties = dict(properties_encrypter.encrypt(dict(new_subscription.properties))) - - # Update expiration timestamp - if new_subscription.expires_at is not None: - subscription.expires_at = new_subscription.expires_at - - # Commit the transaction - session.commit() - - # Clear subscription cache - delete_cache_for_subscription( - tenant_id=tenant_id, - provider_id=subscription.provider_id, - subscription_id=subscription.id, - ) - - except Exception as e: - # Rollback on any error - session.rollback() - logger.exception("Failed to rebuild trigger subscription", exc_info=e) - raise + # Create a new subscription with the same subscription_id and endpoint_id + new_credentials: dict[str, Any] = { + key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } + new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id), + parameters=parameters, + credentials=new_credentials, + credential_type=credential_type, + ) + TriggerProviderService.update_trigger_subscription( + tenant_id=tenant_id, + subscription_id=subscription.id, + name=name, + parameters=parameters, + credentials=new_credentials, + properties=new_subscription.properties, + expires_at=new_subscription.expires_at, + ) diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 0f969207cf..f973361341 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from abc import ABC, abstractmethod from collections.abc import Mapping @@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator): self._max_size_bytes = max_size_bytes @classmethod - def default(cls) -> "VariableTruncator": + def default(cls) -> VariableTruncator: return VariableTruncator( max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE, array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH, diff --git a/api/services/website_service.py b/api/services/website_service.py index a23f01ec71..fe48c3b08e 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import json from dataclasses import dataclass @@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest: return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest": + def from_args(cls, args: dict) -> WebsiteCrawlApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") url = args.get("url") @@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest: job_id: str @classmethod - def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": + def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") if not provider: diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 01f0c7a55a..8574d30255 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -86,12 +86,19 @@ class WorkflowAppService: # Join to workflow run for filtering when needed. if keyword: - keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") + from libs.helper import escape_like_pattern + + # Escape special characters in keyword to prevent SQL injection via LIKE wildcards + escaped_keyword = escape_like_pattern(keyword[:30]) + keyword_like_val = f"%{escaped_keyword}%" keyword_conditions = [ - WorkflowRun.inputs.ilike(keyword_like_val), - WorkflowRun.outputs.ilike(keyword_like_val), + WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"), + WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), + and_( + WorkflowRun.created_by_role == "end_user", + EndUser.session_id.ilike(keyword_like_val, escape="\\"), + ), ] # filter keyword by workflow run id diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f299ce3baa..9407a2b3f0 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -679,6 +679,7 @@ def _batch_upsert_draft_variable( def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: d: dict[str, Any] = { + "id": model.id, "app_id": model.app_id, "last_edited_at": None, "node_id": model.node_id, diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index cdc07c77a8..be1de3cdd2 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -98,7 +98,7 @@ def enable_annotation_reply_task( if annotations: for annotation in annotations: document = Document( - page_content=annotation.question, + page_content=annotation.question_text, metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 1eef361a92..3c5e152520 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 275f5abe6e..093342d1a3 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_id=workflow_id, user=account, application_generate_entity=entity, - invoke_from=InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index d59d5dc0fe..5012defdad 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -48,10 +48,6 @@ class MockModelClass(PluginModelClient): en_US="https://example.com/icon_small.png", zh_Hans="https://example.com/icon_small.png", ), - icon_large=I18nObject( - en_US="https://example.com/icon_large.png", - zh_Hans="https://example.com/icon_large.png", - ), supported_model_types=[ModelType.LLM], configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], models=[ diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e421e4ff36..9b0bd6275b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -10,6 +10,7 @@ from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import NodeRunResult from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -67,6 +68,16 @@ def init_code_node(code_config: dict): config=code_config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + code_limits=CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ), ) return node diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 72469ad646..dcf31aeca7 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -35,6 +35,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.enums import WorkflowExecutionStatus from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import GraphRunPausedEvent from core.workflow.runtime.graph_runtime_state import GraphRuntimeState from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState @@ -569,10 +570,10 @@ class TestPauseStatePersistenceLayerTestContainers: """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() - # Don't initialize - graph_runtime_state should not be set + # Don't initialize - graph_runtime_state should be uninitialized event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")]) - # Act & Assert - Should raise AttributeError - with pytest.raises(AttributeError): + # Act & Assert - Should raise GraphEngineLayerNotInitializedError + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index da73122cd7..5555400ca6 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -444,6 +444,78 @@ class TestAnnotationService: assert total == 1 assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with special characters in content + annotation_with_percent = { + "question": "Question with 50% discount", + "answer": "Answer about 50% discount offer", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id) + + annotation_with_underscore = { + "question": "Question with test_data", + "answer": "Answer about test_data value", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id) + + annotation_with_backslash = { + "question": "Question with path\\to\\file", + "answer": "Answer about path\\to\\file location", + } + AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id) + + # Create annotation that should NOT match (contains % but as part of different text) + annotation_no_match = { + "question": "Question with 100% different", + "answer": "Answer about 100% different content", + } + AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id) + + # Test 1: Search with % character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content + + # Test 2: Search with _ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="test_data" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content + + # Test 3: Search with \ character - should find exact match only + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="path\\to\\file" + ) + assert total == 1 + assert len(annotation_list) == 1 + assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="50%" + ) + # Should only find the 50% annotation, not the 100% one + assert total == 1 + assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) + def test_get_annotation_list_by_app_id_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index e53392bcef..745d6c97b0 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -7,7 +7,9 @@ from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService class TestAppService: @@ -71,6 +73,9 @@ class TestAppService: } # Create app + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -109,6 +114,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Test different app modes @@ -159,6 +167,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() created_app = app_service.create_app(tenant.id, app_args, account) @@ -194,6 +205,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create multiple apps @@ -245,6 +259,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create apps with different modes @@ -315,6 +332,9 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Create an app @@ -392,6 +412,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -458,6 +481,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -508,6 +534,9 @@ class TestAppService: "icon_background": "#45B7D1", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -562,6 +591,9 @@ class TestAppService: "icon_background": "#74B9FF", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -617,6 +649,9 @@ class TestAppService: "icon_background": "#A29BFE", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -672,6 +707,9 @@ class TestAppService: "icon_background": "#FD79A8", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -720,6 +758,9 @@ class TestAppService: "icon_background": "#E17055", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -768,6 +809,9 @@ class TestAppService: "icon_background": "#00B894", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -826,6 +870,9 @@ class TestAppService: "icon_background": "#6C5CE7", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -862,6 +909,9 @@ class TestAppService: "icon_background": "#FDCB6E", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -899,6 +949,9 @@ class TestAppService: "icon_background": "#E84393", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -947,8 +1000,132 @@ class TestAppService: "icon_background": "#D63031", } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() # Attempt to create app with invalid mode with pytest.raises(ValueError, match="invalid mode value"): app_service.create_app(tenant.id, app_args, account) + + def test_get_apps_with_special_characters_in_name( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test app retrieval with special characters in name search to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in name search are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Import here to avoid circular dependency + from services.app_service import AppService + + app_service = AppService() + + # Create apps with special characters in names + app_with_percent = app_service.create_app( + tenant.id, + { + "name": "App with 50% discount", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_underscore = app_service.create_app( + tenant.id, + { + "name": "test_data_app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + app_with_backslash = app_service.create_app( + tenant.id, + { + "name": "path\\to\\app", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Create app that should NOT match + app_no_match = app_service.create_app( + tenant.id, + { + "name": "100% different", + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + }, + account, + ) + + # Test 1: Search with % character + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "App with 50% discount" + + # Test 2: Search with _ character + args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "test_data_app" + + # Test 3: Search with \ character + args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].name == "path\\to\\app" + + # Test 4: Search with % should NOT match 100% (verifies escaping works) + args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + assert paginated_apps is not None + assert paginated_apps.total == 1 + assert all("50%" in app.name for app in paginated_apps.items) diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 612210ef86..d57ab7428b 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -228,7 +228,6 @@ class TestModelProviderService: mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity.icon_small_dark = None - mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity.background = "#FF6B6B" mock_provider_entity.help = None mock_provider_entity.supported_model_types = [ModelType.LLM, ModelType.TEXT_EMBEDDING] @@ -302,7 +301,6 @@ class TestModelProviderService: mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity_llm.icon_small_dark = None - mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_llm.background = "#FF6B6B" mock_provider_entity_llm.help = None mock_provider_entity_llm.supported_model_types = [ModelType.LLM] @@ -316,7 +314,6 @@ class TestModelProviderService: mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere 提供商"} mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity_embedding.icon_small_dark = None - mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_embedding.background = "#4ECDC4" mock_provider_entity_embedding.help = None mock_provider_entity_embedding.supported_model_types = [ModelType.TEXT_EMBEDDING] @@ -419,7 +416,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], configurate_methods=[], models=[], @@ -431,7 +427,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], configurate_methods=[], models=[], @@ -655,7 +650,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], ), ) @@ -1027,7 +1021,6 @@ class TestModelProviderService: label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, icon_small_dark=None, - icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-3.5-turbo", model_type=ModelType.LLM, @@ -1045,7 +1038,6 @@ class TestModelProviderService: label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, icon_small_dark=None, - icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-4", model_type=ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 6732b8d558..e8c7f17e0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import create_autospec, patch import pytest @@ -312,6 +313,85 @@ class TestTagService: result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") assert len(result_no_match) == 0 + def test_get_tags_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test tag retrieval with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _, \) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + from extensions.ext_database import db + + # Create tags with special characters in names + tag_with_percent = Tag( + name="50% discount", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_percent.id = str(uuid.uuid4()) + db.session.add(tag_with_percent) + + tag_with_underscore = Tag( + name="test_data_tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_underscore.id = str(uuid.uuid4()) + db.session.add(tag_with_underscore) + + tag_with_backslash = Tag( + name="path\\to\\tag", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_with_backslash.id = str(uuid.uuid4()) + db.session.add(tag_with_backslash) + + # Create tag that should NOT match + tag_no_match = Tag( + name="100% different", + type="app", + tenant_id=tenant.id, + created_by=account.id, + ) + tag_no_match.id = str(uuid.uuid4()) + db.session.add(tag_no_match) + + db.session.commit() + + # Act & Assert: Test 1 - Search with % character + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert result[0].name == "50% discount" + + # Test 2 - Search with _ character + result = TagService.get_tags("app", tenant.id, keyword="test_data") + assert len(result) == 1 + assert result[0].name == "test_data_tag" + + # Test 3 - Search with \ character + result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag") + assert len(result) == 1 + assert result[0].name == "path\\to\\tag" + + # Test 4 - Search with % should NOT match 100% (verifies escaping works) + result = TagService.get_tags("app", tenant.id, keyword="50%") + assert len(result) == 1 + assert all("50%" in item.name for item in result) + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 8322b9414e..5315960d73 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -474,64 +474,6 @@ class TestTriggerProviderService: assert subscription.name == original_name assert subscription.parameters == original_parameters - def test_rebuild_trigger_subscription_unsubscribe_error_continues( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test that unsubscribe errors are handled gracefully and operation continues. - - This test verifies: - - Unsubscribe errors are caught and logged but don't stop the rebuild - - Rebuild continues even if unsubscribe fails - """ - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - provider_id = TriggerProviderID("test_org/test_plugin/test_provider") - credential_type = CredentialType.API_KEY - - original_credentials = {"api_key": "original-key"} - subscription = self._create_test_subscription( - db_session_with_containers, - tenant.id, - account.id, - provider_id, - credential_type, - original_credentials, - mock_external_service_dependencies, - ) - - # Make unsubscribe_trigger raise an error (should be caught and continue) - mock_external_service_dependencies["trigger_manager"].unsubscribe_trigger.side_effect = ValueError( - "Unsubscribe failed" - ) - - new_subscription_entity = TriggerSubscriptionEntity( - endpoint=subscription.endpoint_id, - parameters={}, - properties={}, - expires_at=-1, - ) - mock_external_service_dependencies["trigger_manager"].subscribe_trigger.return_value = new_subscription_entity - - # Execute rebuild - should succeed despite unsubscribe error - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=tenant.id, - provider_id=provider_id, - subscription_id=subscription.id, - credentials={"api_key": "new-key"}, - parameters={}, - ) - - # Verify subscribe was still called (operation continued) - mock_external_service_dependencies["trigger_manager"].subscribe_trigger.assert_called_once() - - # Verify subscription was updated - db.session.refresh(subscription) - assert subscription.parameters == {} - def test_rebuild_trigger_subscription_subscription_not_found( self, db_session_with_containers, mock_external_service_dependencies ): @@ -558,70 +500,6 @@ class TestTriggerProviderService: parameters={}, ) - def test_rebuild_trigger_subscription_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test error when provider is not found. - - This test verifies: - - Proper error is raised when provider doesn't exist - """ - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - provider_id = TriggerProviderID("non_existent_org/non_existent_plugin/non_existent_provider") - - # Make get_trigger_provider return None - mock_external_service_dependencies["trigger_manager"].get_trigger_provider.return_value = None - - with pytest.raises(ValueError, match="Provider.*not found"): - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=tenant.id, - provider_id=provider_id, - subscription_id=fake.uuid4(), - credentials={}, - parameters={}, - ) - - def test_rebuild_trigger_subscription_unsupported_credential_type( - self, db_session_with_containers, mock_external_service_dependencies - ): - """ - Test error when credential type is not supported for rebuild. - - This test verifies: - - Proper error is raised for unsupported credential types (not OAUTH2 or API_KEY) - """ - fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) - - provider_id = TriggerProviderID("test_org/test_plugin/test_provider") - credential_type = CredentialType.UNAUTHORIZED # Not supported - - subscription = self._create_test_subscription( - db_session_with_containers, - tenant.id, - account.id, - provider_id, - credential_type, - {}, - mock_external_service_dependencies, - ) - - with pytest.raises(ValueError, match="Credential type not supported for rebuild"): - TriggerProviderService.rebuild_trigger_subscription( - tenant_id=tenant.id, - provider_id=provider_id, - subscription_id=subscription.id, - credentials={}, - parameters={}, - ) - def test_rebuild_trigger_subscription_name_uniqueness_check( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 7b95944bbe..040fb826e1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService -from services.app_service import AppService + +# Delay import of AppService to avoid circular dependency +# from services.app_service import AppService from services.workflow_app_service import WorkflowAppService @@ -86,6 +88,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -147,6 +152,9 @@ class TestWorkflowAppService: "api_rpm": 10, } + # Import here to avoid circular dependency + from services.app_service import AppService + app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -308,6 +316,156 @@ class TestWorkflowAppService: assert result_no_match["total"] == 0 assert len(result_no_match["data"]) == 0 + def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + r""" + Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. + + This test verifies: + - Special characters (%, _) in keyword are properly escaped + - Search treats special characters as literal characters, not wildcards + - SQL injection via LIKE wildcards is prevented + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) + + from extensions.ext_database import db + + service = WorkflowAppService() + + # Test 1: Search with % character + workflow_run_1 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}), + outputs=json.dumps({"result": "50% discount applied", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_1) + db.session.flush() + + workflow_app_log_1 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_1.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_1.id = str(uuid.uuid4()) + workflow_app_log_1.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_1) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should find the workflow_run_1 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"]) + + # Test 2: Search with _ character + workflow_run_2 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}), + outputs=json.dumps({"result": "test_data_value found", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_2) + db.session.flush() + + workflow_app_log_2 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_2.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_2.id = str(uuid.uuid4()) + workflow_app_log_2.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_2) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 + ) + # Should find the workflow_run_2 entry + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"]) + + # Test 3: Search with % should NOT match 100% (verifies escaping works correctly) + workflow_run_4 = WorkflowRun( + id=str(uuid.uuid4()), + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type="workflow", + triggered_from="app-run", + version="1.0.0", + graph=json.dumps({"nodes": [], "edges": []}), + status="succeeded", + inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}), + outputs=json.dumps({"result": "100% different result", "status": "success"}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(workflow_run_4) + db.session.flush() + + workflow_app_log_4 = WorkflowAppLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + workflow_run_id=workflow_run_4.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + workflow_app_log_4.id = str(uuid.uuid4()) + workflow_app_log_4.created_at = datetime.now(UTC) + db.session.add(workflow_app_log_4) + db.session.commit() + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 + ) + # Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4) + # This verifies that escaping works correctly - 50% should not match 100% + assert result["total"] >= 1 + assert len(result["data"]) >= 1 + # Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different) + found_run_ids = [log.workflow_run_id for log in result["data"]] + assert workflow_run_1.id in found_run_ids + assert workflow_run_4.id not in found_run_ids + def test_get_paginate_workflow_app_logs_with_status_filter( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index e29b98037f..b9977b1fb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -165,7 +165,7 @@ class TestRagPipelineRunTasks: "files": [], "user_id": account.id, "stream": False, - "invoke_from": "published", + "invoke_from": InvokeFrom.PUBLISHED_PIPELINE.value, "workflow_execution_id": str(uuid.uuid4()), "pipeline_config": { "app_id": str(uuid.uuid4()), @@ -249,7 +249,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -294,7 +294,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) @@ -743,7 +743,7 @@ class TestRagPipelineRunTasks: assert call_kwargs["pipeline"].id == pipeline.id assert call_kwargs["workflow_id"] == workflow.id assert call_kwargs["user"].id == account.id - assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED + assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE assert call_kwargs["streaming"] == False assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 209b6bf59b..6fce7849f9 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -16,6 +16,7 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -51,6 +52,7 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): os.environ.clear() # Set minimal required env vars + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -75,6 +77,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -124,6 +127,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): # Set environment variables using monkeypatch monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -140,6 +144,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): """Test that DB_EXTRAS options are properly merged with default timezone setting""" # Set environment variables + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") @@ -199,6 +204,7 @@ def test_celery_broker_url_with_special_chars_password( # Set up basic required environment variables (following existing pattern) monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") monkeypatch.setenv("DB_PASSWORD", "postgres") monkeypatch.setenv("DB_HOST", "localhost") diff --git a/api/tests/unit_tests/controllers/__init__.py b/api/tests/unit_tests/controllers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/common/test_fields.py b/api/tests/unit_tests/controllers/common/test_fields.py new file mode 100644 index 0000000000..d4dc13127d --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_fields.py @@ -0,0 +1,69 @@ +import builtins +from types import SimpleNamespace +from unittest.mock import patch + +from flask.views import MethodView as FlaskMethodView + +_NEEDS_METHOD_VIEW_CLEANUP = False +if not hasattr(builtins, "MethodView"): + builtins.MethodView = FlaskMethodView + _NEEDS_METHOD_VIEW_CLEANUP = True +from controllers.common.fields import Parameters, Site +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from models.model import IconType + + +def test_parameters_model_round_trip(): + parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[]) + + model = Parameters.model_validate(parameters) + + assert model.model_dump(mode="json") == parameters + + +def test_site_icon_url_uses_signed_url_for_image_icon(): + site = SimpleNamespace( + title="Example", + chat_color_theme=None, + chat_color_theme_inverted=False, + icon_type=IconType.IMAGE, + icon="file-id", + icon_background=None, + description=None, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + default_language="en-US", + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper: + model = Site.model_validate(site) + + assert model.icon_url == "signed" + mock_helper.assert_called_once_with("file-id") + + +def test_site_icon_url_is_none_for_non_image_icon(): + site = SimpleNamespace( + title="Example", + chat_color_theme=None, + chat_color_theme_inverted=False, + icon_type=IconType.EMOJI, + icon="file-id", + icon_background=None, + description=None, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + default_language="en-US", + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper: + model = Site.model_validate(site) + + assert model.icon_url is None + mock_helper.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/__init__.py b/api/tests/unit_tests/controllers/console/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py new file mode 100644 index 0000000000..40eb59a8f4 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import builtins +import sys +from datetime import datetime +from importlib import util +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import pytest +from flask.views import MethodView + +# kombu references MethodView as a global when importing celery/kombu pools. +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +def _load_app_module(): + module_name = "controllers.console.app.app" + if module_name in sys.modules: + return sys.modules[module_name] + + root = Path(__file__).resolve().parents[5] + module_path = root / "controllers" / "console" / "app" / "app.py" + + class _StubNamespace: + def __init__(self): + self.models: dict[str, Any] = {} + self.payload = None + + def schema_model(self, name, schema): + self.models[name] = schema + + def _decorator(self, obj): + return obj + + def doc(self, *args, **kwargs): + return self._decorator + + def expect(self, *args, **kwargs): + return self._decorator + + def response(self, *args, **kwargs): + return self._decorator + + def route(self, *args, **kwargs): + def decorator(obj): + return obj + + return decorator + + stub_namespace = _StubNamespace() + + original_console = sys.modules.get("controllers.console") + original_app_pkg = sys.modules.get("controllers.console.app") + stubbed_modules: list[tuple[str, ModuleType | None]] = [] + + console_module = ModuleType("controllers.console") + console_module.__path__ = [str(root / "controllers" / "console")] + console_module.console_ns = stub_namespace + console_module.api = None + console_module.bp = None + sys.modules["controllers.console"] = console_module + + app_package = ModuleType("controllers.console.app") + app_package.__path__ = [str(root / "controllers" / "console" / "app")] + sys.modules["controllers.console.app"] = app_package + console_module.app = app_package + + def _stub_module(name: str, attrs: dict[str, Any]): + original = sys.modules.get(name) + module = ModuleType(name) + for key, value in attrs.items(): + setattr(module, key, value) + sys.modules[name] = module + stubbed_modules.append((name, original)) + + class _OpsTraceManager: + @staticmethod + def get_app_tracing_config(app_id: str) -> dict[str, Any]: + return {} + + @staticmethod + def update_app_tracing_config(app_id: str, **kwargs) -> None: + return None + + _stub_module( + "core.ops.ops_trace_manager", + { + "OpsTraceManager": _OpsTraceManager, + "TraceQueueManager": object, + "TraceTask": object, + }, + ) + + spec = util.spec_from_file_location(module_name, module_path) + module = util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + assert spec.loader is not None + spec.loader.exec_module(module) + finally: + for name, original in reversed(stubbed_modules): + if original is not None: + sys.modules[name] = original + else: + sys.modules.pop(name, None) + if original_console is not None: + sys.modules["controllers.console"] = original_console + else: + sys.modules.pop("controllers.console", None) + if original_app_pkg is not None: + sys.modules["controllers.console.app"] = original_app_pkg + else: + sys.modules.pop("controllers.console.app", None) + + return module + + +_app_module = _load_app_module() +AppDetailWithSite = _app_module.AppDetailWithSite +AppPagination = _app_module.AppPagination +AppPartial = _app_module.AppPartial + + +@pytest.fixture(autouse=True) +def patch_signed_url(monkeypatch): + """Ensure icon URL generation uses a deterministic helper for tests.""" + + def _fake_signed_url(key: str | None) -> str | None: + if not key: + return None + return f"signed:{key}" + + monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + + +def _ts(hour: int = 12) -> datetime: + return datetime(2024, 1, 1, hour, 0, 0) + + +def _dummy_model_config(): + return SimpleNamespace( + model_dict={"provider": "openai", "name": "gpt-4o"}, + pre_prompt="hello", + created_by="config-author", + created_at=_ts(9), + updated_by="config-editor", + updated_at=_ts(10), + ) + + +def _dummy_workflow(): + return SimpleNamespace( + id="wf-1", + created_by="workflow-author", + created_at=_ts(8), + updated_by="workflow-editor", + updated_at=_ts(9), + ) + + +def test_app_partial_serialization_uses_aliases(): + created_at = _ts() + app_obj = SimpleNamespace( + id="app-1", + name="My App", + desc_or_prompt="Prompt snippet", + mode_compatible_with_agent="chat", + icon_type="image", + icon="icon-key", + icon_background="#fff", + app_model_config=_dummy_model_config(), + workflow=_dummy_workflow(), + created_by="creator", + created_at=created_at, + updated_by="editor", + updated_at=created_at, + tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")], + access_mode="private", + create_user_name="Creator", + author_name="Author", + has_draft_trigger=True, + ) + + serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["description"] == "Prompt snippet" + assert serialized["mode"] == "chat" + assert serialized["icon_url"] == "signed:icon-key" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["updated_at"] == int(created_at.timestamp()) + assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"} + assert serialized["workflow"]["id"] == "wf-1" + assert serialized["tags"][0]["name"] == "Utilities" + + +def test_app_detail_with_site_includes_nested_serialization(): + timestamp = _ts(14) + site = SimpleNamespace( + code="site-code", + title="Public Site", + icon_type="image", + icon="site-icon", + created_at=timestamp, + updated_at=timestamp, + ) + app_obj = SimpleNamespace( + id="app-2", + name="Detailed App", + description="Desc", + mode_compatible_with_agent="advanced-chat", + icon_type="image", + icon="detail-icon", + icon_background="#123456", + enable_site=True, + enable_api=True, + app_model_config={ + "opening_statement": "hi", + "model": {"provider": "openai", "name": "gpt-4o"}, + "retriever_resource": {"enabled": True}, + }, + workflow=_dummy_workflow(), + tracing={"enabled": True}, + use_icon_as_answer_icon=True, + created_by="creator", + created_at=timestamp, + updated_by="editor", + updated_at=timestamp, + access_mode="public", + tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")], + api_base_url="https://api.example.com/v1", + max_active_requests=5, + deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}], + site=site, + ) + + serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["icon_url"] == "signed:detail-icon" + assert serialized["model_config"]["retriever_resource"] == {"enabled": True} + assert serialized["deleted_tools"][0]["tool_name"] == "search" + assert serialized["site"]["icon_url"] == "signed:site-icon" + assert serialized["site"]["created_at"] == int(timestamp.timestamp()) + + +def test_app_pagination_aliases_per_page_and_has_next(): + item_one = SimpleNamespace( + id="app-10", + name="Paginated One", + desc_or_prompt="Summary", + mode_compatible_with_agent="chat", + icon_type="image", + icon="first-icon", + created_at=_ts(15), + updated_at=_ts(15), + ) + item_two = SimpleNamespace( + id="app-11", + name="Paginated Two", + desc_or_prompt="Summary", + mode_compatible_with_agent="agent-chat", + icon_type="emoji", + icon="🙂", + created_at=_ts(16), + updated_at=_ts(16), + ) + pagination = SimpleNamespace( + page=2, + per_page=10, + total=50, + has_next=True, + items=[item_one, item_two], + ) + + serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json") + + assert serialized["page"] == 2 + assert serialized["limit"] == 10 + assert serialized["has_more"] is True + assert len(serialized["data"]) == 2 + assert serialized["data"][0]["icon_url"] == "signed:first-icon" + assert serialized["data"][1]["icon_url"] is None diff --git a/api/tests/unit_tests/controllers/console/app/test_xss_prevention.py b/api/tests/unit_tests/controllers/console/app/test_xss_prevention.py new file mode 100644 index 0000000000..313818547b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_xss_prevention.py @@ -0,0 +1,254 @@ +""" +Unit tests for XSS prevention in App payloads. + +This test module validates that HTML tags, JavaScript, and other potentially +dangerous content are rejected in App names and descriptions. +""" + +import pytest + +from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload + + +class TestXSSPreventionUnit: + """Unit tests for XSS prevention in App payloads.""" + + def test_create_app_valid_names(self): + """Test CreateAppPayload with valid app names.""" + # Normal app names should be valid + valid_names = [ + "My App", + "Test App 123", + "App with - dash", + "App with _ underscore", + "App with + plus", + "App with () parentheses", + "App with [] brackets", + "App with {} braces", + "App with ! exclamation", + "App with @ at", + "App with # hash", + "App with $ dollar", + "App with % percent", + "App with ^ caret", + "App with & ampersand", + "App with * asterisk", + "Unicode: 测试应用", + "Emoji: 🤖", + "Mixed: Test 测试 123", + ] + + for name in valid_names: + payload = CreateAppPayload( + name=name, + mode="chat", + ) + assert payload.name == name + + def test_create_app_xss_script_tags(self): + """Test CreateAppPayload rejects script tags.""" + xss_payloads = [ + "", + "", + "", + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_iframe_tags(self): + """Test CreateAppPayload rejects iframe tags.""" + xss_payloads = [ + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_javascript_protocol(self): + """Test CreateAppPayload rejects javascript: protocol.""" + xss_payloads = [ + "javascript:alert(1)", + "JAVASCRIPT:alert(1)", + "JavaScript:alert(document.cookie)", + "javascript:void(0)", + "javascript://comment%0Aalert(1)", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_svg_onload(self): + """Test CreateAppPayload rejects SVG with onload.""" + xss_payloads = [ + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_event_handlers(self): + """Test CreateAppPayload rejects HTML event handlers.""" + xss_payloads = [ + "
", + "", + "", + "", + "", + "
", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_object_embed(self): + """Test CreateAppPayload rejects object and embed tags.""" + xss_payloads = [ + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_link_javascript(self): + """Test CreateAppPayload rejects link tags with javascript.""" + xss_payloads = [ + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_in_description(self): + """Test CreateAppPayload rejects XSS in description.""" + xss_descriptions = [ + "", + "javascript:alert(1)", + "", + ] + + for description in xss_descriptions: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload( + name="Valid Name", + mode="chat", + description=description, + ) + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_valid_descriptions(self): + """Test CreateAppPayload with valid descriptions.""" + valid_descriptions = [ + "A simple description", + "Description with < and > symbols", + "Description with & ampersand", + "Description with 'quotes' and \"double quotes\"", + "Description with / slashes", + "Description with \\ backslashes", + "Description with ; semicolons", + "Unicode: 这是一个描述", + "Emoji: 🎉🚀", + ] + + for description in valid_descriptions: + payload = CreateAppPayload( + name="Valid App Name", + mode="chat", + description=description, + ) + assert payload.description == description + + def test_create_app_none_description(self): + """Test CreateAppPayload with None description.""" + payload = CreateAppPayload( + name="Valid App Name", + mode="chat", + description=None, + ) + assert payload.description is None + + def test_update_app_xss_prevention(self): + """Test UpdateAppPayload also prevents XSS.""" + xss_names = [ + "", + "javascript:alert(1)", + "", + ] + + for name in xss_names: + with pytest.raises(ValueError) as exc_info: + UpdateAppPayload(name=name) + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_update_app_valid_names(self): + """Test UpdateAppPayload with valid names.""" + payload = UpdateAppPayload(name="Valid Updated Name") + assert payload.name == "Valid Updated Name" + + def test_copy_app_xss_prevention(self): + """Test CopyAppPayload also prevents XSS.""" + xss_names = [ + "", + "javascript:alert(1)", + "", + ] + + for name in xss_names: + with pytest.raises(ValueError) as exc_info: + CopyAppPayload(name=name) + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_copy_app_valid_names(self): + """Test CopyAppPayload with valid names.""" + payload = CopyAppPayload(name="Valid Copy Name") + assert payload.name == "Valid Copy Name" + + def test_copy_app_none_name(self): + """Test CopyAppPayload with None name (should be allowed).""" + payload = CopyAppPayload(name=None) + assert payload.name is None + + def test_edge_case_angle_brackets_content(self): + """Test that angle brackets with actual content are rejected.""" + # Angle brackets without valid HTML-like patterns should be checked + # The regex pattern <.*?on\w+\s*= should catch event handlers + # But let's verify other patterns too + + # Valid: angle brackets used as symbols (not matched by our patterns) + # Our patterns specifically look for dangerous constructs + + # Invalid: actual HTML tags with event handlers + invalid_names = [ + "
", + "", + ] + + for name in invalid_names: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 399caf8c4d..3ddfcdb832 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -171,7 +171,7 @@ class TestOAuthCallback: ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} - mock_generate_account.return_value = oauth_setup["account"] + mock_generate_account.return_value = (oauth_setup["account"], True) mock_account_service.login.return_value = oauth_setup["token_pair"] with app.test_request_context("/auth/oauth/github/callback?code=test_code"): @@ -179,7 +179,7 @@ class TestOAuthCallback: oauth_setup["provider"].get_access_token.assert_called_once_with("test_code") oauth_setup["provider"].get_user_info.assert_called_once_with("access_token") - mock_redirect.assert_called_once_with("http://localhost:3000") + mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=true") @pytest.mark.parametrize( ("exception", "expected_error"), @@ -223,7 +223,7 @@ class TestOAuthCallback: # This documents actual behavior. See test_defensive_check_for_closed_account_status for details ( AccountStatus.CLOSED.value, - "http://localhost:3000", + "http://localhost:3000?oauth_new_user=false", ), ], ) @@ -260,7 +260,7 @@ class TestOAuthCallback: account = MagicMock() account.status = account_status account.id = "123" - mock_generate_account.return_value = account + mock_generate_account.return_value = (account, False) # Mock login for CLOSED status mock_token_pair = MagicMock() @@ -296,7 +296,7 @@ class TestOAuthCallback: mock_account = MagicMock() mock_account.status = AccountStatus.PENDING - mock_generate_account.return_value = mock_account + mock_generate_account.return_value = (mock_account, False) mock_token_pair = MagicMock() mock_token_pair.access_token = "jwt_access_token" @@ -360,7 +360,7 @@ class TestOAuthCallback: closed_account.status = AccountStatus.CLOSED closed_account.id = "123" closed_account.name = "Closed Account" - mock_generate_account.return_value = closed_account + mock_generate_account.return_value = (closed_account, False) # Mock successful login (current behavior) mock_token_pair = MagicMock() @@ -374,7 +374,7 @@ class TestOAuthCallback: resource.get("github") # Verify current behavior: login succeeds (this is NOT ideal) - mock_redirect.assert_called_once_with("http://localhost:3000") + mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=false") mock_account_service.login.assert_called_once() # Document expected behavior in comments: @@ -458,8 +458,9 @@ class TestAccountGeneration: with pytest.raises(AccountRegisterError): _generate_account("github", user_info) else: - result = _generate_account("github", user_info) + result, oauth_new_user = _generate_account("github", user_info) assert result == mock_account + assert oauth_new_user == should_create if should_create: mock_register_service.register.assert_called_once_with( @@ -490,9 +491,10 @@ class TestAccountGeneration: mock_tenant_service.create_tenant.return_value = mock_new_tenant with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): - result = _generate_account("github", user_info) + result, oauth_new_user = _generate_account("github", user_info) assert result == mock_account + assert oauth_new_user is False mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace") mock_tenant_service.create_tenant_member.assert_called_once_with( mock_new_tenant, mock_account, role="owner" diff --git a/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py b/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py new file mode 100644 index 0000000000..f8dd98fdb2 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_document_detail_api_data_source_info.py @@ -0,0 +1,145 @@ +""" +Test for document detail API data_source_info serialization fix. + +This test verifies that the document detail API returns both data_source_info +and data_source_detail_dict for all data_source_type values, including "local_file". +""" + +import json +from typing import Generic, Literal, NotRequired, TypedDict, TypeVar, Union + +from models.dataset import Document + + +class LocalFileInfo(TypedDict): + file_path: str + size: int + created_at: NotRequired[str] + + +class UploadFileInfo(TypedDict): + upload_file_id: str + + +class NotionImportInfo(TypedDict): + notion_page_id: str + workspace_id: str + + +class WebsiteCrawlInfo(TypedDict): + url: str + job_id: str + + +RawInfo = Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo] +T_type = TypeVar("T_type", bound=str) +T_info = TypeVar("T_info", bound=Union[LocalFileInfo, UploadFileInfo, NotionImportInfo, WebsiteCrawlInfo]) + + +class Case(TypedDict, Generic[T_type, T_info]): + data_source_type: T_type + data_source_info: str + expected_raw: T_info + + +LocalFileCase = Case[Literal["local_file"], LocalFileInfo] +UploadFileCase = Case[Literal["upload_file"], UploadFileInfo] +NotionImportCase = Case[Literal["notion_import"], NotionImportInfo] +WebsiteCrawlCase = Case[Literal["website_crawl"], WebsiteCrawlInfo] + +AnyCase = Union[LocalFileCase, UploadFileCase, NotionImportCase, WebsiteCrawlCase] + + +case_1: LocalFileCase = { + "data_source_type": "local_file", + "data_source_info": json.dumps({"file_path": "/tmp/test.txt", "size": 1024}), + "expected_raw": {"file_path": "/tmp/test.txt", "size": 1024}, +} + + +# ERROR: Expected LocalFileInfo, but got WebsiteCrawlInfo +case_2: LocalFileCase = { + "data_source_type": "local_file", + "data_source_info": "...", + "expected_raw": {"file_path": "https://google.com", "size": 123}, +} + +cases: list[AnyCase] = [case_1] + + +class TestDocumentDetailDataSourceInfo: + """Test cases for document detail API data_source_info serialization.""" + + def test_data_source_info_dict_returns_raw_data(self): + """Test that data_source_info_dict returns raw JSON data for all data_source_type values.""" + # Test data for different data_source_type values + for case in cases: + document = Document( + data_source_type=case["data_source_type"], + data_source_info=case["data_source_info"], + ) + + # Test data_source_info_dict (raw data) + raw_result = document.data_source_info_dict + assert raw_result == case["expected_raw"], f"Failed for {case['data_source_type']}" + + # Verify raw_result is always a valid dict + assert isinstance(raw_result, dict) + + def test_local_file_data_source_info_without_db_context(self): + """Test that local_file type data_source_info_dict works without database context.""" + test_data: LocalFileInfo = { + "file_path": "/local/path/document.txt", + "size": 512, + "created_at": "2024-01-01T00:00:00Z", + } + + document = Document( + data_source_type="local_file", + data_source_info=json.dumps(test_data), + ) + + # data_source_info_dict should return the raw data (this doesn't need DB context) + raw_data = document.data_source_info_dict + assert raw_data == test_data + assert isinstance(raw_data, dict) + + # Verify the data contains expected keys for pipeline mode + assert "file_path" in raw_data + assert "size" in raw_data + + def test_notion_and_website_crawl_data_source_detail(self): + """Test that notion_import and website_crawl return raw data in data_source_detail_dict.""" + # Test notion_import + notion_data: NotionImportInfo = {"notion_page_id": "page-123", "workspace_id": "ws-456"} + document = Document( + data_source_type="notion_import", + data_source_info=json.dumps(notion_data), + ) + + # data_source_detail_dict should return raw data for notion_import + detail_result = document.data_source_detail_dict + assert detail_result == notion_data + + # Test website_crawl + website_data: WebsiteCrawlInfo = {"url": "https://example.com", "job_id": "job-789"} + document = Document( + data_source_type="website_crawl", + data_source_info=json.dumps(website_data), + ) + + # data_source_detail_dict should return raw data for website_crawl + detail_result = document.data_source_detail_dict + assert detail_result == website_data + + def test_local_file_data_source_detail_dict_without_db(self): + """Test that local_file returns empty data_source_detail_dict (this doesn't need DB context).""" + # Test local_file - this should work without database context since it returns {} early + document = Document( + data_source_type="local_file", + data_source_info=json.dumps({"file_path": "/tmp/test.txt"}), + ) + + # Should return empty dict for local_file type (handled in the model) + detail_result = document.data_source_detail_dict + assert detail_result == {} diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py index 2630fbcfd0..370bf63fdb 100644 --- a/api/tests/unit_tests/controllers/console/test_files_security.py +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -1,7 +1,9 @@ +import builtins import io from unittest.mock import patch import pytest +from flask.views import MethodView from werkzeug.exceptions import Forbidden from controllers.common.errors import ( @@ -14,6 +16,9 @@ from controllers.common.errors import ( from services.errors.file import FileTooLargeError as ServiceFileTooLargeError from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + class TestFileUploadSecurity: """Test file upload security logic without complex framework setup""" @@ -128,7 +133,7 @@ class TestFileUploadSecurity: # Test passes if no exception is raised # Test 4: Service error handling - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_file_too_large_error(self, mock_upload): """Test that service FileTooLargeError is properly converted""" mock_upload.side_effect = ServiceFileTooLargeError("File too large") @@ -140,7 +145,7 @@ class TestFileUploadSecurity: with pytest.raises(FileTooLargeError): raise FileTooLargeError(e.description) - @patch("services.file_service.FileService.upload_file") + @patch("controllers.console.files.FileService.upload_file") def test_should_handle_unsupported_file_type_error(self, mock_upload): """Test that service UnsupportedFileTypeError is properly converted""" mock_upload.side_effect = ServiceUnsupportedFileTypeError() diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py new file mode 100644 index 0000000000..2835f7ffbf --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -0,0 +1,174 @@ +"""Unit tests for controllers.web.message message list mapping.""" + +from __future__ import annotations + +import builtins +from datetime import datetime +from types import ModuleType, SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask import Flask +from flask.views import MethodView + +# Ensure flask_restx.api finds MethodView during import. +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +def _load_controller_module(): + """Import controllers.web.message using a stub package.""" + + import importlib + import importlib.util + import sys + + parent_module_name = "controllers.web" + module_name = f"{parent_module_name}.message" + + if parent_module_name not in sys.modules: + from flask_restx import Namespace + + stub = ModuleType(parent_module_name) + stub.__file__ = "controllers/web/__init__.py" + stub.__path__ = ["controllers/web"] + stub.__package__ = "controllers" + stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True) + stub.web_ns = Namespace("web", description="Web API", path="/") + sys.modules[parent_module_name] = stub + + wraps_module_name = f"{parent_module_name}.wraps" + if wraps_module_name not in sys.modules: + wraps_stub = ModuleType(wraps_module_name) + + class WebApiResource: + pass + + wraps_stub.WebApiResource = WebApiResource + sys.modules[wraps_module_name] = wraps_stub + + return importlib.import_module(module_name) + + +message_module = _load_controller_module() +MessageListApi = message_module.MessageListApi + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_message_list_mapping(app: Flask) -> None: + conversation_id = str(uuid4()) + message_id = str(uuid4()) + + created_at = datetime(2024, 1, 1, 12, 0, 0) + resource_created_at = datetime(2024, 1, 1, 13, 0, 0) + thought_created_at = datetime(2024, 1, 1, 14, 0, 0) + + retriever_resource_obj = SimpleNamespace( + id="res-obj", + message_id=message_id, + position=2, + dataset_id="ds-1", + dataset_name="dataset", + document_id="doc-1", + document_name="document", + data_source_type="file", + segment_id="seg-1", + score=0.9, + hit_count=1, + word_count=10, + segment_position=0, + index_node_hash="hash", + content="content", + created_at=resource_created_at, + ) + + agent_thought = SimpleNamespace( + id="thought-1", + chain_id=None, + message_chain_id="chain-1", + message_id=message_id, + position=1, + thought="thinking", + tool="tool", + tool_labels={"label": "value"}, + tool_input="{}", + created_at=thought_created_at, + observation="observed", + files=["file-a"], + ) + + message_file_obj = SimpleNamespace( + id="file-obj", + filename="b.txt", + type="file", + url=None, + mime_type=None, + size=None, + transfer_method="local", + belongs_to=None, + upload_file_id=None, + ) + + message = SimpleNamespace( + id=message_id, + conversation_id=conversation_id, + parent_message_id=None, + inputs={"foo": "bar"}, + query="hello", + re_sign_file_url_answer="answer", + user_feedback=SimpleNamespace(rating="like"), + retriever_resources=[ + {"id": "res-dict", "message_id": message_id, "position": 1}, + retriever_resource_obj, + ], + created_at=created_at, + agent_thoughts=[agent_thought], + message_files=[ + {"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"}, + message_file_obj, + ], + status="success", + error=None, + message_metadata_dict={"meta": "value"}, + ) + + pagination = SimpleNamespace(limit=20, has_more=False, data=[message]) + app_model = SimpleNamespace(mode="chat") + end_user = SimpleNamespace() + + with ( + patch.object(message_module.MessageService, "pagination_by_first_id", return_value=pagination) as mock_page, + app.test_request_context(f"/messages?conversation_id={conversation_id}&limit=20"), + ): + response = MessageListApi().get(app_model, end_user) + + mock_page.assert_called_once_with(app_model, end_user, conversation_id, None, 20) + assert response["limit"] == 20 + assert response["has_more"] is False + assert len(response["data"]) == 1 + + item = response["data"][0] + assert item["id"] == message_id + assert item["conversation_id"] == conversation_id + assert item["inputs"] == {"foo": "bar"} + assert item["answer"] == "answer" + assert item["feedback"]["rating"] == "like" + assert item["metadata"] == {"meta": "value"} + assert item["created_at"] == int(created_at.timestamp()) + + assert item["retriever_resources"][0]["id"] == "res-dict" + assert item["retriever_resources"][1]["id"] == "res-obj" + assert item["retriever_resources"][1]["created_at"] == int(resource_created_at.timestamp()) + + assert item["agent_thoughts"][0]["chain_id"] == "chain-1" + assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp()) + + assert item["message_files"][0]["id"] == "file-dict" + assert item["message_files"][1]["id"] == "file-obj" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py new file mode 100644 index 0000000000..205b157542 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_answer_node.py @@ -0,0 +1,390 @@ +""" +Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method, +specifically testing the ANSWER node message_replace logic. +""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity +from core.app.entities.queue_entities import QueueNodeSucceededEvent +from core.workflow.enums import NodeType +from models import EndUser +from models.model import AppMode + + +class TestAnswerNodeMessageReplace: + """Test cases for ANSWER node message_replace event logic.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock(spec=AdvancedChatAppGenerateEntity) + entity.task_id = "test-task-id" + entity.app_id = "test-app-id" + entity.workflow_run_id = "test-workflow-run-id" + # minimal app_config used by pipeline internals + entity.app_config = SimpleNamespace( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.ADVANCED_CHAT, + app_model_config_dict={}, + additional_features=None, + sensitive_word_avoidance=None, + ) + entity.query = "test query" + entity.files = [] + entity.extras = {} + entity.trace_manager = None + entity.inputs = {} + entity.invoke_from = "debugger" + return entity + + @pytest.fixture + def mock_workflow(self): + """Create a mock workflow.""" + workflow = Mock() + workflow.id = "test-workflow-id" + workflow.features_dict = {} + return workflow + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = Mock() + manager.listen.return_value = [] + manager.graph_runtime_state = None + return manager + + @pytest.fixture + def mock_conversation(self): + """Create a mock conversation.""" + conversation = Mock() + conversation.id = "test-conversation-id" + conversation.mode = "advanced_chat" + return conversation + + @pytest.fixture + def mock_message(self): + """Create a mock message.""" + message = Mock() + message.id = "test-message-id" + message.query = "test query" + message.created_at = Mock() + message.created_at.timestamp.return_value = 1234567890 + return message + + @pytest.fixture + def mock_user(self): + """Create a mock end user.""" + user = MagicMock(spec=EndUser) + user.id = "test-user-id" + user.session_id = "test-session-id" + return user + + @pytest.fixture + def mock_draft_var_saver_factory(self): + """Create a mock draft variable saver factory.""" + return Mock() + + @pytest.fixture + def pipeline( + self, + mock_application_generate_entity, + mock_workflow, + mock_queue_manager, + mock_conversation, + mock_message, + mock_user, + mock_draft_var_saver_factory, + ): + """Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies.""" + from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline + + with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"): + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=mock_application_generate_entity, + workflow=mock_workflow, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + user=mock_user, + stream=True, + dialogue_count=1, + draft_var_saver_factory=mock_draft_var_saver_factory, + ) + # Initialize workflow run id to avoid validation errors + pipeline._workflow_run_id = "test-workflow-run-id" + # Mock the message cycle manager methods we need to track + pipeline._message_cycle_manager.message_replace_to_stream_response = Mock() + return pipeline + + def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity): + """ + Test that when an ANSWER node's final output differs from accumulated answer, + a message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "initial answer" + + # Create ANSWER node succeeded event with different final output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": "updated final answer"}, + ) + + # Mock the workflow response converter to avoid extra processing + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + responses = list(pipeline._handle_node_succeeded_event(event)) + + # Assert + assert pipeline._task_state.answer == "updated final answer" + # Verify message_replace was called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with( + answer="updated final answer", reason="variable_update" + ) + + def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node's final output is the same as accumulated answer, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "same answer" + + # Create ANSWER node succeeded event with same output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": "same answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "same answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node's output is None or missing 'answer' key, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with None output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": None}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline): + """ + Test that when an ANSWER node has empty outputs dict, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with empty outputs + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_no_answer_key_in_outputs(self, pipeline): + """ + Test that when an ANSWER node's outputs don't contain 'answer' key, + no message_replace event is sent. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event without 'answer' key in outputs + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"other_key": "some value"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_non_answer_node_does_not_send_message_replace(self, pipeline): + """ + Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Test with LLM node + llm_event = QueueNodeSucceededEvent( + node_execution_id="test-llm-execution-id", + node_id="test-llm-node", + node_type=NodeType.LLM, + start_at=datetime.now(), + outputs={"answer": "different answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(llm_event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_end_node_does_not_send_message_replace(self, pipeline): + """ + Test that END nodes don't trigger message_replace events even with 'answer' output. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "existing answer" + + # Create END node succeeded event with answer output + event = QueueNodeSucceededEvent( + node_execution_id="test-end-execution-id", + node_id="test-end-node", + node_type=NodeType.END, + start_at=datetime.now(), + outputs={"answer": "different answer"}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should remain unchanged + assert pipeline._task_state.answer == "existing answer" + # Verify message_replace was NOT called + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called() + + def test_answer_node_with_numeric_output_converts_to_string(self, pipeline): + """ + Test that when an ANSWER node's final output is numeric, + it gets converted to string properly. + """ + # Arrange: Set initial accumulated answer + pipeline._task_state.answer = "text answer" + + # Create ANSWER node succeeded event with numeric output + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={"answer": 12345}, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: answer should be converted to string + assert pipeline._task_state.answer == "12345" + # Verify message_replace was called with string + pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with( + answer="12345", reason="variable_update" + ) + + def test_answer_node_files_are_recorded(self, pipeline): + """ + Test that ANSWER nodes properly record files from outputs. + """ + # Arrange + pipeline._task_state.answer = "existing answer" + + # Create ANSWER node succeeded event with files + event = QueueNodeSucceededEvent( + node_execution_id="test-node-execution-id", + node_id="test-answer-node", + node_type=NodeType.ANSWER, + start_at=datetime.now(), + outputs={ + "answer": "same answer", + "files": [ + {"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"} + ], + }, + ) + + # Mock the workflow response converter + pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"]) + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None) + pipeline._save_output_for_event = Mock() + + # Act + list(pipeline._handle_node_succeeded_event(event)) + + # Assert: files should be recorded + assert len(pipeline._recorded_files) == 1 + assert pipeline._recorded_files[0] == event.outputs["files"][0] diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 1c9f577a50..6b40bf462b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -431,10 +431,10 @@ class TestWorkflowResponseConverterServiceApiTruncation: description="Explore calls should have truncation enabled", ), TestCase( - name="published_truncation_enabled", - invoke_from=InvokeFrom.PUBLISHED, + name="published_pipeline_truncation_enabled", + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, expected_truncation_enabled=True, - description="Published app calls should have truncation enabled", + description="Published pipeline calls should have truncation enabled", ), ], ids=lambda x: x.name, diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py new file mode 100644 index 0000000000..b6e8cc9c8e --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -0,0 +1,144 @@ +from collections.abc import Sequence +from datetime import datetime +from unittest.mock import Mock + +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.variables import StringVariable +from core.variables.segments import Segment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.graph_engine.protocols.command_channel import CommandChannel +from core.workflow.graph_events.node import NodeRunSucceededEvent +from core.workflow.node_events import NodeRunResult +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers +from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from core.workflow.system_variable import SystemVariable + + +class MockReadOnlyVariablePool: + def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None: + self._variables = variables or {} + + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + return self._variables.get((selector[0], selector[1])) + + def get_all_by_node(self, node_id: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == node_id} + + def get_by_prefix(self, prefix: str) -> dict[str, object]: + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} + + +def _build_graph_runtime_state( + variable_pool: MockReadOnlyVariablePool, + conversation_id: str | None = None, +) -> ReadOnlyGraphRuntimeState: + graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + graph_runtime_state.variable_pool = variable_pool + graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() + return graph_runtime_state + + +def _build_node_run_succeeded_event( + *, + node_type: NodeType, + outputs: dict[str, object] | None = None, + process_data: dict[str, object] | None = None, +) -> NodeRunSucceededEvent: + return NodeRunSucceededEvent( + id="node-exec-id", + node_id="assigner", + node_type=node_type, + start_at=datetime.utcnow(), + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs or {}, + process_data=process_data or {}, + ), + ) + + +def test_persists_conversation_variables_from_assigner_output(): + conversation_id = "conv-123" + variable = StringVariable( + id="var-1", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) + updater.flush.assert_called_once() + + +def test_skips_when_outputs_missing(): + conversation_id = "conv-456" + variable = StringVariable( + id="var-2", + name="name", + value="updated", + selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], + ) + + variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_assigner_nodes(): + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.LLM) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() + + +def test_skips_non_conversation_variables(): + conversation_id = "conv-789" + non_conversation_variable = StringVariable( + id="var-3", + name="name", + value="updated", + selector=["environment", "name"], + ) + process_data = common_helpers.set_updated_variables( + {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] + ) + + variable_pool = MockReadOnlyVariablePool() + + updater = Mock() + layer = ConversationVariablePersistenceLayer(updater) + layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + + event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + layer.on_event(event) + + updater.update.assert_not_called() + updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 534420f21e..1d885f6b2e 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -1,4 +1,5 @@ import json +from collections.abc import Sequence from time import time from unittest.mock import Mock @@ -15,6 +16,7 @@ from core.app.layers.pause_state_persist_layer import ( from core.variables.segments import Segment from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph_engine.entities.commands import GraphEngineCommand +from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError from core.workflow.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, @@ -66,8 +68,10 @@ class MockReadOnlyVariablePool: def __init__(self, variables: dict[tuple[str, str], object] | None = None): self._variables = variables or {} - def get(self, node_id: str, variable_key: str) -> Segment | None: - value = self._variables.get((node_id, variable_key)) + def get(self, selector: Sequence[str]) -> Segment | None: + if len(selector) < 2: + return None + value = self._variables.get((selector[0], selector[1])) if value is None: return None mock_segment = Mock(spec=Segment) @@ -209,8 +213,9 @@ class TestPauseStatePersistenceLayer: assert layer._session_maker is session_factory assert layer._state_owner_user_id == state_owner_user_id - assert not hasattr(layer, "graph_runtime_state") - assert not hasattr(layer, "command_channel") + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + assert layer.command_channel is None def test_initialize_sets_dependencies(self): session_factory = Mock(name="session_factory") @@ -295,7 +300,7 @@ class TestPauseStatePersistenceLayer: mock_factory.assert_not_called() mock_repo.create_workflow_pause.assert_not_called() - def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self): + def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self): session_factory = Mock(name="session_factory") layer = PauseStatePersistenceLayer( session_factory=session_factory, @@ -305,7 +310,7 @@ class TestPauseStatePersistenceLayer: event = TestDataFactory.create_graph_run_paused_event() - with pytest.raises(AttributeError): + with pytest.raises(GraphEngineLayerNotInitializedError): layer.on_event(event) def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index beae1d0358..d6d75fb72f 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -14,12 +14,12 @@ def test_successful_request(mock_get_client): mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 - mock_client.send.return_value = mock_response mock_client.request.return_value = mock_response mock_get_client.return_value = mock_client response = make_request("GET", "http://example.com") assert response.status_code == 200 + mock_client.request.assert_called_once() @patch("core.helper.ssrf_proxy._get_ssrf_client") @@ -27,7 +27,6 @@ def test_retry_exceed_max_retries(mock_get_client): mock_client = MagicMock() mock_response = MagicMock() mock_response.status_code = 500 - mock_client.send.return_value = mock_response mock_client.request.return_value = mock_response mock_get_client.return_value = mock_client @@ -72,34 +71,12 @@ class TestGetUserProvidedHostHeader: assert result in ("first.com", "second.com") -@patch("core.helper.ssrf_proxy._get_ssrf_client") -def test_host_header_preservation_without_user_header(mock_get_client): - """Test that when no Host header is provided, the default behavior is maintained.""" - mock_client = MagicMock() - mock_request = MagicMock() - mock_request.headers = {} - mock_response = MagicMock() - mock_response.status_code = 200 - mock_client.send.return_value = mock_response - mock_client.request.return_value = mock_response - mock_get_client.return_value = mock_client - - response = make_request("GET", "http://example.com") - - assert response.status_code == 200 - # Host should not be set if not provided by user - assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None - - @patch("core.helper.ssrf_proxy._get_ssrf_client") def test_host_header_preservation_with_user_header(mock_get_client): """Test that user-provided Host header is preserved in the request.""" mock_client = MagicMock() - mock_request = MagicMock() - mock_request.headers = {} mock_response = MagicMock() mock_response.status_code = 200 - mock_client.send.return_value = mock_response mock_client.request.return_value = mock_response mock_get_client.return_value = mock_client @@ -107,3 +84,93 @@ def test_host_header_preservation_with_user_header(mock_get_client): response = make_request("GET", "http://example.com", headers={"Host": custom_host}) assert response.status_code == 200 + # Verify client.request was called with the host header preserved (lowercase) + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["host"] == custom_host + + +@patch("core.helper.ssrf_proxy._get_ssrf_client") +@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"]) +def test_host_header_preservation_case_insensitive(mock_get_client, host_key): + """Test that Host header is preserved regardless of case.""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"}) + + assert response.status_code == 200 + # Host header should be normalized to lowercase "host" + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["host"] == "api.example.com" + + +class TestFollowRedirectsParameter: + """Tests for follow_redirects parameter handling. + + These tests verify that follow_redirects is correctly passed to client.request(). + """ + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_follow_redirects_passed_to_request(self, mock_get_client): + """Verify follow_redirects IS passed to client.request().""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + make_request("GET", "http://example.com", follow_redirects=True) + + # Verify follow_redirects was passed to request + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs.get("follow_redirects") is True + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client): + """Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style).""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + # Use allow_redirects (requests-style parameter) + make_request("GET", "http://example.com", allow_redirects=True) + + # Verify it was converted to follow_redirects + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs.get("follow_redirects") is True + assert "allow_redirects" not in call_kwargs + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_follow_redirects_not_set_when_not_specified(self, mock_get_client): + """Verify follow_redirects is not in kwargs when not specified (httpx default behavior).""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + make_request("GET", "http://example.com") + + # follow_redirects should not be in kwargs, letting httpx use its default + call_kwargs = mock_client.request.call_args.kwargs + assert "follow_redirects" not in call_kwargs + + @patch("core.helper.ssrf_proxy._get_ssrf_client") + def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client): + """Verify follow_redirects takes precedence when both are specified.""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.request.return_value = mock_response + mock_get_client.return_value = mock_client + + # Both specified - follow_redirects should take precedence + make_request("GET", "http://example.com", allow_redirects=False, follow_redirects=True) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs.get("follow_redirects") is True diff --git a/api/tests/unit_tests/core/logging/__init__.py b/api/tests/unit_tests/core/logging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/logging/test_context.py b/api/tests/unit_tests/core/logging/test_context.py new file mode 100644 index 0000000000..f388a3a0b9 --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_context.py @@ -0,0 +1,79 @@ +"""Tests for logging context module.""" + +import uuid + +from core.logging.context import ( + clear_request_context, + get_request_id, + get_trace_id, + init_request_context, +) + + +class TestLoggingContext: + """Tests for the logging context functions.""" + + def test_init_creates_request_id(self): + """init_request_context should create a 10-char request ID.""" + init_request_context() + request_id = get_request_id() + assert len(request_id) == 10 + assert all(c in "0123456789abcdef" for c in request_id) + + def test_init_creates_trace_id(self): + """init_request_context should create a 32-char trace ID.""" + init_request_context() + trace_id = get_trace_id() + assert len(trace_id) == 32 + assert all(c in "0123456789abcdef" for c in trace_id) + + def test_trace_id_derived_from_request_id(self): + """trace_id should be deterministically derived from request_id.""" + init_request_context() + request_id = get_request_id() + trace_id = get_trace_id() + + # Verify trace_id is derived using uuid5 + expected_trace = uuid.uuid5(uuid.NAMESPACE_DNS, request_id).hex + assert trace_id == expected_trace + + def test_clear_resets_context(self): + """clear_request_context should reset both IDs to empty strings.""" + init_request_context() + assert get_request_id() != "" + assert get_trace_id() != "" + + clear_request_context() + assert get_request_id() == "" + assert get_trace_id() == "" + + def test_default_values_are_empty(self): + """Default values should be empty strings before init.""" + clear_request_context() + assert get_request_id() == "" + assert get_trace_id() == "" + + def test_multiple_inits_create_different_ids(self): + """Each init should create new unique IDs.""" + init_request_context() + first_request_id = get_request_id() + first_trace_id = get_trace_id() + + init_request_context() + second_request_id = get_request_id() + second_trace_id = get_trace_id() + + assert first_request_id != second_request_id + assert first_trace_id != second_trace_id + + def test_context_isolation(self): + """Context should be isolated per-call (no thread leakage in same thread).""" + init_request_context() + id1 = get_request_id() + + # Simulate another request + init_request_context() + id2 = get_request_id() + + # IDs should be different + assert id1 != id2 diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py new file mode 100644 index 0000000000..b66ad111d5 --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -0,0 +1,114 @@ +"""Tests for logging filters.""" + +import logging +from unittest import mock + +import pytest + + +@pytest.fixture +def log_record(): + return logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + + +class TestTraceContextFilter: + def test_sets_empty_trace_id_without_context(self, log_record): + from core.logging.context import clear_request_context + from core.logging.filters import TraceContextFilter + + # Ensure no context is set + clear_request_context() + + filter = TraceContextFilter() + result = filter.filter(log_record) + + assert result is True + assert hasattr(log_record, "trace_id") + assert hasattr(log_record, "span_id") + assert hasattr(log_record, "req_id") + # Without context, IDs should be empty + assert log_record.trace_id == "" + assert log_record.req_id == "" + + def test_sets_trace_id_from_context(self, log_record): + """Test that trace_id and req_id are set from ContextVar when initialized.""" + from core.logging.context import init_request_context + from core.logging.filters import TraceContextFilter + + # Initialize context (no Flask needed!) + init_request_context() + + filter = TraceContextFilter() + filter.filter(log_record) + + # With context initialized, IDs should be set + assert log_record.trace_id != "" + assert len(log_record.trace_id) == 32 + assert log_record.req_id != "" + assert len(log_record.req_id) == 10 + + def test_filter_always_returns_true(self, log_record): + from core.logging.filters import TraceContextFilter + + filter = TraceContextFilter() + result = filter.filter(log_record) + assert result is True + + def test_sets_trace_id_from_otel_when_available(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2 + mock_context.span_id = 0x051581BF3BB55C45 + mock_span.get_span_context.return_value = mock_context + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + filter = TraceContextFilter() + filter.filter(log_record) + + assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_record.span_id == "051581bf3bb55c45" + + +class TestIdentityContextFilter: + def test_sets_empty_identity_without_request_context(self, log_record): + from core.logging.filters import IdentityContextFilter + + filter = IdentityContextFilter() + result = filter.filter(log_record) + + assert result is True + assert log_record.tenant_id == "" + assert log_record.user_id == "" + assert log_record.user_type == "" + + def test_filter_always_returns_true(self, log_record): + from core.logging.filters import IdentityContextFilter + + filter = IdentityContextFilter() + result = filter.filter(log_record) + assert result is True + + def test_handles_exception_gracefully(self, log_record): + from core.logging.filters import IdentityContextFilter + + filter = IdentityContextFilter() + + # Should not raise even if something goes wrong + with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")): + result = filter.filter(log_record) + assert result is True + assert log_record.tenant_id == "" diff --git a/api/tests/unit_tests/core/logging/test_structured_formatter.py b/api/tests/unit_tests/core/logging/test_structured_formatter.py new file mode 100644 index 0000000000..94b91d205e --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_structured_formatter.py @@ -0,0 +1,267 @@ +"""Tests for structured JSON formatter.""" + +import logging +import sys + +import orjson + + +class TestStructuredJSONFormatter: + def test_basic_log_format(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter(service_name="test-service") + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=42, + msg="Test message", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["severity"] == "INFO" + assert log_dict["service"] == "test-service" + assert log_dict["caller"] == "test.py:42" + assert log_dict["message"] == "Test message" + assert "ts" in log_dict + assert log_dict["ts"].endswith("Z") + + def test_severity_mapping(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + + test_cases = [ + (logging.DEBUG, "DEBUG"), + (logging.INFO, "INFO"), + (logging.WARNING, "WARN"), + (logging.ERROR, "ERROR"), + (logging.CRITICAL, "ERROR"), + ] + + for level, expected_severity in test_cases: + record = logging.LogRecord( + name="test", + level=level, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + output = formatter.format(record) + log_dict = orjson.loads(output) + assert log_dict["severity"] == expected_severity, f"Level {level} should map to {expected_severity}" + + def test_error_with_stack_trace(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + + try: + raise ValueError("Test error") + except ValueError: + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", + level=logging.ERROR, + pathname="test.py", + lineno=10, + msg="Error occurred", + args=(), + exc_info=exc_info, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["severity"] == "ERROR" + assert "stack_trace" in log_dict + assert "ValueError: Test error" in log_dict["stack_trace"] + + def test_no_stack_trace_for_info(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + + try: + raise ValueError("Test error") + except ValueError: + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=10, + msg="Info message", + args=(), + exc_info=exc_info, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "stack_trace" not in log_dict + + def test_trace_context_included(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.trace_id = "5b8aa5a2d2c872e8321cf37308d69df2" + record.span_id = "051581bf3bb55c45" + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["trace_id"] == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_dict["span_id"] == "051581bf3bb55c45" + + def test_identity_context_included(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.tenant_id = "t-global-corp" + record.user_id = "u-admin-007" + record.user_type = "admin" + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "identity" in log_dict + assert log_dict["identity"]["tenant_id"] == "t-global-corp" + assert log_dict["identity"]["user_id"] == "u-admin-007" + assert log_dict["identity"]["user_type"] == "admin" + + def test_no_identity_when_empty(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "identity" not in log_dict + + def test_attributes_included(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.attributes = {"order_id": "ord-123", "amount": 99.99} + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["attributes"]["order_id"] == "ord-123" + assert log_dict["attributes"]["amount"] == 99.99 + + def test_message_with_args(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="User %s logged in from %s", + args=("john", "192.168.1.1"), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["message"] == "User john logged in from 192.168.1.1" + + def test_timestamp_format(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + # Verify ISO 8601 format with Z suffix + ts = log_dict["ts"] + assert ts.endswith("Z") + assert "T" in ts + # Should have milliseconds + assert "." in ts + + def test_fallback_for_non_serializable_attributes(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test with non-serializable", + args=(), + exc_info=None, + ) + # Set is not serializable by orjson + record.attributes = {"items": {1, 2, 3}, "custom": object()} + + # Should not raise, fallback to json.dumps with default=str + output = formatter.format(record) + + # Verify it's valid JSON (parsed by stdlib json since orjson may fail) + import json + + log_dict = json.loads(output) + assert log_dict["message"] == "Test with non-serializable" + assert "attributes" in log_dict diff --git a/api/tests/unit_tests/core/logging/test_trace_helpers.py b/api/tests/unit_tests/core/logging/test_trace_helpers.py new file mode 100644 index 0000000000..aab1753b9b --- /dev/null +++ b/api/tests/unit_tests/core/logging/test_trace_helpers.py @@ -0,0 +1,102 @@ +"""Tests for trace helper functions.""" + +import re +from unittest import mock + + +class TestGetSpanIdFromOtelContext: + def test_returns_none_without_span(self): + from core.helper.trace_id_helper import get_span_id_from_otel_context + + with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + result = get_span_id_from_otel_context() + assert result is None + + def test_returns_span_id_when_available(self): + from core.helper.trace_id_helper import get_span_id_from_otel_context + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.span_id = 0x051581BF3BB55C45 + mock_span.get_span_context.return_value = mock_context + + with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0): + result = get_span_id_from_otel_context() + assert result == "051581bf3bb55c45" + + def test_returns_none_on_exception(self): + from core.helper.trace_id_helper import get_span_id_from_otel_context + + with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")): + result = get_span_id_from_otel_context() + assert result is None + + +class TestGenerateTraceparentHeader: + def test_generates_valid_format(self): + from core.helper.trace_id_helper import generate_traceparent_header + + with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + result = generate_traceparent_header() + + assert result is not None + # Format: 00-{trace_id}-{span_id}-01 + parts = result.split("-") + assert len(parts) == 4 + assert parts[0] == "00" # version + assert len(parts[1]) == 32 # trace_id (32 hex chars) + assert len(parts[2]) == 16 # span_id (16 hex chars) + assert parts[3] == "01" # flags + + def test_uses_otel_context_when_available(self): + from core.helper.trace_id_helper import generate_traceparent_header + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2 + mock_context.span_id = 0x051581BF3BB55C45 + mock_span.get_span_context.return_value = mock_context + + with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with ( + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + result = generate_traceparent_header() + + assert result == "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" + + def test_generates_hex_only_values(self): + from core.helper.trace_id_helper import generate_traceparent_header + + with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + result = generate_traceparent_header() + + parts = result.split("-") + # All parts should be valid hex + assert re.match(r"^[0-9a-f]+$", parts[1]) + assert re.match(r"^[0-9a-f]+$", parts[2]) + + +class TestParseTraceparentHeader: + def test_parses_valid_traceparent(self): + from core.helper.trace_id_helper import parse_traceparent_header + + traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" + result = parse_traceparent_header(traceparent) + + assert result == "5b8aa5a2d2c872e8321cf37308d69df2" + + def test_returns_none_for_invalid_format(self): + from core.helper.trace_id_helper import parse_traceparent_header + + # Wrong number of parts + assert parse_traceparent_header("00-abc-def") is None + # Wrong trace_id length + assert parse_traceparent_header("00-abc-def-01") is None + + def test_returns_none_for_empty_string(self): + from core.helper.trace_id_helper import parse_traceparent_header + + assert parse_traceparent_header("") is None diff --git a/api/tests/unit_tests/core/rag/cleaner/__init__.py b/api/tests/unit_tests/core/rag/cleaner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py new file mode 100644 index 0000000000..65ee62b8dd --- /dev/null +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -0,0 +1,213 @@ +from core.rag.cleaner.clean_processor import CleanProcessor + + +class TestCleanProcessor: + """Test cases for CleanProcessor.clean method.""" + + def test_clean_default_removal_of_invalid_symbols(self): + """Test default cleaning removes invalid symbols.""" + # Test <| replacement + assert CleanProcessor.clean("text<|with<|invalid", None) == "text replacement + assert CleanProcessor.clean("text|>with|>invalid", None) == "text>with>invalid" + + # Test removal of control characters + text_with_control = "normal\x00text\x1fwith\x07control\x7fchars" + expected = "normaltextwithcontrolchars" + assert CleanProcessor.clean(text_with_control, None) == expected + + # Test U+FFFE removal + text_with_ufffe = "normal\ufffepadding" + expected = "normalpadding" + assert CleanProcessor.clean(text_with_ufffe, None) == expected + + def test_clean_with_none_process_rule(self): + """Test cleaning with None process_rule - only default cleaning applied.""" + text = "Hello<|World\x00" + expected = "Hello becomes >, control chars and U+FFFE are removed + assert CleanProcessor.clean(text, None) == "<<>>" + + def test_clean_multiple_markdown_links_preserved(self): + """Test multiple markdown links are all preserved.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + text = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)" + expected = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)" + assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_markdown_link_text_as_url(self): + """Test markdown link where the link text itself is a URL.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + # Link text that looks like URL should be preserved + text = "[https://text-url.com](https://actual-url.com)" + expected = "[https://text-url.com](https://actual-url.com)" + assert CleanProcessor.clean(text, process_rule) == expected + + # Text URL without markdown should be removed + text = "https://text-url.com https://actual-url.com" + expected = " " + assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_complex_markdown_link_content(self): + """Test markdown links with complex content - known limitation with brackets in link text.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + # Note: The regex pattern [^\]]* cannot handle ] within link text + # This is a known limitation - the pattern stops at the first ] + text = "[Text with [brackets] and (parens)](https://example.com)" + # Actual behavior: only matches up to first ], URL gets removed + expected = "[Text with [brackets] and (parens)](" + assert CleanProcessor.clean(text, process_rule) == expected + + # Test that properly formatted markdown links work + text = "[Text with (parens) and symbols](https://example.com)" + expected = "[Text with (parens) and symbols](https://example.com)" + assert CleanProcessor.clean(text, process_rule) == expected diff --git a/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py new file mode 100644 index 0000000000..3167a9a301 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py @@ -0,0 +1,186 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.extractor.pdf_extractor as pe + + +@pytest.fixture +def mock_dependencies(monkeypatch): + # Mock storage + saves = [] + + def save(key, data): + saves.append((key, data)) + + monkeypatch.setattr(pe, "storage", SimpleNamespace(save=save)) + + # Mock db + class DummySession: + def __init__(self): + self.added = [] + self.committed = False + + def add(self, obj): + self.added.append(obj) + + def add_all(self, objs): + self.added.extend(objs) + + def commit(self): + self.committed = True + + db_stub = SimpleNamespace(session=DummySession()) + monkeypatch.setattr(pe, "db", db_stub) + + # Mock UploadFile + class FakeUploadFile: + DEFAULT_ID = "test_file_id" + + def __init__(self, **kwargs): + # Assign id from DEFAULT_ID, allow override via kwargs if needed + self.id = self.DEFAULT_ID + for k, v in kwargs.items(): + setattr(self, k, v) + + monkeypatch.setattr(pe, "UploadFile", FakeUploadFile) + + # Mock config + monkeypatch.setattr(pe.dify_config, "FILES_URL", "http://files.local") + monkeypatch.setattr(pe.dify_config, "INTERNAL_FILES_URL", None) + monkeypatch.setattr(pe.dify_config, "STORAGE_TYPE", "local") + + return SimpleNamespace(saves=saves, db=db_stub, UploadFile=FakeUploadFile) + + +@pytest.mark.parametrize( + ("image_bytes", "expected_mime", "expected_ext", "file_id"), + [ + (b"\xff\xd8\xff some jpeg", "image/jpeg", "jpg", "test_file_id_jpeg"), + (b"\x89PNG\r\n\x1a\n some png", "image/png", "png", "test_file_id_png"), + ], +) +def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, expected_mime, expected_ext, file_id): + saves = mock_dependencies.saves + db_stub = mock_dependencies.db + + # Customize FakeUploadFile id for this test case. + # Using monkeypatch ensures the class attribute is reset between parameter sets. + monkeypatch.setattr(mock_dependencies.UploadFile, "DEFAULT_ID", file_id) + + # Mock page and image objects + mock_page = MagicMock() + mock_image_obj = MagicMock() + + def mock_extract(buf, fb_format=None): + buf.write(image_bytes) + + mock_image_obj.extract.side_effect = mock_extract + + mock_page.get_objects.return_value = [mock_image_obj] + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + # We need to handle the import inside _extract_images + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + assert f"![image](http://files.local/files/{file_id}/file-preview)" in result + assert len(saves) == 1 + assert saves[0][1] == image_bytes + assert len(db_stub.session.added) == 1 + assert db_stub.session.added[0].tenant_id == "t1" + assert db_stub.session.added[0].size == len(image_bytes) + assert db_stub.session.added[0].mime_type == expected_mime + assert db_stub.session.added[0].extension == expected_ext + assert db_stub.session.committed is True + + +@pytest.mark.parametrize( + ("get_objects_side_effect", "get_objects_return_value"), + [ + (None, []), # Empty list + (None, None), # None returned + (Exception("Failed to get objects"), None), # Exception raised + ], +) +def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_side_effect, get_objects_return_value): + mock_page = MagicMock() + if get_objects_side_effect: + mock_page.get_objects.side_effect = get_objects_side_effect + else: + mock_page.get_objects.return_value = get_objects_return_value + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + assert result == "" + + +def test_extract_calls_extract_images(mock_dependencies, monkeypatch): + # Mock pypdfium2 + mock_pdf_doc = MagicMock() + mock_page = MagicMock() + mock_pdf_doc.__iter__.return_value = [mock_page] + + # Mock text extraction + mock_text_page = MagicMock() + mock_text_page.get_text_range.return_value = "Page text content" + mock_page.get_textpage.return_value = mock_text_page + + with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc): + # Mock Blob + mock_blob = MagicMock() + mock_blob.source = "test.pdf" + with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob): + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + # Mock _extract_images to return a known string + monkeypatch.setattr(extractor, "_extract_images", lambda p: "![image](img_url)") + + documents = list(extractor.extract()) + + assert len(documents) == 1 + assert "Page text content" in documents[0].page_content + assert "![image](img_url)" in documents[0].page_content + assert documents[0].metadata["page"] == 0 + + +def test_extract_images_failures(mock_dependencies): + saves = mock_dependencies.saves + db_stub = mock_dependencies.db + + # Mock page and image objects + mock_page = MagicMock() + mock_image_obj_fail = MagicMock() + mock_image_obj_ok = MagicMock() + + # First image raises exception + mock_image_obj_fail.extract.side_effect = Exception("Extraction failure") + + # Second image is OK (JPEG) + jpeg_bytes = b"\xff\xd8\xff some image data" + + def mock_extract(buf, fb_format=None): + buf.write(jpeg_bytes) + + mock_image_obj_ok.extract.side_effect = mock_extract + + mock_page.get_objects.return_value = [mock_image_obj_fail, mock_image_obj_ok] + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + # Should have one success + assert "![image](http://files.local/files/test_file_id/file-preview)" in result + assert len(saves) == 1 + assert saves[0][1] == jpeg_bytes + assert db_stub.session.committed is True diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 3203aab8c3..f9e59a5f05 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,8 +1,12 @@ """Primarily used for testing merged cell scenarios""" +import os +import tempfile from types import SimpleNamespace from docx import Document +from docx.oxml import OxmlElement +from docx.oxml.ns import qn import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor @@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url(): dify_config.FILES_URL = original_files_url if original_internal_files_url is not None: dify_config.INTERNAL_FILES_URL = original_internal_files_url + + +def test_extract_hyperlinks(monkeypatch): + # Mock db and storage to avoid issues during image extraction (even if no images are present) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph("Visit ") + + # Adding a hyperlink manually + r_id = "rId99" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + new_run = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + new_run.append(t) + hyperlink.append(new_run) + p._p.append(hyperlink) + + # Add relationship to the part + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify modern hyperlink extraction + assert "Visit[Dify](https://dify.ai)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_extract_legacy_hyperlinks(monkeypatch): + # Mock db and storage + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None)) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + doc = Document() + p = doc.add_paragraph() + + # Construct a legacy HYPERLINK field: + # 1. w:fldChar (begin) + # 2. w:instrText (HYPERLINK "http://example.com") + # 3. w:fldChar (separate) + # 4. w:r (visible text "Example") + # 5. w:fldChar (end) + + run1 = OxmlElement("w:r") + fldCharBegin = OxmlElement("w:fldChar") + fldCharBegin.set(qn("w:fldCharType"), "begin") + run1.append(fldCharBegin) + p._p.append(run1) + + run2 = OxmlElement("w:r") + instrText = OxmlElement("w:instrText") + instrText.text = ' HYPERLINK "http://example.com" ' + run2.append(instrText) + p._p.append(run2) + + run3 = OxmlElement("w:r") + fldCharSep = OxmlElement("w:fldChar") + fldCharSep.set(qn("w:fldCharType"), "separate") + run3.append(fldCharSep) + p._p.append(run3) + + run4 = OxmlElement("w:r") + t4 = OxmlElement("w:t") + t4.text = "Example" + run4.append(t4) + p._p.append(run4) + + run5 = OxmlElement("w:r") + fldCharEnd = OxmlElement("w:fldChar") + fldCharEnd.set(qn("w:fldCharType"), "end") + run5.append(fldCharEnd) + p._p.append(run5) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + extractor = WordExtractor(tmp_path, "tenant_id", "user_id") + docs = extractor.extract() + # Verify legacy hyperlink extraction + assert "[Example](http://example.com)" in docs[0].page_content + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 6306d665e7..ca08cb0591 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -73,6 +73,7 @@ import pytest from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from models.dataset import Dataset @@ -1518,6 +1519,282 @@ class TestRetrievalService: call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model + # ==================== Multiple Retrieve Thread Tests ==================== + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset( + self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1. + + When there is only one dataset, the second reranking is unnecessary + because the documents are already ranked from the first retrieval. + This optimization avoids the overhead of reranking when it won't + provide any benefit. + + Verifies: + - DataPostProcessor is NOT called when dataset_count == 1 + - Documents are still added to all_documents + - Standard scoring logic is applied instead + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + all_documents = [] + + # Act - Call with dataset_count = 1 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=1, # Single dataset - should skip second reranking + ) + + # Assert + # DataPostProcessor should NOT be called (second reranking skipped) + mock_data_processor_class.assert_not_called() + + # Documents should still be added to all_documents + assert len(all_documents) == 2 + assert all_documents[0].page_content == "Test content 1" + assert all_documents[1].page_content == "Test content 2" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score") + def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets( + self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1. + + When there are multiple datasets, the second reranking is necessary + to merge and re-rank results from different datasets. This ensures + the most relevant documents across all datasets are returned. + + Verifies: + - DataPostProcessor IS called when dataset_count > 1 + - Reranking is applied with correct parameters + - Documents are processed correctly + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + # Mock DataPostProcessor instance and its invoke method + mock_processor_instance = Mock() + # Simulate reranking - return documents in reversed order with updated scores + reranked_docs = [ + Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + ] + mock_processor_instance.invoke.return_value = reranked_docs + mock_data_processor_class.return_value = mock_processor_instance + + all_documents = [] + + # Create second dataset + mock_dataset2 = Mock(spec=Dataset) + mock_dataset2.id = str(uuid4()) + mock_dataset2.indexing_technique = "high_quality" + mock_dataset2.provider = "dify" + + # Act - Call with dataset_count = 2 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset, mock_dataset2], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=2, # Multiple datasets - should perform second reranking + ) + + # Assert + # DataPostProcessor SHOULD be called (second reranking performed) + mock_data_processor_class.assert_called_once_with( + tenant_id, + "reranking_model", + {"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + None, + False, + ) + + # Verify invoke was called with correct parameters + mock_processor_instance.invoke.assert_called_once() + + # Documents should be added to all_documents after reranking + assert len(all_documents) == 2 + # The reranked order should be reflected + assert all_documents[0].page_content == "Test content 2" + assert all_documents[1].page_content == "Test content 1" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score") + def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring( + self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset + ): + """ + Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1 + and reranking is enabled. + + When there's only one dataset, instead of using DataPostProcessor, + the method should fall through to the standard scoring logic + (calculate_vector_score for high_quality datasets). + + Verifies: + - DataPostProcessor is NOT called + - calculate_vector_score IS called for high_quality indexing + - Documents are scored correctly + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + tenant_id = str(uuid4()) + + # Create test documents + doc1 = Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + doc2 = Document( + page_content="Test content 2", + metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ) + + # Mock _retriever to return documents + def side_effect_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.extend([doc1, doc2]) + + mock_retriever.side_effect = side_effect_retriever + + # Set up dataset with high_quality indexing + mock_dataset.indexing_technique = "high_quality" + + # Mock calculate_vector_score to return scored documents + scored_docs = [ + Document( + page_content="Test content 1", + metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id}, + provider="dify", + ), + ] + mock_calculate_vector_score.return_value = scored_docs + + all_documents = [] + + # Act - Call with dataset_count = 1 + dataset_retrieval._multiple_retrieve_thread( + flask_app=mock_flask_app, + available_datasets=[mock_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, # Reranking enabled but should be skipped for single dataset + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + weights=None, + top_k=5, + score_threshold=0.5, + query="test query", + attachment_id=None, + dataset_count=1, + ) + + # Assert + # DataPostProcessor should NOT be called + mock_data_processor_class.assert_not_called() + + # calculate_vector_score SHOULD be called for high_quality datasets + mock_calculate_vector_score.assert_called_once() + call_args = mock_calculate_vector_score.call_args + assert call_args[0][1] == 5 # top_k + + # Documents should be added after standard scoring + assert len(all_documents) == 1 + assert all_documents[0].page_content == "Test content 1" + class TestRetrievalMethods: """ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py new file mode 100644 index 0000000000..5f461d53ae --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py @@ -0,0 +1,113 @@ +import threading +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from flask import Flask, current_app + +from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from models.dataset import Dataset + + +class TestRetrievalService: + @pytest.fixture + def mock_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = str(uuid4()) + dataset.tenant_id = str(uuid4()) + dataset.name = "test_dataset" + dataset.indexing_technique = "high_quality" + dataset.provider = "dify" + return dataset + + def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): + """ + Repro test for current bug: + reranking runs after `with flask_app.app_context():` exits. + `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, + so we must assert from that list (not from an outer try/except). + """ + dataset_retrieval = DatasetRetrieval() + flask_app = Flask(__name__) + tenant_id = str(uuid4()) + + # second dataset to ensure dataset_count > 1 reranking branch + secondary_dataset = Mock(spec=Dataset) + secondary_dataset.id = str(uuid4()) + secondary_dataset.provider = "dify" + secondary_dataset.indexing_technique = "high_quality" + + # retriever returns 1 doc into internal list (all_documents_item) + document = Document( + page_content="Context aware doc", + metadata={ + "doc_id": "doc1", + "score": 0.95, + "document_id": str(uuid4()), + "dataset_id": mock_dataset.id, + }, + provider="dify", + ) + + def fake_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.append(document) + + called = {"init": 0, "invoke": 0} + + class ContextRequiredPostProcessor: + def __init__(self, *args, **kwargs): + called["init"] += 1 + # will raise RuntimeError if no Flask app context exists + _ = current_app.name + + def invoke(self, *args, **kwargs): + called["invoke"] += 1 + _ = current_app.name + return kwargs.get("documents") or args[1] + + # output list from _multiple_retrieve_thread + all_documents: list[Document] = [] + + # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here + thread_exceptions: list[Exception] = [] + + def target(): + with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): + with patch( + "core.rag.retrieval.dataset_retrieval.DataPostProcessor", + ContextRequiredPostProcessor, + ): + dataset_retrieval._multiple_retrieve_thread( + flask_app=flask_app, + available_datasets=[mock_dataset, secondary_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={ + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-v2", + }, + weights=None, + top_k=3, + score_threshold=0.0, + query="test query", + attachment_id=None, + dataset_count=2, # force reranking branch + thread_exceptions=thread_exceptions, # ✅ key + ) + + t = threading.Thread(target=target) + t.start() + t.join() + + # Ensure reranking branch was actually executed + assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." + + # Current buggy code should record an exception (not raise it) + assert not thread_exceptions, thread_exceptions diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 9060cf7b6c..636fac7a40 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -32,7 +32,6 @@ def mock_provider_entity(): label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"), icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), - icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), background="background.png", help=None, supported_model_types=[ModelType.LLM], diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 8677325d4e..f33fd0deeb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,8 +3,15 @@ import json from unittest.mock import MagicMock +from core.variables import IntegerVariable, StringVariable from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + GraphEngineCommand, + UpdateVariablesCommand, + VariableUpdate, +) class TestRedisChannel: @@ -148,6 +155,43 @@ class TestRedisChannel: assert commands[0].command_type == CommandType.ABORT assert isinstance(commands[1], AbortCommand) + def test_fetch_commands_with_update_variables_command(self): + """Test fetching update variables command from Redis.""" + mock_redis = MagicMock() + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), + ), + ] + ) + command_json = json.dumps(update_command.model_dump()) + + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command_json.encode()], 1] + + channel = RedisChannel(mock_redis, "test:key") + commands = channel.fetch_commands() + + assert len(commands) == 1 + assert isinstance(commands[0], UpdateVariablesCommand) + assert isinstance(commands[0].updates[0].value, StringVariable) + assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] + assert commands[0].updates[0].value.value == "bar" + def test_fetch_commands_skips_invalid_json(self): """Test that invalid JSON commands are skipped.""" mock_redis = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py new file mode 100644 index 0000000000..cf8811dc2b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py @@ -0,0 +1 @@ +"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py new file mode 100644 index 0000000000..0019020ede --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py @@ -0,0 +1,307 @@ +"""Unit tests for skip propagator.""" + +from unittest.mock import MagicMock, create_autospec + +from core.workflow.graph import Edge, Graph +from core.workflow.graph_engine.graph_state_manager import GraphStateManager +from core.workflow.graph_engine.graph_traversal.skip_propagator import SkipPropagator + + +class TestSkipPropagator: + """Test suite for SkipPropagator.""" + + def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: + """When there are unknown incoming edges, propagation should stop.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + # Setup graph edges dict + mock_graph.edges = {"edge_1": mock_edge} + + # Setup incoming edges + incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return has_unknown=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": True, + "has_taken": False, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_graph.get_incoming_edges.assert_called_once_with("node_2") + mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) + # Should not call any other state manager methods + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.start_execution.assert_not_called() + mock_state_manager.mark_node_skipped.assert_not_called() + + def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: + """When there is at least one taken edge, node should be enqueued.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return has_taken=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": True, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_state_manager.enqueue_node.assert_called_once_with("node_2") + mock_state_manager.start_execution.assert_called_once_with("node_2") + mock_state_manager.mark_node_skipped.assert_not_called() + + def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: + """When all incoming edges are skipped, should propagate skip to node.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return all_skipped=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.start_execution.assert_not_called() + + def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: + """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create outgoing edges + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_2" + edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge + + edge2 = MagicMock(spec=Edge) + edge2.id = "edge_3" + edge2.head = "node_downstream_2" + + # Setup graph edges dict for propagate_skip_from_edge + mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} + mock_graph.get_outgoing_edges.return_value = [edge1, edge2] + + # Setup get_incoming_edges to return empty list to stop recursion + mock_graph.get_incoming_edges.return_value = [] + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Use mock to call private method + # Act + propagator._propagate_skip_to_node("node_1") + + # Assert + mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") + assert mock_state_manager.mark_edge_skipped.call_count == 2 + # Should recursively propagate from each edge + # Since propagate_skip_from_edge is called, we need to verify it was called + # But we can't directly verify due to recursion. We'll trust the logic. + + def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: + """skip_branch_paths should mark all unselected edges as skipped and propagate.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create unselected edges + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_1" + edge1.head = "node_downstream_1" + + edge2 = MagicMock(spec=Edge) + edge2.id = "edge_2" + edge2.head = "node_downstream_2" + + unselected_edges = [edge1, edge2] + + # Setup graph edges dict + mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} + # Setup get_incoming_edges to return empty list to stop recursion + mock_graph.get_incoming_edges.return_value = [] + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.skip_branch_paths(unselected_edges) + + # Assert + mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") + assert mock_state_manager.mark_edge_skipped.call_count == 2 + # propagate_skip_from_edge should be called for each edge + # We can't directly verify due to the mock, but the logic is covered + + def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: + """Skip propagation should recursively propagate through the graph.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_1" + edge1.head = "node_2" + + edge3 = MagicMock(spec=Edge) + edge3.id = "edge_3" + edge3.head = "node_4" + + mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} + + # Setup get_incoming_edges to return different values based on node + def get_incoming_edges_side_effect(node_id): + if node_id == "node_2": + return [edge1] + elif node_id == "node_4": + return [edge3] + return [] + + mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect + + # Setup get_outgoing_edges to return different values based on node + def get_outgoing_edges_side_effect(node_id): + if node_id == "node_2": + return [edge3] + elif node_id == "node_4": + return [] # No outgoing edges, stops recursion + return [] + + mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect + + # Setup state manager to return all_skipped for both nodes + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + # Should mark node_2 as skipped + mock_state_manager.mark_node_skipped.assert_any_call("node_2") + # Should mark edge_3 as skipped + mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") + # Should propagate to node_4 + mock_state_manager.mark_node_skipped.assert_any_call("node_4") + assert mock_state_manager.mark_node_skipped.call_count == 2 + + def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: + """Test with mixed edge states (some unknown, some taken, some skipped).""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Test 1: has_unknown=True, has_taken=False, all_skipped=False + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": True, + "has_taken": False, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should stop processing + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.mark_node_skipped.assert_not_called() + + # Reset mocks for next test + mock_state_manager.reset_mock() + mock_graph.reset_mock() + + # Test 2: has_unknown=False, has_taken=True, all_skipped=False + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": True, + "all_skipped": False, + } + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should enqueue node + mock_state_manager.enqueue_node.assert_called_once_with("node_2") + mock_state_manager.start_execution.assert_called_once_with("node_2") + + # Reset mocks for next test + mock_state_manager.reset_mock() + mock_graph.reset_mock() + + # Test 3: has_unknown=False, has_taken=False, all_skipped=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should propagate skip + mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py new file mode 100644 index 0000000000..d6ba61c50c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import pytest + +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_engine.layers.base import ( + GraphEngineLayer, + GraphEngineLayerNotInitializedError, +) +from core.workflow.graph_events import GraphEngineEvent + +from ..test_table_runner import WorkflowRunner + + +class LayerForTest(GraphEngineLayer): + def on_graph_start(self) -> None: + pass + + def on_event(self, event: GraphEngineEvent) -> None: + pass + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_layer_runtime_state_raises_when_uninitialized() -> None: + layer = LayerForTest() + + with pytest.raises(GraphEngineLayerNotInitializedError): + _ = layer.graph_runtime_state + + +def test_layer_runtime_state_available_after_engine_layer() -> None: + runner = WorkflowRunner() + fixture_data = runner.load_fixture("simple_passthrough_workflow") + graph, graph_runtime_state = runner.create_graph_from_fixture( + fixture_data, + inputs={"query": "test layer state"}, + ) + engine = GraphEngine( + workflow_id="test_workflow", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + ) + + layer = LayerForTest() + engine.layer(layer) + + outputs = layer.graph_runtime_state.outputs + ready_queue_size = layer.graph_runtime_state.ready_queue_size + + assert outputs == {} + assert isinstance(ready_queue_size, int) + assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index c1fc4acd73..fe3ea576c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -3,6 +3,7 @@ from __future__ import annotations import queue +import threading from unittest import mock from core.workflow.entities.pause_reason import SchedulingPause @@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=execution_coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() assert event_queue.empty() @@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int: event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() @@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue(): event_queue=event_queue, event_handler=event_handler, execution_coordinator=coordinator, + stop_event=threading.Event(), ) dispatcher._dispatcher_loop() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index b074a11be9..d826f7a900 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -4,12 +4,19 @@ import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables import IntegerVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent from core.workflow.nodes.start.start_node import StartNode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -180,3 +187,67 @@ def test_pause_command(): graph_execution = engine.graph_runtime_state.graph_execution assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] + + +def test_update_variables_command_updates_pool(): + """Test that GraphEngine updates variable pool via update variables command.""" + + shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=shared_runtime_state, + ) + mock_graph.nodes["start"] = start_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + command_channel = InMemoryChannel() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=shared_runtime_state, + command_channel=command_channel, + ) + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), + ), + ] + ) + command_channel.send_command(update_command) + + list(engine.run()) + + updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) + added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) + + assert updated_existing is not None + assert updated_existing.value == "new value" + assert added_new is not None + assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index b02f90588b..5ceb8dd7f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing the behavior of mock nodes during testing. """ +from __future__ import annotations + from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -95,67 +97,67 @@ class MockConfigBuilder: def __init__(self) -> None: self._config = MockConfig() - def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder": + def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable auto-mocking.""" self._config.enable_auto_mock = enabled return self - def with_delays(self, enabled: bool = True) -> "MockConfigBuilder": + def with_delays(self, enabled: bool = True) -> MockConfigBuilder: """Enable or disable simulated execution delays.""" self._config.simulate_delays = enabled return self - def with_llm_response(self, response: str) -> "MockConfigBuilder": + def with_llm_response(self, response: str) -> MockConfigBuilder: """Set default LLM response.""" self._config.default_llm_response = response return self - def with_agent_response(self, response: str) -> "MockConfigBuilder": + def with_agent_response(self, response: str) -> MockConfigBuilder: """Set default agent response.""" self._config.default_agent_response = response return self - def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default tool response.""" self._config.default_tool_response = response return self - def with_retrieval_response(self, response: str) -> "MockConfigBuilder": + def with_retrieval_response(self, response: str) -> MockConfigBuilder: """Set default retrieval response.""" self._config.default_retrieval_response = response return self - def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default HTTP response.""" self._config.default_http_response = response return self - def with_template_transform_response(self, response: str) -> "MockConfigBuilder": + def with_template_transform_response(self, response: str) -> MockConfigBuilder: """Set default template transform response.""" self._config.default_template_transform_response = response return self - def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder": + def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder: """Set default code execution response.""" self._config.default_code_response = response return self - def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder": + def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder: """Set outputs for a specific node.""" self._config.set_node_outputs(node_id, outputs) return self - def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder": + def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder: """Set error for a specific node.""" self._config.set_node_error(node_id, error) return self - def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder": + def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder: """Add a node-specific configuration.""" self._config.set_node_config(config.node_id, config) return self - def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder": + def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder: """Set default configuration for a node type.""" self._config.set_default_config(node_type, config) return self diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index eeffdd27fe..6e9a432745 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -103,13 +103,25 @@ class MockNodeFactory(DifyNodeFactory): # Create mock node instance mock_class = self._mock_node_types[node_type] - mock_instance = mock_class( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - mock_config=self.mock_config, - ) + if node_type == NodeType.CODE: + mock_instance = mock_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + mock_config=self.mock_config, + code_executor=self._code_executor, + code_providers=self._code_providers, + code_limits=self._code_limits, + ) + else: + mock_instance = mock_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + mock_config=self.mock_config, + ) return mock_instance diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index fd94a5e833..5937bbfb39 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -40,12 +40,14 @@ class MockNodeMixin: graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, + **kwargs: Any, ): super().__init__( id=id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + **kwargs, ) self.mock_config = mock_config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 4fb693a5c2..de08cc3497 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -5,11 +5,24 @@ This module tests the functionality of MockTemplateTransformNode and MockCodeNod to ensure they work correctly with the TableTestRunner. """ +from configs import dify_config from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode +DEFAULT_CODE_LIMITS = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, +) + class TestMockTemplateTransformNode: """Test cases for MockTemplateTransformNode.""" @@ -306,6 +319,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node @@ -370,6 +384,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node @@ -438,6 +453,7 @@ class TestMockCodeNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + code_limits=DEFAULT_CODE_LIMITS, ) # Run the node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py new file mode 100644 index 0000000000..ea8d3a977f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py @@ -0,0 +1,539 @@ +""" +Unit tests for stop_event functionality in GraphEngine. + +Tests the unified stop_event management by GraphEngine and its propagation +to WorkerPool, Worker, Dispatcher, and Nodes. +""" + +import threading +import time +from unittest.mock import MagicMock, Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, +) +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from models.enums import UserFrom + + +class TestStopEventPropagation: + """Test suite for stop_event propagation through GraphEngine components.""" + + def test_graph_engine_creates_stop_event(self): + """Test that GraphEngine creates a stop_event on initialization.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify stop_event was created + assert engine._stop_event is not None + assert isinstance(engine._stop_event, threading.Event) + + # Verify it was set in graph_runtime_state + assert runtime_state.stop_event is not None + assert runtime_state.stop_event is engine._stop_event + + def test_stop_event_cleared_on_start(self): + """Test that stop_event is cleared when execution starts.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set the stop_event before running + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear the stop_event) + events = list(engine.run()) + + # After running, stop_event should be set again (by _stop_execution) + # But during start it was cleared + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + def test_stop_event_set_on_stop(self): + """Test that stop_event is set when execution stops.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Initially not set + assert not engine._stop_event.is_set() + + # Run the engine + list(engine.run()) + + # After execution completes, stop_event should be set + assert engine._stop_event.is_set() + + def test_stop_event_passed_to_worker_pool(self): + """Test that stop_event is passed to WorkerPool.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify WorkerPool has the stop_event + assert engine._worker_pool._stop_event is not None + assert engine._worker_pool._stop_event is engine._stop_event + + def test_stop_event_passed_to_dispatcher(self): + """Test that stop_event is passed to Dispatcher.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Verify Dispatcher has the stop_event + assert engine._dispatcher._stop_event is not None + assert engine._dispatcher._stop_event is engine._stop_event + + +class TestNodeStopCheck: + """Test suite for Node._should_stop() functionality.""" + + def test_node_should_stop_checks_runtime_state(self): + """Test that Node._should_stop() checks GraphRuntimeState.stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Initially stop_event is not set + assert not answer_node._should_stop() + + # Set the stop_event + runtime_state.stop_event.set() + + # Now _should_stop should return True + assert answer_node._should_stop() + + def test_node_run_checks_stop_event_between_yields(self): + """Test that Node.run() checks stop_event between yielding events.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a simple node + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + # Set stop_event BEFORE running the node + runtime_state.stop_event.set() + + # Run the node - should yield start event then detect stop + # The node should check stop_event before processing + assert answer_node._should_stop(), "stop_event should be set" + + # Run and collect events + events = list(answer_node.run()) + + # Since stop_event is set at the start, we should get: + # 1. NodeRunStartedEvent (always yielded first) + # 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast) + assert len(events) >= 2 + assert isinstance(events[0], NodeRunStartedEvent) + + # Note: AnswerNode is very simple and might complete before stop check + # The important thing is that _should_stop() returns True when stop_event is set + assert answer_node._should_stop() + + +class TestStopEventIntegration: + """Integration tests for stop_event in workflow execution.""" + + def test_simple_workflow_respects_stop_event(self): + """Test that a simple workflow respects stop_event.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + # Create start and answer nodes + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + answer_node = AnswerNode( + id="answer", + config={"id": "answer", "data": {"title": "answer", "answer": "hello"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + + mock_graph.nodes["start"] = start_node + mock_graph.nodes["answer"] = answer_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Set stop_event before running + runtime_state.stop_event.set() + + # Run the engine + events = list(engine.run()) + + # Should get started event but not succeeded (due to stop) + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + # The workflow should still complete (start node runs quickly) + # but answer node might be cancelled depending on timing + + def test_stop_event_with_concurrent_nodes(self): + """Test stop_event behavior with multiple concurrent nodes.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + # Create multiple nodes + for i in range(3): + answer_node = AnswerNode( + id=f"answer_{i}", + config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes[f"answer_{i}"] = answer_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # All nodes should share the same stop_event + for node in mock_graph.nodes.values(): + assert node.graph_runtime_state.stop_event is runtime_state.stop_event + assert node.graph_runtime_state.stop_event is engine._stop_event + + +class TestStopEventTimeoutBehavior: + """Test stop_event behavior with join timeouts.""" + + @patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread") + def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock): + """Test that Dispatcher uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + dispatcher = engine._dispatcher + dispatcher.start() # This will create and start the mocked thread + + mock_thread_instance = mock_thread_cls.return_value + mock_thread_instance.is_alive.return_value = True + + dispatcher.stop() + + mock_thread_instance.join.assert_called_once_with(timeout=2.0) + + @patch("core.workflow.graph_engine.worker_management.worker_pool.Worker") + def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock): + """Test that WorkerPool uses 2s timeout instead of 10s.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + worker_pool = engine._worker_pool + worker_pool.start(initial_count=1) # Start with one worker + + mock_worker_instance = mock_worker_cls.return_value + mock_worker_instance.is_alive.return_value = True + + worker_pool.stop() + + mock_worker_instance.join.assert_called_once_with(timeout=2.0) + + +class TestStopEventResumeBehavior: + """Test stop_event behavior during workflow resume.""" + + def test_stop_event_cleared_on_resume(self): + """Test that stop_event is cleared when resuming a paused workflow.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" # Set proper id + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=runtime_state, + ) + mock_graph.nodes["start"] = start_node + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Simulate a previous execution that set stop_event + engine._stop_event.set() + assert engine._stop_event.is_set() + + # Run the engine (should clear stop_event in _start_execution) + events = list(engine.run()) + + # Execution should complete successfully + assert any(isinstance(e, GraphRunStartedEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) for e in events) + + +class TestWorkerStopBehavior: + """Test Worker behavior with shared stop_event.""" + + def test_worker_uses_shared_stop_event(self): + """Test that Worker uses shared stop_event from GraphEngine.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + + # Get the worker pool and check workers + worker_pool = engine._worker_pool + + # Start the worker pool to create workers + worker_pool.start() + + # Check that at least one worker was created + assert len(worker_pool._workers) > 0 + + # Verify workers use the shared stop_event + for worker in worker_pool._workers: + assert worker._stop_event is engine._stop_event + + # Clean up + worker_pool.stop() + + def test_worker_stop_is_noop(self): + """Test that Worker.stop() is now a no-op.""" + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + + # Create a mock worker + from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + from core.workflow.graph_engine.worker import Worker + + ready_queue = InMemoryReadyQueue() + event_queue = MagicMock() + + # Create a proper mock graph with real dict + mock_graph = Mock(spec=Graph) + mock_graph.nodes = {} # Use real dict + + stop_event = threading.Event() + + worker = Worker( + ready_queue=ready_queue, + event_queue=event_queue, + graph=mock_graph, + layers=[], + stop_event=stop_event, + ) + + # Calling stop() should do nothing (no-op) + # and should NOT set the stop_event + worker.stop() + assert not stop_event.is_set() diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 596e72ddd0..2262d25a14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,3 +1,4 @@ +from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage from core.variables.types import SegmentType from core.workflow.nodes.code.code_node import CodeNode @@ -7,6 +8,18 @@ from core.workflow.nodes.code.exc import ( DepthLimitError, OutputValidationError, ) +from core.workflow.nodes.code.limits import CodeNodeLimits + +CodeNode._limits = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, +) class TestCodeNodeExceptions: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index e8f257bf2f..1e224d56a5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -78,7 +78,7 @@ class TestFileSaverImpl: file_binary=_PNG_DATA, mimetype=mime_type, ) - mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") + mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 1a67d5c3e3..66d6c3c56b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.helper.code_executor.code_executor import CodeExecutionError from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.workflow import WorkflowType @@ -127,7 +127,9 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_simple_template( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): @@ -145,7 +147,7 @@ class TestTemplateTransformNode: mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) # Setup mock executor - mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"} + mock_execute.return_value = "Hello Alice, you are 30 years old!" node = TemplateTransformNode( id="test_node", @@ -162,7 +164,9 @@ class TestTemplateTransformNode: assert result.inputs["name"] == "Alice" assert result.inputs["age"] == 30 - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with None variable values.""" node_data = { @@ -172,7 +176,7 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.return_value = None - mock_execute.return_value = {"result": "Value: "} + mock_execute.return_value = "Value: " node = TemplateTransformNode( id="test_node", @@ -187,13 +191,15 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs["value"] is None - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_code_execution_error( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): """Test _run when code execution fails.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.side_effect = CodeExecutionError("Template syntax error") + mock_execute.side_effect = TemplateRenderError("Template syntax error") node = TemplateTransformNode( id="test_node", @@ -208,14 +214,16 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Template syntax error" in result.error - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) @patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10) def test_run_output_length_exceeds_limit( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params ): """Test _run when output exceeds maximum length.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"} + mock_execute.return_value = "This is a very long output that exceeds the limit" node = TemplateTransformNode( id="test_node", @@ -230,7 +238,9 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_complex_jinja2_template( self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params ): @@ -257,7 +267,7 @@ class TestTemplateTransformNode: ("sys", "show_total"): mock_show_total, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"} + mock_execute.return_value = "apple, banana, orange (Total: 3)" node = TemplateTransformNode( id="test_node", @@ -292,7 +302,9 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -301,7 +313,7 @@ class TestTemplateTransformNode: "template": "This is a static message.", } - mock_execute.return_value = {"result": "This is a static message."} + mock_execute.return_value = "This is a static message." node = TemplateTransformNode( id="test_node", @@ -317,7 +329,9 @@ class TestTemplateTransformNode: assert result.outputs["output"] == "This is a static message." assert result.inputs == {} - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with numeric variable values.""" node_data = { @@ -339,7 +353,7 @@ class TestTemplateTransformNode: ("sys", "quantity"): mock_quantity, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = {"result": "Total: $31.5"} + mock_execute.return_value = "Total: $31.5" node = TemplateTransformNode( id="test_node", @@ -354,7 +368,9 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["output"] == "Total: $31.5" - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with dictionary variable values.""" node_data = { @@ -367,7 +383,7 @@ class TestTemplateTransformNode: mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"} mock_graph_runtime_state.variable_pool.get.return_value = mock_user - mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"} + mock_execute.return_value = "Name: John Doe, Email: john@example.com" node = TemplateTransformNode( id="test_node", @@ -383,7 +399,9 @@ class TestTemplateTransformNode: assert "John Doe" in result.outputs["output"] assert "john@example.com" in result.outputs["output"] - @patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template") + @patch( + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + ) def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with list variable values.""" node_data = { @@ -396,7 +414,7 @@ class TestTemplateTransformNode: mock_tags.to_object.return_value = ["python", "ai", "workflow"] mock_graph_runtime_state.variable_pool.get.return_value = mock_tags - mock_execute.return_value = {"result": "Tags: #python #ai #workflow "} + mock_execute.return_value = "Tags: #python #ai #workflow " node = TemplateTransformNode( id="test_node", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 09b8191870..06927cddcf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import types from collections.abc import Generator @@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only @pytest.fixture -def tool_node(monkeypatch) -> "ToolNode": +def tool_node(monkeypatch) -> ToolNode: module_name = "core.ops.ops_trace_manager" if module_name not in sys.modules: ops_stub = types.ModuleType(module_name) @@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: +def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: def _identity_transform(messages, *_args, **_kwargs): return messages @@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l return _collect_events(generator) -def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): +def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( tenant_id="tenant-id", type=FileType.DOCUMENT, @@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"): assert files_segment.value == [file_obj] -def test_plain_link_messages_remain_links(tool_node: "ToolNode"): +def test_plain_link_messages_remain_links(tool_node: ToolNode): message = ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index c62fc4d8fe..1df75380af 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -1,14 +1,14 @@ import time import uuid -from unittest import mock from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable -from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph +from core.workflow.graph_events.node import NodeRunSucceededEvent from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -86,9 +86,6 @@ def test_overwrite_string_variable(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -104,20 +101,14 @@ def test_overwrite_string_variable(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = StringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=input_variable.value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == input_variable.value got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -191,9 +182,6 @@ def test_append_variable_to_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -209,22 +197,14 @@ def test_append_variable_to_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_value = list(conversation_variable.value) - expected_value.append(input_variable.value) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=expected_value, - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == ["the first value", "the second value"] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None @@ -287,9 +267,6 @@ def test_clear_array(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater) - mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater) - node_config = { "id": "node_id", "data": { @@ -305,20 +282,14 @@ def test_clear_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, config=node_config, - conv_var_updater_factory=mock_conv_var_updater_factory, ) - list(node.run()) - expected_var = ArrayStringVariable( - id=conversation_variable.id, - name=conversation_variable.name, - description=conversation_variable.description, - selector=conversation_variable.selector, - value_type=conversation_variable.value_type, - value=[], - ) - mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var) - mock_conv_var_updater.flush.assert_called_once() + events = list(node.run()) + succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) + updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) + assert updated_variables is not None + assert updated_variables[0].name == conversation_variable.name + assert updated_variables[0].new_value == [] got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index caa36734ad..353d56fe25 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -390,3 +390,42 @@ def test_remove_last_from_empty_array(): got = variable_pool.get(["conversation", conversation_variable.name]) assert got is not None assert got.to_object() == [] + + +def test_node_factory_creates_variable_assigner_node(): + graph_config = { + "edges": [], + "nodes": [ + { + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + variable_pool = VariablePool( + system_variables=SystemVariable(conversation_id="conversation_id"), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + node_factory = DifyNodeFactory( + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + ) + + node = node_factory.create_node(graph_config["nodes"][0]) + + assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index fc7a090ef9..d3a4d69f07 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -8,11 +8,12 @@ class TestCelerySSLConfiguration: """Test suite for Celery SSL configuration.""" def test_get_celery_ssl_options_when_ssl_disabled(self): - """Test SSL options when REDIS_USE_SSL is False.""" - mock_config = MagicMock() - mock_config.REDIS_USE_SSL = False + """Test SSL options when BROKER_USE_SSL is False.""" + from configs import DifyConfig - with patch("extensions.ext_celery.dify_config", mock_config): + dify_config = DifyConfig(CELERY_BROKER_URL="redis://localhost:6379/0") + + with patch("extensions.ext_celery.dify_config", dify_config): from extensions.ext_celery import _get_celery_ssl_options result = _get_celery_ssl_options() @@ -21,7 +22,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_when_broker_not_redis(self): """Test SSL options when broker is not Redis.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "amqp://localhost:5672" with patch("extensions.ext_celery.dify_config", mock_config): @@ -33,7 +33,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_none(self): """Test SSL options with CERT_NONE requirement.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" mock_config.REDIS_SSL_CA_CERTS = None @@ -53,7 +52,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_required(self): """Test SSL options with CERT_REQUIRED and certificates.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED" mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" @@ -73,7 +71,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_optional(self): """Test SSL options with CERT_OPTIONAL requirement.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL" mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" @@ -91,7 +88,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_invalid_cert_reqs(self): """Test SSL options with invalid cert requirement defaults to CERT_NONE.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE" mock_config.REDIS_SSL_CA_CERTS = None @@ -108,7 +104,6 @@ class TestCelerySSLConfiguration: def test_celery_init_applies_ssl_to_broker_and_backend(self): """Test that SSL options are applied to both broker and backend when using Redis.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.CELERY_BACKEND = "redis" mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" diff --git a/api/tests/unit_tests/fields/test_file_fields.py b/api/tests/unit_tests/fields/test_file_fields.py new file mode 100644 index 0000000000..8be8df16f4 --- /dev/null +++ b/api/tests/unit_tests/fields/test_file_fields.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig + + +def test_file_response_serializes_datetime() -> None: + created_at = datetime(2024, 1, 1, 12, 0, 0) + file_obj = SimpleNamespace( + id="file-1", + name="example.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by="user-1", + created_at=created_at, + preview_url="https://preview", + source_url="https://source", + original_url="https://origin", + user_id="user-1", + tenant_id="tenant-1", + conversation_id="conv-1", + file_key="key-1", + ) + + serialized = FileResponse.model_validate(file_obj, from_attributes=True).model_dump(mode="json") + + assert serialized["id"] == "file-1" + assert serialized["created_at"] == int(created_at.timestamp()) + assert serialized["preview_url"] == "https://preview" + assert serialized["source_url"] == "https://source" + assert serialized["original_url"] == "https://origin" + assert serialized["user_id"] == "user-1" + assert serialized["tenant_id"] == "tenant-1" + assert serialized["conversation_id"] == "conv-1" + assert serialized["file_key"] == "key-1" + + +def test_file_with_signed_url_builds_payload() -> None: + payload = FileWithSignedUrl( + id="file-2", + name="remote.pdf", + size=2048, + extension="pdf", + url="https://signed", + mime_type="application/pdf", + created_by="user-2", + created_at=datetime(2024, 1, 2, 0, 0, 0), + ) + + dumped = payload.model_dump(mode="json") + + assert dumped["url"] == "https://signed" + assert dumped["created_at"] == int(datetime(2024, 1, 2, 0, 0, 0).timestamp()) + + +def test_remote_file_info_and_upload_config() -> None: + info = RemoteFileInfo(file_type="text/plain", file_length=123) + assert info.model_dump(mode="json") == {"file_type": "text/plain", "file_length": 123} + + config = UploadConfig( + file_size_limit=1, + batch_count_limit=2, + file_upload_limit=3, + image_file_size_limit=4, + video_file_size_limit=5, + audio_file_size_limit=6, + workflow_file_upload_limit=7, + image_file_batch_limit=8, + single_chunk_attachment_limit=9, + attachment_image_file_size_limit=10, + ) + + dumped = config.model_dump(mode="json") + assert dumped["file_upload_limit"] == 3 + assert dumped["attachment_image_file_size_limit"] == 10 diff --git a/api/tests/unit_tests/libs/test_archive_storage.py b/api/tests/unit_tests/libs/test_archive_storage.py new file mode 100644 index 0000000000..697760e33a --- /dev/null +++ b/api/tests/unit_tests/libs/test_archive_storage.py @@ -0,0 +1,272 @@ +import base64 +import hashlib +from datetime import datetime +from unittest.mock import ANY, MagicMock + +import pytest +from botocore.exceptions import ClientError + +from libs import archive_storage as storage_module +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageError, + ArchiveStorageNotConfiguredError, +) + +BUCKET_NAME = "archive-bucket" + + +def _configure_storage(monkeypatch, **overrides): + defaults = { + "ARCHIVE_STORAGE_ENABLED": True, + "ARCHIVE_STORAGE_ENDPOINT": "https://storage.example.com", + "ARCHIVE_STORAGE_ARCHIVE_BUCKET": BUCKET_NAME, + "ARCHIVE_STORAGE_ACCESS_KEY": "access", + "ARCHIVE_STORAGE_SECRET_KEY": "secret", + "ARCHIVE_STORAGE_REGION": "auto", + } + defaults.update(overrides) + for key, value in defaults.items(): + monkeypatch.setattr(storage_module.dify_config, key, value, raising=False) + + +def _client_error(code: str) -> ClientError: + return ClientError({"Error": {"Code": code}}, "Operation") + + +def _mock_client(monkeypatch): + client = MagicMock() + client.head_bucket.return_value = None + boto_client = MagicMock(return_value=client) + monkeypatch.setattr(storage_module.boto3, "client", boto_client) + return client, boto_client + + +def test_init_disabled(monkeypatch): + _configure_storage(monkeypatch, ARCHIVE_STORAGE_ENABLED=False) + with pytest.raises(ArchiveStorageNotConfiguredError, match="not enabled"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_missing_config(monkeypatch): + _configure_storage(monkeypatch, ARCHIVE_STORAGE_ENDPOINT=None) + with pytest.raises(ArchiveStorageNotConfiguredError, match="incomplete"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_not_found(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("404") + + with pytest.raises(ArchiveStorageNotConfiguredError, match="does not exist"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_access_denied(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("403") + + with pytest.raises(ArchiveStorageNotConfiguredError, match="Access denied"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_other_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("500") + + with pytest.raises(ArchiveStorageError, match="Failed to access archive bucket"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_sets_client(monkeypatch): + _configure_storage(monkeypatch) + client, boto_client = _mock_client(monkeypatch) + + storage = ArchiveStorage(bucket=BUCKET_NAME) + + boto_client.assert_called_once_with( + "s3", + endpoint_url="https://storage.example.com", + aws_access_key_id="access", + aws_secret_access_key="secret", + region_name="auto", + config=ANY, + ) + assert storage.client is client + assert storage.bucket == BUCKET_NAME + + +def test_put_object_returns_checksum(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + + data = b"hello" + checksum = storage.put_object("key", data) + + expected_md5 = hashlib.md5(data).hexdigest() + expected_content_md5 = base64.b64encode(hashlib.md5(data).digest()).decode() + client.put_object.assert_called_once_with( + Bucket="archive-bucket", + Key="key", + Body=data, + ContentMD5=expected_content_md5, + ) + assert checksum == expected_md5 + + +def test_put_object_raises_on_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + client.put_object.side_effect = _client_error("500") + + with pytest.raises(ArchiveStorageError, match="Failed to upload object"): + storage.put_object("key", b"data") + + +def test_get_object_returns_bytes(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + body = MagicMock() + body.read.return_value = b"payload" + client.get_object.return_value = {"Body": body} + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.get_object("key") == b"payload" + + +def test_get_object_missing(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.get_object.side_effect = _client_error("NoSuchKey") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(FileNotFoundError, match="Archive object not found"): + storage.get_object("missing") + + +def test_get_object_stream(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + body = MagicMock() + body.iter_chunks.return_value = [b"a", b"b"] + client.get_object.return_value = {"Body": body} + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert list(storage.get_object_stream("key")) == [b"a", b"b"] + + +def test_get_object_stream_missing(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.get_object.side_effect = _client_error("NoSuchKey") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(FileNotFoundError, match="Archive object not found"): + list(storage.get_object_stream("missing")) + + +def test_object_exists(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.object_exists("key") is True + client.head_object.side_effect = _client_error("404") + assert storage.object_exists("missing") is False + + +def test_delete_object_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.delete_object.side_effect = _client_error("500") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to delete object"): + storage.delete_object("key") + + +def test_list_objects(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + paginator = MagicMock() + paginator.paginate.return_value = [ + {"Contents": [{"Key": "a"}, {"Key": "b"}]}, + {"Contents": [{"Key": "c"}]}, + ] + client.get_paginator.return_value = paginator + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.list_objects("prefix") == ["a", "b", "c"] + paginator.paginate.assert_called_once_with(Bucket="archive-bucket", Prefix="prefix") + + +def test_list_objects_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + paginator = MagicMock() + paginator.paginate.side_effect = _client_error("500") + client.get_paginator.return_value = paginator + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to list objects"): + storage.list_objects("prefix") + + +def test_generate_presigned_url(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.generate_presigned_url.return_value = "http://signed-url" + storage = ArchiveStorage(bucket=BUCKET_NAME) + + url = storage.generate_presigned_url("key", expires_in=123) + + client.generate_presigned_url.assert_called_once_with( + ClientMethod="get_object", + Params={"Bucket": "archive-bucket", "Key": "key"}, + ExpiresIn=123, + ) + assert url == "http://signed-url" + + +def test_generate_presigned_url_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.generate_presigned_url.side_effect = _client_error("500") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to generate pre-signed URL"): + storage.generate_presigned_url("key") + + +def test_serialization_roundtrip(): + records = [ + { + "id": "1", + "created_at": datetime(2024, 1, 1, 12, 0, 0), + "payload": {"nested": "value"}, + "items": [{"name": "a"}], + }, + {"id": "2", "value": 123}, + ] + + data = ArchiveStorage.serialize_to_jsonl_gz(records) + decoded = ArchiveStorage.deserialize_from_jsonl_gz(data) + + assert decoded[0]["id"] == "1" + assert decoded[0]["payload"]["nested"] == "value" + assert decoded[0]["items"][0]["name"] == "a" + assert "2024-01-01T12:00:00" in decoded[0]["created_at"] + assert decoded[1]["value"] == 123 + + +def test_content_md5_matches_checksum(): + data = b"checksum" + expected = base64.b64encode(hashlib.md5(data).digest()).decode() + + assert ArchiveStorage._content_md5(data) == expected + assert ArchiveStorage.compute_checksum(data) == hashlib.md5(data).hexdigest() diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py index 9aa157a651..5135970bcc 100644 --- a/api/tests/unit_tests/libs/test_external_api.py +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -99,29 +99,20 @@ def test_external_api_json_message_and_bad_request_rewrite(): assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty." -def test_external_api_param_mapping_and_quota_and_exc_info_none(): - # Force exc_info() to return (None,None,None) only during request - import libs.external_api as ext +def test_external_api_param_mapping_and_quota(): + app = _create_api_app() + client = app.test_client() - orig_exc_info = ext.sys.exc_info - try: - ext.sys.exc_info = lambda: (None, None, None) + # Param errors mapping payload path + res = client.get("/api/param-errors") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "invalid_param" + assert data["params"] == "field" - app = _create_api_app() - client = app.test_client() - - # Param errors mapping payload path - res = client.get("/api/param-errors") - assert res.status_code == 400 - data = res.get_json() - assert data["code"] == "invalid_param" - assert data["params"] == "field" - - # Quota path — depending on Flask-RESTX internals it may be handled - res = client.get("/api/quota") - assert res.status_code in (400, 429) - finally: - ext.sys.exc_info = orig_exc_info # type: ignore[assignment] + # Quota path — depending on Flask-RESTX internals it may be handled + res = client.get("/api/quota") + assert res.status_code in (400, 429) def test_unauthorized_and_force_logout_clears_cookies(): diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py index 85789bfa7e..de74eff82f 100644 --- a/api/tests/unit_tests/libs/test_helper.py +++ b/api/tests/unit_tests/libs/test_helper.py @@ -1,6 +1,6 @@ import pytest -from libs.helper import extract_tenant_id +from libs.helper import escape_like_pattern, extract_tenant_id from models.account import Account from models.model import EndUser @@ -63,3 +63,51 @@ class TestExtractTenantId: with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): extract_tenant_id(dict_user) + + +class TestEscapeLikePattern: + """Test cases for the escape_like_pattern utility function.""" + + def test_escape_percent_character(self): + """Test escaping percent character.""" + result = escape_like_pattern("50% discount") + assert result == "50\\% discount" + + def test_escape_underscore_character(self): + """Test escaping underscore character.""" + result = escape_like_pattern("test_data") + assert result == "test\\_data" + + def test_escape_backslash_character(self): + """Test escaping backslash character.""" + result = escape_like_pattern("path\\to\\file") + assert result == "path\\\\to\\\\file" + + def test_escape_combined_special_characters(self): + """Test escaping multiple special characters together.""" + result = escape_like_pattern("file_50%\\path") + assert result == "file\\_50\\%\\\\path" + + def test_escape_empty_string(self): + """Test escaping empty string returns empty string.""" + result = escape_like_pattern("") + assert result == "" + + def test_escape_none_handling(self): + """Test escaping None returns None (falsy check handles it).""" + # The function checks `if not pattern`, so None is falsy and returns as-is + result = escape_like_pattern(None) + assert result is None + + def test_escape_normal_string_no_change(self): + """Test that normal strings without special characters are unchanged.""" + result = escape_like_pattern("normal text") + assert result == "normal text" + + def test_escape_order_matters(self): + """Test that backslash is escaped first to prevent double escaping.""" + # If we escape % first, then escape \, we might get wrong results + # This test ensures the order is correct: \ first, then % and _ + result = escape_like_pattern("test\\%_value") + # Should be: test\\\%\_value + assert result == "test\\\\\\%\\_value" diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index e35788660d..8be2eea121 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -114,7 +114,7 @@ class TestAppModelValidation: def test_icon_type_validation(self): """Test icon type enum values.""" # Assert - assert {t.value for t in IconType} == {"image", "emoji"} + assert {t.value for t in IconType} == {"image", "emoji", "link"} def test_app_desc_or_prompt_with_description(self): """Test desc_or_prompt property when description exists.""" diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 9a107da1c7..e2360b116d 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -27,7 +27,6 @@ def service_with_fake_configurations(): description=None, icon_small=None, icon_small_dark=None, - icon_large=None, background=None, help=None, supported_model_types=[ModelType.LLM], diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 11e017464a..bf61162a66 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -15,6 +15,11 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols ("", ""), (" ", " "), ("【测试】", "【测试】"), + # Markdown link preservation - should be preserved if text starts with a markdown link + ("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"), + ("[Example](http://example.com) some text", "[Example](http://example.com) some text"), + # Leading symbols before markdown link are removed, including the opening bracket [ + ("@[Test](https://example.com)", "Test](https://example.com)"), ], ) def test_remove_leading_symbols(input_text, expected_output): diff --git a/api/uv.lock b/api/uv.lock index 4ccd229eec..8e60fad3a7 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1953,14 +1953,14 @@ wheels = [ [[package]] name = "fickling" -version = "0.1.5" +version = "0.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "stdlib-list" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/41/94/0d0ce455952c036cfee235637f786c1d1d07d1b90f6a4dfb50e0eff929d6/fickling-0.1.5.tar.gz", hash = "sha256:92f9b49e717fa8dbc198b4b7b685587adb652d85aa9ede8131b3e44494efca05", size = 282462, upload-time = "2025-11-18T05:04:30.748Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/ab/7571453f9365c17c047b5a7b7e82692a7f6be51203f295030886758fd57a/fickling-0.1.6.tar.gz", hash = "sha256:03cb5d7bd09f9169c7583d2079fad4b3b88b25f865ed0049172e5cb68582311d", size = 284033, upload-time = "2025-12-15T18:14:58.721Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/a7/d25912b2e3a5b0a37e6f460050bbc396042b5906a6563a1962c484abc3c6/fickling-0.1.5-py3-none-any.whl", hash = "sha256:6aed7270bfa276e188b0abe043a27b3a042129d28ec1fa6ff389bdcc5ad178bb", size = 46240, upload-time = "2025-11-18T05:04:29.048Z" }, + { url = "https://files.pythonhosted.org/packages/76/99/cc04258dda421bc612cdfe4be8c253f45b922f1c7f268b5a0b9962d9cd12/fickling-0.1.6-py3-none-any.whl", hash = "sha256:465d0069548bfc731bdd75a583cb4cf5a4b2666739c0f76287807d724b147ed3", size = 47922, upload-time = "2025-12-15T18:14:57.526Z" }, ] [[package]] @@ -2955,14 +2955,14 @@ wheels = [ [[package]] name = "intersystems-irispython" -version = "5.3.0" +version = "5.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" }, - { url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" }, - { url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" }, - { url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" }, - { url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" }, + { url = "https://files.pythonhosted.org/packages/33/5b/8eac672a6ef26bef6ef79a7c9557096167b50c4d3577d558ae6999c195fe/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-macosx_10_9_universal2.whl", hash = "sha256:634c9b4ec620837d830ff49543aeb2797a1ce8d8570a0e868398b85330dfcc4d", size = 6736686, upload-time = "2025-12-19T16:24:57.734Z" }, + { url = "https://files.pythonhosted.org/packages/ba/17/bab3e525ffb6711355f7feea18c1b7dced9c2484cecbcdd83f74550398c0/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf912f30f85e2a42f2c2ea77fbeb98a24154d5ea7428a50382786a684ec4f583", size = 16005259, upload-time = "2025-12-19T16:25:05.578Z" }, + { url = "https://files.pythonhosted.org/packages/39/59/9bb79d9e32e3e55fc9aed8071a797b4497924cbc6457cea9255bb09320b7/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be5659a6bb57593910f2a2417eddb9f5dc2f93a337ead6ddca778f557b8a359a", size = 15638040, upload-time = "2025-12-19T16:24:54.429Z" }, + { url = "https://files.pythonhosted.org/packages/cf/47/654ccf9c5cca4f5491f070888544165c9e2a6a485e320ea703e4e38d2358/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win32.whl", hash = "sha256:583e4f17088c1e0530f32efda1c0ccb02993cbc22035bc8b4c71d8693b04ee7e", size = 2879644, upload-time = "2025-12-19T16:24:59.945Z" }, + { url = "https://files.pythonhosted.org/packages/68/95/19cc13d09f1b4120bd41b1434509052e1d02afd27f2679266d7ad9cc1750/intersystems_irispython-5.3.1-cp38.cp39.cp310.cp311.cp312.cp313.cp314-cp38.cp39.cp310.cp311.cp312.cp313.cp314-win_amd64.whl", hash = "sha256:1d5d40450a0cdeec2a1f48d12d946a8a8ffc7c128576fcae7d58e66e3a127eae", size = 3522092, upload-time = "2025-12-19T16:25:01.834Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index 0e09d6869d..09ee1060e2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -58,8 +58,8 @@ FILES_URL= INTERNAL_FILES_URL= # Ensure UTF-8 encoding -LANG=en_US.UTF-8 -LC_ALL=en_US.UTF-8 +LANG=C.UTF-8 +LC_ALL=C.UTF-8 PYTHONIOENCODING=utf-8 # ------------------------------ @@ -69,6 +69,8 @@ PYTHONIOENCODING=utf-8 # The log level for the application. # Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` LOG_LEVEL=INFO +# Log output format: text or json +LOG_OUTPUT_FORMAT=text # Log file path LOG_FILE=/app/logs/server.log # Log file max size, the unit is MB @@ -231,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false # You can adjust the database configuration according to your needs. # ------------------------------ -# Database type, supported values are `postgresql` and `mysql` +# Database type, supported values are `postgresql`, `mysql`, `oceanbase`, `seekdb` DB_TYPE=postgresql # For MySQL, only `root` user is supported for now DB_USERNAME=postgres @@ -447,6 +449,15 @@ S3_SECRET_KEY= # If set to false, the access key and secret key must be provided. S3_USE_AWS_MANAGED_IAM=false +# Workflow run and Conversation archive storage (S3-compatible) +ARCHIVE_STORAGE_ENABLED=false +ARCHIVE_STORAGE_ENDPOINT= +ARCHIVE_STORAGE_ARCHIVE_BUCKET= +ARCHIVE_STORAGE_EXPORT_BUCKET= +ARCHIVE_STORAGE_ACCESS_KEY= +ARCHIVE_STORAGE_SECRET_KEY= +ARCHIVE_STORAGE_REGION=auto + # Azure Blob Configuration # AZURE_BLOB_ACCOUNT_NAME=difyai @@ -522,7 +533,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. +# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -533,9 +544,9 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_GRPC_ENDPOINT=grpc://weaviate:50051 WEAVIATE_TOKENIZATION=word -# For OceanBase metadata database configuration, available when `DB_TYPE` is `mysql` and `COMPOSE_PROFILES` includes `oceanbase`. +# For OceanBase metadata database configuration, available when `DB_TYPE` is `oceanbase`. # For OceanBase vector database configuration, available when `VECTOR_STORE` is `oceanbase` -# If you want to use OceanBase as both vector database and metadata database, you need to set `DB_TYPE` to `mysql`, `COMPOSE_PROFILES` is `oceanbase`, and set Database Configuration is the same as the vector database. +# If you want to use OceanBase as both vector database and metadata database, you need to set both `DB_TYPE` and `VECTOR_STORE` to `oceanbase`, and set Database Configuration is the same as the vector database. # seekdb is the lite version of OceanBase and shares the connection configuration with OceanBase. OCEANBASE_VECTOR_HOST=oceanbase OCEANBASE_VECTOR_PORT=2881 @@ -1066,6 +1077,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 3c88cddf8c..709aff23df 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -475,7 +475,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index dba61d1816..81c34fc6a2 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -129,6 +129,7 @@ services: - ./middleware.env environment: # Use the shared environment variables. + LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} REDIS_HOST: ${REDIS_HOST:-redis} REDIS_PORT: ${REDIS_PORT:-6379} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1c8d8d03e3..712de84c62 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -13,10 +13,11 @@ x-shared-env: &shared-api-worker-env APP_WEB_URL: ${APP_WEB_URL:-} FILES_URL: ${FILES_URL:-} INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} - LANG: ${LANG:-en_US.UTF-8} - LC_ALL: ${LC_ALL:-en_US.UTF-8} + LANG: ${LANG:-C.UTF-8} + LC_ALL: ${LC_ALL:-C.UTF-8} PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8} LOG_LEVEL: ${LOG_LEVEL:-INFO} + LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} @@ -122,6 +123,13 @@ x-shared-env: &shared-api-worker-env S3_ACCESS_KEY: ${S3_ACCESS_KEY:-} S3_SECRET_KEY: ${S3_SECRET_KEY:-} S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} + ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false} + ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-} + ARCHIVE_STORAGE_ARCHIVE_BUCKET: ${ARCHIVE_STORAGE_ARCHIVE_BUCKET:-} + ARCHIVE_STORAGE_EXPORT_BUCKET: ${ARCHIVE_STORAGE_EXPORT_BUCKET:-} + ARCHIVE_STORAGE_ACCESS_KEY: ${ARCHIVE_STORAGE_ACCESS_KEY:-} + ARCHIVE_STORAGE_SECRET_KEY: ${ARCHIVE_STORAGE_SECRET_KEY:-} + ARCHIVE_STORAGE_REGION: ${ARCHIVE_STORAGE_REGION:-auto} AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} @@ -467,6 +475,7 @@ x-shared-env: &shared-api-worker-env ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365} LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false} LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true} + LOGSTORE_ENABLE_PUT_GRAPH_FIELD: ${LOGSTORE_ENABLE_PUT_GRAPH_FIELD:-true} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} @@ -1148,7 +1157,8 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini - LANG: en_US.UTF-8 + LANG: C.UTF-8 + LC_ALL: C.UTF-8 ports: - "${OCEANBASE_VECTOR_PORT:-2881}:2881" healthcheck: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index f7e0252a6f..c88dbe5511 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -233,4 +233,8 @@ ALIYUN_SLS_LOGSTORE_TTL=365 LOGSTORE_DUAL_WRITE_ENABLED=true # Enable dual-read fallback to SQL database when LogStore returns no results (default: true) # Useful for migration scenarios where historical data exists only in SQL database -LOGSTORE_DUAL_READ_ENABLED=true \ No newline at end of file +LOGSTORE_DUAL_READ_ENABLED=true +# Control flag for whether to write the `graph` field to LogStore. +# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field; +# otherwise write an empty {} instead. Defaults to writing the `graph` field. +LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true \ No newline at end of file diff --git a/docker/ssrf_proxy/squid.conf.template b/docker/ssrf_proxy/squid.conf.template index 1775a1fff9..256e669c8d 100644 --- a/docker/ssrf_proxy/squid.conf.template +++ b/docker/ssrf_proxy/squid.conf.template @@ -54,3 +54,52 @@ http_access allow src_all # Unless the option's size is increased, an error will occur when uploading more than two files. client_request_buffer_max_size 100 MB + +################################## Performance & Concurrency ############################### +# Increase file descriptor limit for high concurrency +max_filedescriptors 65536 + +# Timeout configurations for image requests +connect_timeout 30 seconds +request_timeout 2 minutes +read_timeout 2 minutes +client_lifetime 5 minutes +shutdown_lifetime 30 seconds + +# Persistent connections - improve performance for multiple requests +server_persistent_connections on +client_persistent_connections on +persistent_request_timeout 30 seconds +pconn_timeout 1 minute + +# Connection pool and concurrency limits +client_db on +server_idle_pconn_timeout 2 minutes +client_idle_pconn_timeout 2 minutes + +# Quick abort settings - don't abort requests that are mostly done +quick_abort_min 16 KB +quick_abort_max 16 MB +quick_abort_pct 95 + +# Memory and cache optimization +memory_cache_mode disk +cache_mem 256 MB +maximum_object_size_in_memory 512 KB + +# DNS resolver settings for better performance +dns_timeout 30 seconds +dns_retransmit_interval 5 seconds +# By default, Squid uses the system's configured DNS resolvers. +# If you need to override them, set dns_nameservers to appropriate servers +# for your environment (for example, internal/corporate DNS). The following +# is an example using public DNS and SHOULD be customized before use: +# dns_nameservers 8.8.8.8 8.8.4.4 + +# Logging format for better debugging +logformat dify_log %ts.%03tu %6tr %>a %Ss/%03>Hs % ({ + getI18n: () => ({ + t: (key: string) => key, + language: 'en', + }), +})) + // Mock API functions vi.mock('@/service/base', () => ({ postMarketplace: vi.fn(), diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx index 18657f4bd2..ba3840ac3e 100644 --- a/web/__tests__/workflow-parallel-limit.test.tsx +++ b/web/__tests__/workflow-parallel-limit.test.tsx @@ -64,7 +64,6 @@ vi.mock('i18next', () => ({ // Mock the useConfig hook vi.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ - __esModule: true, default: () => ({ inputs: { is_parallel: true, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx index 004f83afc5..368c3dcfc3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx @@ -4,11 +4,11 @@ import type { FC } from 'react' import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types' import { RiCalendarLine } from '@remixicon/react' import dayjs from 'dayjs' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import * as React from 'react' import { useCallback } from 'react' import Picker from '@/app/components/base/date-and-time-picker/date-picker' -import { useI18N } from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { cn } from '@/utils/classnames' import { formatToLocalTime } from '@/utils/format' @@ -26,7 +26,7 @@ const DatePicker: FC = ({ onStartChange, onEndChange, }) => { - const { locale } = useI18N() + const locale = useLocale() const renderDate = useCallback(({ value, handleClickTrigger, isOpen }: TriggerProps) => { return ( diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx index 10209de97b..53794ad8db 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx @@ -7,7 +7,7 @@ import dayjs from 'dayjs' import * as React from 'react' import { useCallback, useState } from 'react' import { HourglassShape } from '@/app/components/base/icons/src/vender/other' -import { useI18N } from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { formatToLocalTime } from '@/utils/format' import DatePicker from './date-picker' import RangeSelector from './range-selector' @@ -27,7 +27,7 @@ const TimeRangePicker: FC = ({ onSelect, queryDateFormat, }) => { - const { locale } = useI18N() + const locale = useLocale() const [isCustomRange, setIsCustomRange] = useState(false) const [start, setStart] = useState(today) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index 8080b565cd..1d65e4de53 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -1,11 +1,8 @@ -/* eslint-disable dify-i18n/require-ns-option */ -import * as React from 'react' +import { useTranslation } from '#i18n' import Form from '@/app/components/datasets/settings/form' -import { getLocaleOnServer, getTranslation } from '@/i18n-config/server' -const Settings = async () => { - const locale = await getLocaleOnServer() - const { t } = await getTranslation(locale, 'dataset-settings') +const Settings = () => { + const { t } = useTranslation('datasetSettings') return (
diff --git a/web/app/(commonLayout)/plugins/page.tsx b/web/app/(commonLayout)/plugins/page.tsx index 2df9cf23c4..81bda3a8a3 100644 --- a/web/app/(commonLayout)/plugins/page.tsx +++ b/web/app/(commonLayout)/plugins/page.tsx @@ -1,14 +1,12 @@ import Marketplace from '@/app/components/plugins/marketplace' import PluginPage from '@/app/components/plugins/plugin-page' import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel' -import { getLocaleOnServer } from '@/i18n-config/server' const PluginList = async () => { - const locale = await getLocaleOnServer() return ( } - marketplace={} + marketplace={} /> ) } diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index ac15f1df6d..fbf45259e5 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -3,12 +3,12 @@ import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' export default function CheckCode() { @@ -19,7 +19,7 @@ export default function CheckCode() { const token = decodeURIComponent(searchParams.get('token') as string) const [code, setVerifyCode] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const verify = async () => { try { diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 6acd8d08f4..9b9a853cdd 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -1,17 +1,17 @@ 'use client' import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import Link from 'next/link' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' import { sendResetPasswordCode } from '@/service/common' @@ -22,7 +22,7 @@ export default function CheckCode() { const router = useRouter() const [email, setEmail] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const handleGetEMailVerificationCode = async () => { try { diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index 0ef63dcbd2..bda5484197 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -4,12 +4,12 @@ import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import { fetchAccessToken } from '@/service/share' @@ -23,7 +23,7 @@ export default function CheckCode() { const token = decodeURIComponent(searchParams.get('token') as string) const [code, setVerifyCode] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const codeInputRef = useRef(null) const redirectUrl = searchParams.get('redirect_url') const embeddedUserId = useWebAppStore(s => s.embeddedUserId) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index f3e018a1fa..5aa9d9f141 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,14 +1,13 @@ -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { sendWebAppEMailLoginCode } from '@/service/common' export default function MailAndCodeAuth() { @@ -18,7 +17,7 @@ export default function MailAndCodeAuth() { const emailFromLink = decodeURIComponent(searchParams.get('email') || '') const [email, setEmail] = useState(emailFromLink) const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const handleGetEMailVerificationCode = async () => { try { diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 7e76a87250..23ac83e76c 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -1,15 +1,14 @@ 'use client' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import Link from 'next/link' import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' import { webAppLogin } from '@/service/common' import { fetchAccessToken } from '@/service/share' @@ -21,7 +20,7 @@ type MailAndPasswordAuthProps = { export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) { const { t } = useTranslation() - const { locale } = useContext(I18NContext) + const locale = useLocale() const router = useRouter() const searchParams = useSearchParams() const [showPassword, setShowPassword] = useState(false) diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 6e702770f7..87ca6a689c 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -1,6 +1,6 @@ import type { ResponseError } from '@/service/fetch' import { RiCloseLine } from '@remixicon/react' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' @@ -214,7 +214,8 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{t('account.changeEmail.authTip', { ns: 'common' })}
}} values={{ email }} /> @@ -244,7 +245,8 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}} values={{ email }} /> @@ -333,7 +335,8 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}} values={{ email: mail }} /> diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index 0f710abf39..e30646eb3f 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -1,14 +1,18 @@ 'use client' import type { ReactNode } from 'react' +import Cookies from 'js-cookie' import { usePathname, useRouter, useSearchParams } from 'next/navigation' +import { parseAsString, useQueryState } from 'nuqs' import { useCallback, useEffect, useState } from 'react' import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' import { fetchSetupStatus } from '@/service/common' +import { sendGAEvent } from '@/utils/gtag' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' +import { trackEvent } from './base/amplitude' type AppInitializerProps = { children: ReactNode @@ -22,6 +26,10 @@ export const AppInitializer = ({ // Tokens are now stored in cookies, no need to check localStorage const pathname = usePathname() const [init, setInit] = useState(false) + const [oauthNewUser, setOauthNewUser] = useQueryState( + 'oauth_new_user', + parseAsString.withOptions({ history: 'replace' }), + ) const isSetupFinished = useCallback(async () => { try { @@ -45,6 +53,34 @@ export const AppInitializer = ({ (async () => { const action = searchParams.get('action') + if (oauthNewUser === 'true') { + let utmInfo = null + const utmInfoStr = Cookies.get('utm_info') + if (utmInfoStr) { + try { + utmInfo = JSON.parse(utmInfoStr) + } + catch (e) { + console.error('Failed to parse utm_info cookie:', e) + } + } + + // Track registration event with UTM params + trackEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', { + method: 'oauth', + ...utmInfo, + }) + + sendGAEvent(utmInfo ? 'user_registration_success_with_utm' : 'user_registration_success', { + method: 'oauth', + ...utmInfo, + }) + + // Clean up: remove utm_info cookie and URL params + Cookies.remove('utm_info') + setOauthNewUser(null) + } + if (action === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION) localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes') @@ -67,7 +103,7 @@ export const AppInitializer = ({ router.replace('/signin') } })() - }, [isSetupFinished, router, pathname, searchParams]) + }, [isSetupFinished, router, pathname, searchParams, oauthNewUser, setOauthNewUser]) return init ? children : null } diff --git a/web/app/components/app-sidebar/dataset-info/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/index.spec.tsx index da7eb6d7ff..9996ef2b4d 100644 --- a/web/app/components/app-sidebar/dataset-info/index.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/index.spec.tsx @@ -132,7 +132,6 @@ vi.mock('@/hooks/use-knowledge', () => ({ })) vi.mock('@/app/components/datasets/rename-modal', () => ({ - __esModule: true, default: ({ show, onClose, diff --git a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx b/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx index 7c0c8b3aca..f7e91b3dea 100644 --- a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx +++ b/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx @@ -13,7 +13,6 @@ vi.mock('next/navigation', () => ({ // Mock classnames utility vi.mock('@/utils/classnames', () => ({ - __esModule: true, default: (...classes: any[]) => classes.filter(Boolean).join(' '), })) diff --git a/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx index 6837516b3c..bad3ceefdf 100644 --- a/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx @@ -10,7 +10,6 @@ vi.mock('@/context/provider-context', () => ({ const mockToastNotify = vi.fn() vi.mock('@/app/components/base/toast', () => ({ - __esModule: true, default: { notify: vi.fn(args => mockToastNotify(args)), }, diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx index a3ab73b339..2ab0934fe2 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx @@ -1,7 +1,8 @@ +import type { Mock } from 'vitest' import type { Locale } from '@/i18n-config' import { render, screen } from '@testing-library/react' import * as React from 'react' -import I18nContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' import CSVDownload from './csv-downloader' @@ -17,17 +18,13 @@ vi.mock('react-papaparse', () => ({ })), })) +vi.mock('@/context/i18n', () => ({ + useLocale: vi.fn(() => 'en-US'), +})) + const renderWithLocale = (locale: Locale) => { - return render( - - - , - ) + ;(useLocale as Mock).mockReturnValue(locale) + return render() } describe('CSVDownload', () => { diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx index a0c204062b..8db70104bc 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.tsx @@ -5,9 +5,9 @@ import { useTranslation } from 'react-i18next' import { useCSVDownloader, } from 'react-papaparse' -import { useContext } from 'use-context-selector' import { Download02 as DownloadIcon } from '@/app/components/base/icons/src/vender/solid/general' -import I18n from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' const CSV_TEMPLATE_QA_EN = [ @@ -24,7 +24,7 @@ const CSV_TEMPLATE_QA_CN = [ const CSVDownload: FC = () => { const { t } = useTranslation() - const { locale } = useContext(I18n) + const locale = useLocale() const { CSVDownloader, Type } = useCSVDownloader() const getTemplate = () => { diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx index d7458d6b90..7fdb99fbab 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx @@ -8,7 +8,6 @@ import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/ser import BatchModal, { ProcessStatus } from './index' vi.mock('@/app/components/base/toast', () => ({ - __esModule: true, default: { notify: vi.fn(), }, @@ -24,14 +23,12 @@ vi.mock('@/context/provider-context', () => ({ })) vi.mock('./csv-downloader', () => ({ - __esModule: true, default: () =>
, })) let lastUploadedFile: File | undefined vi.mock('./csv-uploader', () => ({ - __esModule: true, default: ({ file, updateFile }: { file?: File, updateFile: (file?: File) => void }) => (