mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 02:06:35 +08:00
Merge branch 'main' into feat/hitl-frontend
This commit is contained in:
commit
a19c0023f9
355
.claude/skills/skill-creator/SKILL.md
Normal file
355
.claude/skills/skill-creator/SKILL.md
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
---
|
||||||
|
name: skill-creator
|
||||||
|
description: Guide for creating effective skills. This skill should be used when users want to create a new skill (or update an existing skill) that extends Claude's capabilities with specialized knowledge, workflows, or tool integrations.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Skill Creator
|
||||||
|
|
||||||
|
This skill provides guidance for creating effective skills.
|
||||||
|
|
||||||
|
## About Skills
|
||||||
|
|
||||||
|
Skills are modular, self-contained packages that extend Claude's capabilities by providing
|
||||||
|
specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific
|
||||||
|
domains or tasks—they transform Claude from a general-purpose agent into a specialized agent
|
||||||
|
equipped with procedural knowledge that no model can fully possess.
|
||||||
|
|
||||||
|
### What Skills Provide
|
||||||
|
|
||||||
|
1. Specialized workflows - Multi-step procedures for specific domains
|
||||||
|
2. Tool integrations - Instructions for working with specific file formats or APIs
|
||||||
|
3. Domain expertise - Company-specific knowledge, schemas, business logic
|
||||||
|
4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks
|
||||||
|
|
||||||
|
## Core Principles
|
||||||
|
|
||||||
|
### Concise is Key
|
||||||
|
|
||||||
|
The context window is a public good. Skills share the context window with everything else Claude needs: system prompt, conversation history, other Skills' metadata, and the actual user request.
|
||||||
|
|
||||||
|
**Default assumption: Claude is already very smart.** Only add context Claude doesn't already have. Challenge each piece of information: "Does Claude really need this explanation?" and "Does this paragraph justify its token cost?"
|
||||||
|
|
||||||
|
Prefer concise examples over verbose explanations.
|
||||||
|
|
||||||
|
### Set Appropriate Degrees of Freedom
|
||||||
|
|
||||||
|
Match the level of specificity to the task's fragility and variability:
|
||||||
|
|
||||||
|
**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach.
|
||||||
|
|
||||||
|
**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior.
|
||||||
|
|
||||||
|
**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed.
|
||||||
|
|
||||||
|
Think of Claude as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom).
|
||||||
|
|
||||||
|
### Anatomy of a Skill
|
||||||
|
|
||||||
|
Every skill consists of a required SKILL.md file and optional bundled resources:
|
||||||
|
|
||||||
|
```
|
||||||
|
skill-name/
|
||||||
|
├── SKILL.md (required)
|
||||||
|
│ ├── YAML frontmatter metadata (required)
|
||||||
|
│ │ ├── name: (required)
|
||||||
|
│ │ └── description: (required)
|
||||||
|
│ └── Markdown instructions (required)
|
||||||
|
└── Bundled Resources (optional)
|
||||||
|
├── scripts/ - Executable code (Python/Bash/etc.)
|
||||||
|
├── references/ - Documentation intended to be loaded into context as needed
|
||||||
|
└── assets/ - Files used in output (templates, icons, fonts, etc.)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### SKILL.md (required)
|
||||||
|
|
||||||
|
Every SKILL.md consists of:
|
||||||
|
|
||||||
|
- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that Claude reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used.
|
||||||
|
- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all).
|
||||||
|
|
||||||
|
#### Bundled Resources (optional)
|
||||||
|
|
||||||
|
##### Scripts (`scripts/`)
|
||||||
|
|
||||||
|
Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten.
|
||||||
|
|
||||||
|
- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed
|
||||||
|
- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks
|
||||||
|
- **Benefits**: Token efficient, deterministic, may be executed without loading into context
|
||||||
|
- **Note**: Scripts may still need to be read by Claude for patching or environment-specific adjustments
|
||||||
|
|
||||||
|
##### References (`references/`)
|
||||||
|
|
||||||
|
Documentation and reference material intended to be loaded as needed into context to inform Claude's process and thinking.
|
||||||
|
|
||||||
|
- **When to include**: For documentation that Claude should reference while working
|
||||||
|
- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications
|
||||||
|
- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides
|
||||||
|
- **Benefits**: Keeps SKILL.md lean, loaded only when Claude determines it's needed
|
||||||
|
- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md
|
||||||
|
- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files.
|
||||||
|
|
||||||
|
##### Assets (`assets/`)
|
||||||
|
|
||||||
|
Files not intended to be loaded into context, but rather used within the output Claude produces.
|
||||||
|
|
||||||
|
- **When to include**: When the skill needs files that will be used in the final output
|
||||||
|
- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography
|
||||||
|
- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified
|
||||||
|
- **Benefits**: Separates output resources from documentation, enables Claude to use files without loading them into context
|
||||||
|
|
||||||
|
#### What to Not Include in a Skill
|
||||||
|
|
||||||
|
A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including:
|
||||||
|
|
||||||
|
- README.md
|
||||||
|
- INSTALLATION_GUIDE.md
|
||||||
|
- QUICK_REFERENCE.md
|
||||||
|
- CHANGELOG.md
|
||||||
|
- etc.
|
||||||
|
|
||||||
|
The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxilary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion.
|
||||||
|
|
||||||
|
### Progressive Disclosure Design Principle
|
||||||
|
|
||||||
|
Skills use a three-level loading system to manage context efficiently:
|
||||||
|
|
||||||
|
1. **Metadata (name + description)** - Always in context (~100 words)
|
||||||
|
2. **SKILL.md body** - When skill triggers (<5k words)
|
||||||
|
3. **Bundled resources** - As needed by Claude (Unlimited because scripts can be executed without reading into context window)
|
||||||
|
|
||||||
|
#### Progressive Disclosure Patterns
|
||||||
|
|
||||||
|
Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them.
|
||||||
|
|
||||||
|
**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files.
|
||||||
|
|
||||||
|
**Pattern 1: High-level guide with references**
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# PDF Processing
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Extract text with pdfplumber:
|
||||||
|
[code example]
|
||||||
|
|
||||||
|
## Advanced features
|
||||||
|
|
||||||
|
- **Form filling**: See [FORMS.md](FORMS.md) for complete guide
|
||||||
|
- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods
|
||||||
|
- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns
|
||||||
|
```
|
||||||
|
|
||||||
|
Claude loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed.
|
||||||
|
|
||||||
|
**Pattern 2: Domain-specific organization**
|
||||||
|
|
||||||
|
For Skills with multiple domains, organize content by domain to avoid loading irrelevant context:
|
||||||
|
|
||||||
|
```
|
||||||
|
bigquery-skill/
|
||||||
|
├── SKILL.md (overview and navigation)
|
||||||
|
└── reference/
|
||||||
|
├── finance.md (revenue, billing metrics)
|
||||||
|
├── sales.md (opportunities, pipeline)
|
||||||
|
├── product.md (API usage, features)
|
||||||
|
└── marketing.md (campaigns, attribution)
|
||||||
|
```
|
||||||
|
|
||||||
|
When a user asks about sales metrics, Claude only reads sales.md.
|
||||||
|
|
||||||
|
Similarly, for skills supporting multiple frameworks or variants, organize by variant:
|
||||||
|
|
||||||
|
```
|
||||||
|
cloud-deploy/
|
||||||
|
├── SKILL.md (workflow + provider selection)
|
||||||
|
└── references/
|
||||||
|
├── aws.md (AWS deployment patterns)
|
||||||
|
├── gcp.md (GCP deployment patterns)
|
||||||
|
└── azure.md (Azure deployment patterns)
|
||||||
|
```
|
||||||
|
|
||||||
|
When the user chooses AWS, Claude only reads aws.md.
|
||||||
|
|
||||||
|
**Pattern 3: Conditional details**
|
||||||
|
|
||||||
|
Show basic content, link to advanced content:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
# DOCX Processing
|
||||||
|
|
||||||
|
## Creating documents
|
||||||
|
|
||||||
|
Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md).
|
||||||
|
|
||||||
|
## Editing documents
|
||||||
|
|
||||||
|
For simple edits, modify the XML directly.
|
||||||
|
|
||||||
|
**For tracked changes**: See [REDLINING.md](REDLINING.md)
|
||||||
|
**For OOXML details**: See [OOXML.md](OOXML.md)
|
||||||
|
```
|
||||||
|
|
||||||
|
Claude reads REDLINING.md or OOXML.md only when the user needs those features.
|
||||||
|
|
||||||
|
**Important guidelines:**
|
||||||
|
|
||||||
|
- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md.
|
||||||
|
- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so Claude can see the full scope when previewing.
|
||||||
|
|
||||||
|
## Skill Creation Process
|
||||||
|
|
||||||
|
Skill creation involves these steps:
|
||||||
|
|
||||||
|
1. Understand the skill with concrete examples
|
||||||
|
2. Plan reusable skill contents (scripts, references, assets)
|
||||||
|
3. Initialize the skill (run init_skill.py)
|
||||||
|
4. Edit the skill (implement resources and write SKILL.md)
|
||||||
|
5. Package the skill (run package_skill.py)
|
||||||
|
6. Iterate based on real usage
|
||||||
|
|
||||||
|
Follow these steps in order, skipping only if there is a clear reason why they are not applicable.
|
||||||
|
|
||||||
|
### Step 1: Understanding the Skill with Concrete Examples
|
||||||
|
|
||||||
|
Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill.
|
||||||
|
|
||||||
|
To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback.
|
||||||
|
|
||||||
|
For example, when building an image-editor skill, relevant questions include:
|
||||||
|
|
||||||
|
- "What functionality should the image-editor skill support? Editing, rotating, anything else?"
|
||||||
|
- "Can you give some examples of how this skill would be used?"
|
||||||
|
- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?"
|
||||||
|
- "What would a user say that should trigger this skill?"
|
||||||
|
|
||||||
|
To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness.
|
||||||
|
|
||||||
|
Conclude this step when there is a clear sense of the functionality the skill should support.
|
||||||
|
|
||||||
|
### Step 2: Planning the Reusable Skill Contents
|
||||||
|
|
||||||
|
To turn concrete examples into an effective skill, analyze each example by:
|
||||||
|
|
||||||
|
1. Considering how to execute on the example from scratch
|
||||||
|
2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly
|
||||||
|
|
||||||
|
Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows:
|
||||||
|
|
||||||
|
1. Rotating a PDF requires re-writing the same code each time
|
||||||
|
2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill
|
||||||
|
|
||||||
|
Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows:
|
||||||
|
|
||||||
|
1. Writing a frontend webapp requires the same boilerplate HTML/React each time
|
||||||
|
2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill
|
||||||
|
|
||||||
|
Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows:
|
||||||
|
|
||||||
|
1. Querying BigQuery requires re-discovering the table schemas and relationships each time
|
||||||
|
2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill
|
||||||
|
|
||||||
|
To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets.
|
||||||
|
|
||||||
|
### Step 3: Initializing the Skill
|
||||||
|
|
||||||
|
At this point, it is time to actually create the skill.
|
||||||
|
|
||||||
|
Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step.
|
||||||
|
|
||||||
|
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
scripts/init_skill.py <skill-name> --path <output-directory>
|
||||||
|
```
|
||||||
|
|
||||||
|
The script:
|
||||||
|
|
||||||
|
- Creates the skill directory at the specified path
|
||||||
|
- Generates a SKILL.md template with proper frontmatter and TODO placeholders
|
||||||
|
- Creates example resource directories: `scripts/`, `references/`, and `assets/`
|
||||||
|
- Adds example files in each directory that can be customized or deleted
|
||||||
|
|
||||||
|
After initialization, customize or remove the generated SKILL.md and example files as needed.
|
||||||
|
|
||||||
|
### Step 4: Edit the Skill
|
||||||
|
|
||||||
|
When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of Claude to use. Include information that would be beneficial and non-obvious to Claude. Consider what procedural knowledge, domain-specific details, or reusable assets would help another Claude instance execute these tasks more effectively.
|
||||||
|
|
||||||
|
#### Learn Proven Design Patterns
|
||||||
|
|
||||||
|
Consult these helpful guides based on your skill's needs:
|
||||||
|
|
||||||
|
- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic
|
||||||
|
- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns
|
||||||
|
|
||||||
|
These files contain established best practices for effective skill design.
|
||||||
|
|
||||||
|
#### Start with Reusable Skill Contents
|
||||||
|
|
||||||
|
To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`.
|
||||||
|
|
||||||
|
Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion.
|
||||||
|
|
||||||
|
Any example files and directories not needed for the skill should be deleted. The initialization script creates example files in `scripts/`, `references/`, and `assets/` to demonstrate structure, but most skills won't need all of them.
|
||||||
|
|
||||||
|
#### Update SKILL.md
|
||||||
|
|
||||||
|
**Writing Guidelines:** Always use imperative/infinitive form.
|
||||||
|
|
||||||
|
##### Frontmatter
|
||||||
|
|
||||||
|
Write the YAML frontmatter with `name` and `description`:
|
||||||
|
|
||||||
|
- `name`: The skill name
|
||||||
|
- `description`: This is the primary triggering mechanism for your skill, and helps Claude understand when to use the skill.
|
||||||
|
- Include both what the Skill does and specific triggers/contexts for when to use it.
|
||||||
|
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to Claude.
|
||||||
|
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when Claude needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
|
||||||
|
|
||||||
|
Do not include any other fields in YAML frontmatter.
|
||||||
|
|
||||||
|
##### Body
|
||||||
|
|
||||||
|
Write instructions for using the skill and its bundled resources.
|
||||||
|
|
||||||
|
### Step 5: Packaging a Skill
|
||||||
|
|
||||||
|
Once development of the skill is complete, it must be packaged into a distributable .skill file that gets shared with the user. The packaging process automatically validates the skill first to ensure it meets all requirements:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
scripts/package_skill.py <path/to/skill-folder>
|
||||||
|
```
|
||||||
|
|
||||||
|
Optional output directory specification:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
scripts/package_skill.py <path/to/skill-folder> ./dist
|
||||||
|
```
|
||||||
|
|
||||||
|
The packaging script will:
|
||||||
|
|
||||||
|
1. **Validate** the skill automatically, checking:
|
||||||
|
|
||||||
|
- YAML frontmatter format and required fields
|
||||||
|
- Skill naming conventions and directory structure
|
||||||
|
- Description completeness and quality
|
||||||
|
- File organization and resource references
|
||||||
|
|
||||||
|
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
|
||||||
|
|
||||||
|
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
|
||||||
|
|
||||||
|
### Step 6: Iterate
|
||||||
|
|
||||||
|
After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed.
|
||||||
|
|
||||||
|
**Iteration workflow:**
|
||||||
|
|
||||||
|
1. Use the skill on real tasks
|
||||||
|
2. Notice struggles or inefficiencies
|
||||||
|
3. Identify how SKILL.md or bundled resources should be updated
|
||||||
|
4. Implement changes and test again
|
||||||
86
.claude/skills/skill-creator/references/output-patterns.md
Normal file
86
.claude/skills/skill-creator/references/output-patterns.md
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# Output Patterns
|
||||||
|
|
||||||
|
Use these patterns when skills need to produce consistent, high-quality output.
|
||||||
|
|
||||||
|
## Template Pattern
|
||||||
|
|
||||||
|
Provide templates for output format. Match the level of strictness to your needs.
|
||||||
|
|
||||||
|
**For strict requirements (like API responses or data formats):**
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Report structure
|
||||||
|
|
||||||
|
ALWAYS use this exact template structure:
|
||||||
|
|
||||||
|
# [Analysis Title]
|
||||||
|
|
||||||
|
## Executive summary
|
||||||
|
[One-paragraph overview of key findings]
|
||||||
|
|
||||||
|
## Key findings
|
||||||
|
- Finding 1 with supporting data
|
||||||
|
- Finding 2 with supporting data
|
||||||
|
- Finding 3 with supporting data
|
||||||
|
|
||||||
|
## Recommendations
|
||||||
|
1. Specific actionable recommendation
|
||||||
|
2. Specific actionable recommendation
|
||||||
|
```
|
||||||
|
|
||||||
|
**For flexible guidance (when adaptation is useful):**
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Report structure
|
||||||
|
|
||||||
|
Here is a sensible default format, but use your best judgment:
|
||||||
|
|
||||||
|
# [Analysis Title]
|
||||||
|
|
||||||
|
## Executive summary
|
||||||
|
[Overview]
|
||||||
|
|
||||||
|
## Key findings
|
||||||
|
[Adapt sections based on what you discover]
|
||||||
|
|
||||||
|
## Recommendations
|
||||||
|
[Tailor to the specific context]
|
||||||
|
|
||||||
|
Adjust sections as needed for the specific analysis type.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples Pattern
|
||||||
|
|
||||||
|
For skills where output quality depends on seeing examples, provide input/output pairs:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Commit message format
|
||||||
|
|
||||||
|
Generate commit messages following these examples:
|
||||||
|
|
||||||
|
**Example 1:**
|
||||||
|
Input: Added user authentication with JWT tokens
|
||||||
|
Output:
|
||||||
|
```
|
||||||
|
|
||||||
|
feat(auth): implement JWT-based authentication
|
||||||
|
|
||||||
|
Add login endpoint and token validation middleware
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example 2:**
|
||||||
|
Input: Fixed bug where dates displayed incorrectly in reports
|
||||||
|
Output:
|
||||||
|
```
|
||||||
|
|
||||||
|
fix(reports): correct date formatting in timezone conversion
|
||||||
|
|
||||||
|
Use UTC timestamps consistently across report generation
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow this style: type(scope): brief description, then detailed explanation.
|
||||||
|
```
|
||||||
|
|
||||||
|
Examples help Claude understand the desired style and level of detail more clearly than descriptions alone.
|
||||||
28
.claude/skills/skill-creator/references/workflows.md
Normal file
28
.claude/skills/skill-creator/references/workflows.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Workflow Patterns
|
||||||
|
|
||||||
|
## Sequential Workflows
|
||||||
|
|
||||||
|
For complex tasks, break operations into clear, sequential steps. It is often helpful to give Claude an overview of the process towards the beginning of SKILL.md:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
Filling a PDF form involves these steps:
|
||||||
|
|
||||||
|
1. Analyze the form (run analyze_form.py)
|
||||||
|
2. Create field mapping (edit fields.json)
|
||||||
|
3. Validate mapping (run validate_fields.py)
|
||||||
|
4. Fill the form (run fill_form.py)
|
||||||
|
5. Verify output (run verify_output.py)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Conditional Workflows
|
||||||
|
|
||||||
|
For tasks with branching logic, guide Claude through decision points:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
1. Determine the modification type:
|
||||||
|
**Creating new content?** → Follow "Creation workflow" below
|
||||||
|
**Editing existing content?** → Follow "Editing workflow" below
|
||||||
|
|
||||||
|
2. Creation workflow: [steps]
|
||||||
|
3. Editing workflow: [steps]
|
||||||
|
```
|
||||||
300
.claude/skills/skill-creator/scripts/init_skill.py
Executable file
300
.claude/skills/skill-creator/scripts/init_skill.py
Executable file
@ -0,0 +1,300 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Skill Initializer - Creates a new skill from template
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
init_skill.py <skill-name> --path <path>
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
init_skill.py my-new-skill --path skills/public
|
||||||
|
init_skill.py my-api-helper --path skills/private
|
||||||
|
init_skill.py custom-skill --path /custom/location
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
SKILL_TEMPLATE = """---
|
||||||
|
name: {skill_name}
|
||||||
|
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
|
||||||
|
---
|
||||||
|
|
||||||
|
# {skill_title}
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
[TODO: 1-2 sentences explaining what this skill enables]
|
||||||
|
|
||||||
|
## Structuring This Skill
|
||||||
|
|
||||||
|
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
|
||||||
|
|
||||||
|
**1. Workflow-Based** (best for sequential processes)
|
||||||
|
- Works well when there are clear step-by-step procedures
|
||||||
|
- Example: DOCX skill with "Workflow Decision Tree" → "Reading" → "Creating" → "Editing"
|
||||||
|
- Structure: ## Overview → ## Workflow Decision Tree → ## Step 1 → ## Step 2...
|
||||||
|
|
||||||
|
**2. Task-Based** (best for tool collections)
|
||||||
|
- Works well when the skill offers different operations/capabilities
|
||||||
|
- Example: PDF skill with "Quick Start" → "Merge PDFs" → "Split PDFs" → "Extract Text"
|
||||||
|
- Structure: ## Overview → ## Quick Start → ## Task Category 1 → ## Task Category 2...
|
||||||
|
|
||||||
|
**3. Reference/Guidelines** (best for standards or specifications)
|
||||||
|
- Works well for brand guidelines, coding standards, or requirements
|
||||||
|
- Example: Brand styling with "Brand Guidelines" → "Colors" → "Typography" → "Features"
|
||||||
|
- Structure: ## Overview → ## Guidelines → ## Specifications → ## Usage...
|
||||||
|
|
||||||
|
**4. Capabilities-Based** (best for integrated systems)
|
||||||
|
- Works well when the skill provides multiple interrelated features
|
||||||
|
- Example: Product Management with "Core Capabilities" → numbered capability list
|
||||||
|
- Structure: ## Overview → ## Core Capabilities → ### 1. Feature → ### 2. Feature...
|
||||||
|
|
||||||
|
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
|
||||||
|
|
||||||
|
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
|
||||||
|
|
||||||
|
## [TODO: Replace with the first main section based on chosen structure]
|
||||||
|
|
||||||
|
[TODO: Add content here. See examples in existing skills:
|
||||||
|
- Code samples for technical skills
|
||||||
|
- Decision trees for complex workflows
|
||||||
|
- Concrete examples with realistic user requests
|
||||||
|
- References to scripts/templates/references as needed]
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
This skill includes example resource directories that demonstrate how to organize different types of bundled resources:
|
||||||
|
|
||||||
|
### scripts/
|
||||||
|
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
|
||||||
|
|
||||||
|
**Examples from other skills:**
|
||||||
|
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
|
||||||
|
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
|
||||||
|
|
||||||
|
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
|
||||||
|
|
||||||
|
**Note:** Scripts may be executed without loading into context, but can still be read by Claude for patching or environment adjustments.
|
||||||
|
|
||||||
|
### references/
|
||||||
|
Documentation and reference material intended to be loaded into context to inform Claude's process and thinking.
|
||||||
|
|
||||||
|
**Examples from other skills:**
|
||||||
|
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
|
||||||
|
- BigQuery: API reference documentation and query examples
|
||||||
|
- Finance: Schema documentation, company policies
|
||||||
|
|
||||||
|
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Claude should reference while working.
|
||||||
|
|
||||||
|
### assets/
|
||||||
|
Files not intended to be loaded into context, but rather used within the output Claude produces.
|
||||||
|
|
||||||
|
**Examples from other skills:**
|
||||||
|
- Brand styling: PowerPoint template files (.pptx), logo files
|
||||||
|
- Frontend builder: HTML/React boilerplate project directories
|
||||||
|
- Typography: Font files (.ttf, .woff2)
|
||||||
|
|
||||||
|
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Any unneeded directories can be deleted.** Not every skill requires all three types of resources.
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Example helper script for {skill_name}
|
||||||
|
|
||||||
|
This is a placeholder script that can be executed directly.
|
||||||
|
Replace with actual implementation or delete if not needed.
|
||||||
|
|
||||||
|
Example real scripts from other skills:
|
||||||
|
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
|
||||||
|
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("This is an example script for {skill_name}")
|
||||||
|
# TODO: Add actual script logic here
|
||||||
|
# This could be data processing, file conversion, API calls, etc.
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
'''
|
||||||
|
|
||||||
|
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
|
||||||
|
|
||||||
|
This is a placeholder for detailed reference documentation.
|
||||||
|
Replace with actual reference content or delete if not needed.
|
||||||
|
|
||||||
|
Example real reference docs from other skills:
|
||||||
|
- product-management/references/communication.md - Comprehensive guide for status updates
|
||||||
|
- product-management/references/context_building.md - Deep-dive on gathering context
|
||||||
|
- bigquery/references/ - API references and query examples
|
||||||
|
|
||||||
|
## When Reference Docs Are Useful
|
||||||
|
|
||||||
|
Reference docs are ideal for:
|
||||||
|
- Comprehensive API documentation
|
||||||
|
- Detailed workflow guides
|
||||||
|
- Complex multi-step processes
|
||||||
|
- Information too lengthy for main SKILL.md
|
||||||
|
- Content that's only needed for specific use cases
|
||||||
|
|
||||||
|
## Structure Suggestions
|
||||||
|
|
||||||
|
### API Reference Example
|
||||||
|
- Overview
|
||||||
|
- Authentication
|
||||||
|
- Endpoints with examples
|
||||||
|
- Error codes
|
||||||
|
- Rate limits
|
||||||
|
|
||||||
|
### Workflow Guide Example
|
||||||
|
- Prerequisites
|
||||||
|
- Step-by-step instructions
|
||||||
|
- Common patterns
|
||||||
|
- Troubleshooting
|
||||||
|
- Best practices
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXAMPLE_ASSET = """# Example Asset File
|
||||||
|
|
||||||
|
This placeholder represents where asset files would be stored.
|
||||||
|
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
|
||||||
|
|
||||||
|
Asset files are NOT intended to be loaded into context, but rather used within
|
||||||
|
the output Claude produces.
|
||||||
|
|
||||||
|
Example asset files from other skills:
|
||||||
|
- Brand guidelines: logo.png, slides_template.pptx
|
||||||
|
- Frontend builder: hello-world/ directory with HTML/React boilerplate
|
||||||
|
- Typography: custom-font.ttf, font-family.woff2
|
||||||
|
- Data: sample_data.csv, test_dataset.json
|
||||||
|
|
||||||
|
## Common Asset Types
|
||||||
|
|
||||||
|
- Templates: .pptx, .docx, boilerplate directories
|
||||||
|
- Images: .png, .jpg, .svg, .gif
|
||||||
|
- Fonts: .ttf, .otf, .woff, .woff2
|
||||||
|
- Boilerplate code: Project directories, starter files
|
||||||
|
- Icons: .ico, .svg
|
||||||
|
- Data files: .csv, .json, .xml, .yaml
|
||||||
|
|
||||||
|
Note: This is a text placeholder. Actual assets can be any file type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def title_case_skill_name(skill_name):
|
||||||
|
"""Convert hyphenated skill name to Title Case for display."""
|
||||||
|
return " ".join(word.capitalize() for word in skill_name.split("-"))
|
||||||
|
|
||||||
|
|
||||||
|
def init_skill(skill_name, path):
|
||||||
|
"""
|
||||||
|
Initialize a new skill directory with template SKILL.md.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill
|
||||||
|
path: Path where the skill directory should be created
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to created skill directory, or None if error
|
||||||
|
"""
|
||||||
|
# Determine skill directory path
|
||||||
|
skill_dir = Path(path).resolve() / skill_name
|
||||||
|
|
||||||
|
# Check if directory already exists
|
||||||
|
if skill_dir.exists():
|
||||||
|
print(f"❌ Error: Skill directory already exists: {skill_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create skill directory
|
||||||
|
try:
|
||||||
|
skill_dir.mkdir(parents=True, exist_ok=False)
|
||||||
|
print(f"✅ Created skill directory: {skill_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating directory: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create SKILL.md from template
|
||||||
|
skill_title = title_case_skill_name(skill_name)
|
||||||
|
skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title)
|
||||||
|
|
||||||
|
skill_md_path = skill_dir / "SKILL.md"
|
||||||
|
try:
|
||||||
|
skill_md_path.write_text(skill_content)
|
||||||
|
print("✅ Created SKILL.md")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating SKILL.md: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create resource directories with example files
|
||||||
|
try:
|
||||||
|
# Create scripts/ directory with example script
|
||||||
|
scripts_dir = skill_dir / "scripts"
|
||||||
|
scripts_dir.mkdir(exist_ok=True)
|
||||||
|
example_script = scripts_dir / "example.py"
|
||||||
|
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
|
||||||
|
example_script.chmod(0o755)
|
||||||
|
print("✅ Created scripts/example.py")
|
||||||
|
|
||||||
|
# Create references/ directory with example reference doc
|
||||||
|
references_dir = skill_dir / "references"
|
||||||
|
references_dir.mkdir(exist_ok=True)
|
||||||
|
example_reference = references_dir / "api_reference.md"
|
||||||
|
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
|
||||||
|
print("✅ Created references/api_reference.md")
|
||||||
|
|
||||||
|
# Create assets/ directory with example asset placeholder
|
||||||
|
assets_dir = skill_dir / "assets"
|
||||||
|
assets_dir.mkdir(exist_ok=True)
|
||||||
|
example_asset = assets_dir / "example_asset.txt"
|
||||||
|
example_asset.write_text(EXAMPLE_ASSET)
|
||||||
|
print("✅ Created assets/example_asset.txt")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating resource directories: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Print next steps
|
||||||
|
print(f"\n✅ Skill '{skill_name}' initialized successfully at {skill_dir}")
|
||||||
|
print("\nNext steps:")
|
||||||
|
print("1. Edit SKILL.md to complete the TODO items and update the description")
|
||||||
|
print("2. Customize or delete the example files in scripts/, references/, and assets/")
|
||||||
|
print("3. Run the validator when ready to check the skill structure")
|
||||||
|
|
||||||
|
return skill_dir
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 4 or sys.argv[2] != "--path":
|
||||||
|
print("Usage: init_skill.py <skill-name> --path <path>")
|
||||||
|
print("\nSkill name requirements:")
|
||||||
|
print(" - Hyphen-case identifier (e.g., 'data-analyzer')")
|
||||||
|
print(" - Lowercase letters, digits, and hyphens only")
|
||||||
|
print(" - Max 40 characters")
|
||||||
|
print(" - Must match directory name exactly")
|
||||||
|
print("\nExamples:")
|
||||||
|
print(" init_skill.py my-new-skill --path skills/public")
|
||||||
|
print(" init_skill.py my-api-helper --path skills/private")
|
||||||
|
print(" init_skill.py custom-skill --path /custom/location")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
skill_name = sys.argv[1]
|
||||||
|
path = sys.argv[3]
|
||||||
|
|
||||||
|
print(f"🚀 Initializing skill: {skill_name}")
|
||||||
|
print(f" Location: {path}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
result = init_skill(skill_name, path)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
110
.claude/skills/skill-creator/scripts/package_skill.py
Executable file
110
.claude/skills/skill-creator/scripts/package_skill.py
Executable file
@ -0,0 +1,110 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Skill Packager - Creates a distributable .skill file of a skill folder
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python utils/package_skill.py <path/to/skill-folder> [output-directory]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python utils/package_skill.py skills/public/my-skill
|
||||||
|
python utils/package_skill.py skills/public/my-skill ./dist
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
from quick_validate import validate_skill
|
||||||
|
|
||||||
|
|
||||||
|
def package_skill(skill_path, output_dir=None):
|
||||||
|
"""
|
||||||
|
Package a skill folder into a .skill file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_path: Path to the skill folder
|
||||||
|
output_dir: Optional output directory for the .skill file (defaults to current directory)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the created .skill file, or None if error
|
||||||
|
"""
|
||||||
|
skill_path = Path(skill_path).resolve()
|
||||||
|
|
||||||
|
# Validate skill folder exists
|
||||||
|
if not skill_path.exists():
|
||||||
|
print(f"❌ Error: Skill folder not found: {skill_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not skill_path.is_dir():
|
||||||
|
print(f"❌ Error: Path is not a directory: {skill_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate SKILL.md exists
|
||||||
|
skill_md = skill_path / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
print(f"❌ Error: SKILL.md not found in {skill_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Run validation before packaging
|
||||||
|
print("🔍 Validating skill...")
|
||||||
|
valid, message = validate_skill(skill_path)
|
||||||
|
if not valid:
|
||||||
|
print(f"❌ Validation failed: {message}")
|
||||||
|
print(" Please fix the validation errors before packaging.")
|
||||||
|
return None
|
||||||
|
print(f"✅ {message}\n")
|
||||||
|
|
||||||
|
# Determine output location
|
||||||
|
skill_name = skill_path.name
|
||||||
|
if output_dir:
|
||||||
|
output_path = Path(output_dir).resolve()
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
output_path = Path.cwd()
|
||||||
|
|
||||||
|
skill_filename = output_path / f"{skill_name}.skill"
|
||||||
|
|
||||||
|
# Create the .skill file (zip format)
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||||
|
# Walk through the skill directory
|
||||||
|
for file_path in skill_path.rglob("*"):
|
||||||
|
if file_path.is_file():
|
||||||
|
# Calculate the relative path within the zip
|
||||||
|
arcname = file_path.relative_to(skill_path.parent)
|
||||||
|
zipf.write(file_path, arcname)
|
||||||
|
print(f" Added: {arcname}")
|
||||||
|
|
||||||
|
print(f"\n✅ Successfully packaged skill to: {skill_filename}")
|
||||||
|
return skill_filename
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error creating .skill file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python utils/package_skill.py <path/to/skill-folder> [output-directory]")
|
||||||
|
print("\nExample:")
|
||||||
|
print(" python utils/package_skill.py skills/public/my-skill")
|
||||||
|
print(" python utils/package_skill.py skills/public/my-skill ./dist")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
skill_path = sys.argv[1]
|
||||||
|
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||||
|
|
||||||
|
print(f"📦 Packaging skill: {skill_path}")
|
||||||
|
if output_dir:
|
||||||
|
print(f" Output directory: {output_dir}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
result = package_skill(skill_path, output_dir)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
97
.claude/skills/skill-creator/scripts/quick_validate.py
Executable file
97
.claude/skills/skill-creator/scripts/quick_validate.py
Executable file
@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Quick validation script for skills - minimal version
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def validate_skill(skill_path):
|
||||||
|
"""Basic validation of a skill"""
|
||||||
|
skill_path = Path(skill_path)
|
||||||
|
|
||||||
|
# Check SKILL.md exists
|
||||||
|
skill_md = skill_path / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
return False, "SKILL.md not found"
|
||||||
|
|
||||||
|
# Read and validate frontmatter
|
||||||
|
content = skill_md.read_text()
|
||||||
|
if not content.startswith("---"):
|
||||||
|
return False, "No YAML frontmatter found"
|
||||||
|
|
||||||
|
# Extract frontmatter
|
||||||
|
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return False, "Invalid frontmatter format"
|
||||||
|
|
||||||
|
frontmatter_text = match.group(1)
|
||||||
|
|
||||||
|
# Parse YAML frontmatter
|
||||||
|
try:
|
||||||
|
frontmatter = yaml.safe_load(frontmatter_text)
|
||||||
|
if not isinstance(frontmatter, dict):
|
||||||
|
return False, "Frontmatter must be a YAML dictionary"
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
return False, f"Invalid YAML in frontmatter: {e}"
|
||||||
|
|
||||||
|
# Define allowed properties
|
||||||
|
ALLOWED_PROPERTIES = {"name", "description", "license", "allowed-tools", "metadata"}
|
||||||
|
|
||||||
|
# Check for unexpected properties (excluding nested keys under metadata)
|
||||||
|
unexpected_keys = set(frontmatter.keys()) - ALLOWED_PROPERTIES
|
||||||
|
if unexpected_keys:
|
||||||
|
return False, (
|
||||||
|
f"Unexpected key(s) in SKILL.md frontmatter: {', '.join(sorted(unexpected_keys))}. "
|
||||||
|
f"Allowed properties are: {', '.join(sorted(ALLOWED_PROPERTIES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check required fields
|
||||||
|
if "name" not in frontmatter:
|
||||||
|
return False, "Missing 'name' in frontmatter"
|
||||||
|
if "description" not in frontmatter:
|
||||||
|
return False, "Missing 'description' in frontmatter"
|
||||||
|
|
||||||
|
# Extract name for validation
|
||||||
|
name = frontmatter.get("name", "")
|
||||||
|
if not isinstance(name, str):
|
||||||
|
return False, f"Name must be a string, got {type(name).__name__}"
|
||||||
|
name = name.strip()
|
||||||
|
if name:
|
||||||
|
# Check naming convention (hyphen-case: lowercase with hyphens)
|
||||||
|
if not re.match(r"^[a-z0-9-]+$", name):
|
||||||
|
return False, f"Name '{name}' should be hyphen-case (lowercase letters, digits, and hyphens only)"
|
||||||
|
if name.startswith("-") or name.endswith("-") or "--" in name:
|
||||||
|
return False, f"Name '{name}' cannot start/end with hyphen or contain consecutive hyphens"
|
||||||
|
# Check name length (max 64 characters per spec)
|
||||||
|
if len(name) > 64:
|
||||||
|
return False, f"Name is too long ({len(name)} characters). Maximum is 64 characters."
|
||||||
|
|
||||||
|
# Extract and validate description
|
||||||
|
description = frontmatter.get("description", "")
|
||||||
|
if not isinstance(description, str):
|
||||||
|
return False, f"Description must be a string, got {type(description).__name__}"
|
||||||
|
description = description.strip()
|
||||||
|
if description:
|
||||||
|
# Check for angle brackets
|
||||||
|
if "<" in description or ">" in description:
|
||||||
|
return False, "Description cannot contain angle brackets (< or >)"
|
||||||
|
# Check description length (max 1024 characters per spec)
|
||||||
|
if len(description) > 1024:
|
||||||
|
return False, f"Description is too long ({len(description)} characters). Maximum is 1024 characters."
|
||||||
|
|
||||||
|
return True, "Skill is valid!"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python quick_validate.py <skill_directory>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
valid, message = validate_skill(sys.argv[1])
|
||||||
|
print(message)
|
||||||
|
sys.exit(0 if valid else 1)
|
||||||
@ -53,7 +53,6 @@ ignore_imports =
|
|||||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||||
core.workflow.nodes.variable_assigner.common.impl -> extensions.ext_database
|
|
||||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -96,14 +98,13 @@ class LoginApi(Resource):
|
|||||||
if is_login_error_rate_limit:
|
if is_login_error_rate_limit:
|
||||||
raise EmailPasswordLoginLimitError()
|
raise EmailPasswordLoginLimitError()
|
||||||
|
|
||||||
# TODO: why invitation is re-assigned with different type?
|
invitation_data: dict[str, Any] | None = None
|
||||||
invitation = args.invite_token # type: ignore
|
if args.invite_token:
|
||||||
if invitation:
|
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if invitation:
|
if invitation_data:
|
||||||
data = invitation.get("data", {}) # type: ignore
|
data = invitation_data.get("data", {})
|
||||||
invitee_email = data.get("email") if data else None
|
invitee_email = data.get("email") if data else None
|
||||||
if invitee_email != args.email:
|
if invitee_email != args.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|||||||
@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
args=args,
|
args=args,
|
||||||
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
|
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -15,18 +15,21 @@ from controllers.common.errors import (
|
|||||||
TooManyFilesError,
|
TooManyFilesError,
|
||||||
UnsupportedFileTypeError,
|
UnsupportedFileTypeError,
|
||||||
)
|
)
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from fields.file_fields import FileResponse, UploadConfig
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
from . import console_ns
|
from . import console_ns
|
||||||
|
|
||||||
|
register_schema_models(console_ns, UploadConfig, FileResponse)
|
||||||
|
|
||||||
PREVIEW_WORDS_LIMIT = 3000
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
||||||
|
|
||||||
@ -35,26 +38,27 @@ class FileApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(upload_config_fields)
|
@console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__])
|
||||||
def get(self):
|
def get(self):
|
||||||
return {
|
config = UploadConfig(
|
||||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||||
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
|
file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT,
|
||||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||||
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
|
image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT,
|
||||||
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
||||||
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
||||||
}, 200
|
)
|
||||||
|
return config.model_dump(mode="json"), 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(file_fields)
|
|
||||||
@cloud_edition_billing_resource_check("documents")
|
@cloud_edition_billing_resource_check("documents")
|
||||||
|
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
source_str = request.form.get("source")
|
source_str = request.form.get("source")
|
||||||
@ -90,7 +94,8 @@ class FileApi(Resource):
|
|||||||
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
||||||
raise BlockedFileExtensionError(blocked_extension_error.description)
|
raise BlockedFileExtensionError(blocked_extension_error.description)
|
||||||
|
|
||||||
return upload_file, 201
|
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||||
|
return response.model_dump(mode="json"), 201
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/files/<uuid:file_id>/preview")
|
@console_ns.route("/files/<uuid:file_id>/preview")
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -11,19 +11,22 @@ from controllers.common.errors import (
|
|||||||
RemoteFileUploadError,
|
RemoteFileUploadError,
|
||||||
UnsupportedFileTypeError,
|
UnsupportedFileTypeError,
|
||||||
)
|
)
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
from . import console_ns
|
from . import console_ns
|
||||||
|
|
||||||
|
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/remote-files/<path:url>")
|
@console_ns.route("/remote-files/<path:url>")
|
||||||
class RemoteFileInfoApi(Resource):
|
class RemoteFileInfoApi(Resource):
|
||||||
@marshal_with(remote_file_info_fields)
|
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
|
||||||
def get(self, url):
|
def get(self, url):
|
||||||
decoded_url = urllib.parse.unquote(url)
|
decoded_url = urllib.parse.unquote(url)
|
||||||
resp = ssrf_proxy.head(decoded_url)
|
resp = ssrf_proxy.head(decoded_url)
|
||||||
@ -31,10 +34,11 @@ class RemoteFileInfoApi(Resource):
|
|||||||
# failed back to get method
|
# failed back to get method
|
||||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return {
|
info = RemoteFileInfo(
|
||||||
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
|
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||||
"file_length": int(resp.headers.get("Content-Length", 0)),
|
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||||
}
|
)
|
||||||
|
return info.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
class RemoteFileUploadPayload(BaseModel):
|
class RemoteFileUploadPayload(BaseModel):
|
||||||
@ -50,7 +54,7 @@ console_ns.schema_model(
|
|||||||
@console_ns.route("/remote-files/upload")
|
@console_ns.route("/remote-files/upload")
|
||||||
class RemoteFileUploadApi(Resource):
|
class RemoteFileUploadApi(Resource):
|
||||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||||
@marshal_with(file_fields_with_signed_url)
|
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||||
url = args.url
|
url = args.url
|
||||||
@ -85,13 +89,14 @@ class RemoteFileUploadApi(Resource):
|
|||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
return {
|
payload = FileWithSignedUrl(
|
||||||
"id": upload_file.id,
|
id=upload_file.id,
|
||||||
"name": upload_file.name,
|
name=upload_file.name,
|
||||||
"size": upload_file.size,
|
size=upload_file.size,
|
||||||
"extension": upload_file.extension,
|
extension=upload_file.extension,
|
||||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||||
"mime_type": upload_file.mime_type,
|
mime_type=upload_file.mime_type,
|
||||||
"created_by": upload_file.created_by,
|
created_by=upload_file.created_by,
|
||||||
"created_at": upload_file.created_at,
|
created_at=int(upload_file.created_at.timestamp()),
|
||||||
}, 201
|
)
|
||||||
|
return payload.model_dump(mode="json"), 201
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel):
|
|||||||
repeat_new_password: str
|
repeat_new_password: str
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_passwords_match(self) -> "AccountPasswordPayload":
|
def check_passwords_match(self) -> AccountPasswordPayload:
|
||||||
if self.new_password != self.repeat_new_password:
|
if self.new_password != self.repeat_new_password:
|
||||||
raise RepeatPasswordNotMatchError()
|
raise RepeatPasswordNotMatchError()
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -4,18 +4,18 @@ from flask import request
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from flask_restx.api import HTTPStatus
|
from flask_restx.api import HTTPStatus
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.datastructures import FileStorage
|
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from core.file.helpers import verify_plugin_file_signature
|
from core.file.helpers import verify_plugin_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from fields.file_fields import build_file_model
|
from fields.file_fields import FileResponse
|
||||||
|
|
||||||
from ..common.errors import (
|
from ..common.errors import (
|
||||||
FileTooLargeError,
|
FileTooLargeError,
|
||||||
UnsupportedFileTypeError,
|
UnsupportedFileTypeError,
|
||||||
)
|
)
|
||||||
|
from ..common.schema import register_schema_models
|
||||||
from ..console.wraps import setup_required
|
from ..console.wraps import setup_required
|
||||||
from ..files import files_ns
|
from ..files import files_ns
|
||||||
from ..inner_api.plugin.wraps import get_user
|
from ..inner_api.plugin.wraps import get_user
|
||||||
@ -35,6 +35,8 @@ files_ns.schema_model(
|
|||||||
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_schema_models(files_ns, FileResponse)
|
||||||
|
|
||||||
|
|
||||||
@files_ns.route("/upload/for-plugin")
|
@files_ns.route("/upload/for-plugin")
|
||||||
class PluginUploadFileApi(Resource):
|
class PluginUploadFileApi(Resource):
|
||||||
@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource):
|
|||||||
415: "Unsupported file type",
|
415: "Unsupported file type",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED)
|
@files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Upload a file for plugin usage.
|
"""Upload a file for plugin usage.
|
||||||
|
|
||||||
@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource):
|
|||||||
"""
|
"""
|
||||||
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
file: FileStorage | None = request.files.get("file")
|
file = request.files.get("file")
|
||||||
if file is None:
|
if file is None:
|
||||||
raise Forbidden("File is required.")
|
raise Forbidden("File is required.")
|
||||||
|
|
||||||
@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource):
|
|||||||
user_id = args.user_id
|
user_id = args.user_id
|
||||||
user = get_user(tenant_id, user_id)
|
user = get_user(tenant_id, user_id)
|
||||||
|
|
||||||
filename: str | None = file.filename
|
filename = file.filename
|
||||||
mimetype: str | None = file.mimetype
|
mimetype = file.mimetype
|
||||||
|
|
||||||
if not filename or not mimetype:
|
if not filename or not mimetype:
|
||||||
raise Forbidden("Invalid request.")
|
raise Forbidden("Invalid request.")
|
||||||
@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource):
|
|||||||
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
|
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
|
||||||
|
|
||||||
# Create a dictionary with all the necessary attributes
|
# Create a dictionary with all the necessary attributes
|
||||||
result = {
|
result = FileResponse(
|
||||||
"id": tool_file.id,
|
id=tool_file.id,
|
||||||
"user_id": tool_file.user_id,
|
name=tool_file.name,
|
||||||
"tenant_id": tool_file.tenant_id,
|
size=tool_file.size,
|
||||||
"conversation_id": tool_file.conversation_id,
|
extension=extension,
|
||||||
"file_key": tool_file.file_key,
|
mime_type=mimetype,
|
||||||
"mimetype": tool_file.mimetype,
|
preview_url=preview_url,
|
||||||
"original_url": tool_file.original_url,
|
source_url=tool_file.original_url,
|
||||||
"name": tool_file.name,
|
original_url=tool_file.original_url,
|
||||||
"size": tool_file.size,
|
user_id=tool_file.user_id,
|
||||||
"mime_type": mimetype,
|
tenant_id=tool_file.tenant_id,
|
||||||
"extension": extension,
|
conversation_id=tool_file.conversation_id,
|
||||||
"preview_url": preview_url,
|
file_key=tool_file.file_key,
|
||||||
}
|
)
|
||||||
|
|
||||||
return result, 201
|
return result.model_dump(mode="json"), 201
|
||||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
|||||||
@ -10,13 +10,16 @@ from controllers.common.errors import (
|
|||||||
TooManyFilesError,
|
TooManyFilesError,
|
||||||
UnsupportedFileTypeError,
|
UnsupportedFileTypeError,
|
||||||
)
|
)
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_model
|
from fields.file_fields import FileResponse
|
||||||
from models import App, EndUser
|
from models import App, EndUser
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
register_schema_models(service_api_ns, FileResponse)
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/files/upload")
|
@service_api_ns.route("/files/upload")
|
||||||
class FileApi(Resource):
|
class FileApi(Resource):
|
||||||
@ -31,8 +34,8 @@ class FileApi(Resource):
|
|||||||
415: "Unsupported file type",
|
415: "Unsupported file type",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
|
||||||
@service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED)
|
@service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
|
||||||
def post(self, app_model: App, end_user: EndUser):
|
def post(self, app_model: App, end_user: EndUser):
|
||||||
"""Upload a file for use in conversations.
|
"""Upload a file for use in conversations.
|
||||||
|
|
||||||
@ -64,4 +67,5 @@ class FileApi(Resource):
|
|||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
return upload_file, 201
|
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||||
|
return response.model_dump(mode="json"), 201
|
||||||
|
|||||||
@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource):
|
|||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
args=payload.model_dump(),
|
args=payload.model_dump(),
|
||||||
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
|
invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER,
|
||||||
streaming=payload.response_mode == "streaming",
|
streaming=payload.response_mode == "streaming",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import marshal_with
|
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.common.errors import (
|
from controllers.common.errors import (
|
||||||
@ -9,12 +8,15 @@ from controllers.common.errors import (
|
|||||||
TooManyFilesError,
|
TooManyFilesError,
|
||||||
UnsupportedFileTypeError,
|
UnsupportedFileTypeError,
|
||||||
)
|
)
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_model
|
from fields.file_fields import FileResponse
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
register_schema_models(web_ns, FileResponse)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/files/upload")
|
@web_ns.route("/files/upload")
|
||||||
class FileApi(WebApiResource):
|
class FileApi(WebApiResource):
|
||||||
@ -28,7 +30,7 @@ class FileApi(WebApiResource):
|
|||||||
415: "Unsupported file type",
|
415: "Unsupported file type",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@marshal_with(build_file_model(web_ns))
|
@web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
|
||||||
def post(self, app_model, end_user):
|
def post(self, app_model, end_user):
|
||||||
"""Upload a file for use in web applications.
|
"""Upload a file for use in web applications.
|
||||||
|
|
||||||
@ -81,4 +83,5 @@ class FileApi(WebApiResource):
|
|||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
return upload_file, 201
|
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||||
|
return response.model_dump(mode="json"), 201
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_restx import marshal_with
|
|
||||||
from pydantic import BaseModel, Field, HttpUrl
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -14,7 +13,7 @@ from controllers.common.errors import (
|
|||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
from ..common.schema import register_schema_models
|
from ..common.schema import register_schema_models
|
||||||
@ -26,7 +25,7 @@ class RemoteFileUploadPayload(BaseModel):
|
|||||||
url: HttpUrl = Field(description="Remote file URL")
|
url: HttpUrl = Field(description="Remote file URL")
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(web_ns, RemoteFileUploadPayload)
|
register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/remote-files/<path:url>")
|
@web_ns.route("/remote-files/<path:url>")
|
||||||
@ -41,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource):
|
|||||||
500: "Failed to fetch remote file",
|
500: "Failed to fetch remote file",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@marshal_with(build_remote_file_info_model(web_ns))
|
@web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__])
|
||||||
def get(self, app_model, end_user, url):
|
def get(self, app_model, end_user, url):
|
||||||
"""Get information about a remote file.
|
"""Get information about a remote file.
|
||||||
|
|
||||||
@ -65,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource):
|
|||||||
# failed back to get method
|
# failed back to get method
|
||||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return {
|
info = RemoteFileInfo(
|
||||||
"file_type": resp.headers.get("Content-Type", "application/octet-stream"),
|
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||||
"file_length": int(resp.headers.get("Content-Length", -1)),
|
file_length=int(resp.headers.get("Content-Length", -1)),
|
||||||
}
|
)
|
||||||
|
return info.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/remote-files/upload")
|
@web_ns.route("/remote-files/upload")
|
||||||
@ -84,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource):
|
|||||||
500: "Failed to fetch remote file",
|
500: "Failed to fetch remote file",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@marshal_with(build_file_with_signed_url_model(web_ns))
|
@web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__])
|
||||||
def post(self, app_model, end_user):
|
def post(self, app_model, end_user):
|
||||||
"""Upload a file from a remote URL.
|
"""Upload a file from a remote URL.
|
||||||
|
|
||||||
@ -139,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource):
|
|||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError
|
raise UnsupportedFileTypeError
|
||||||
|
|
||||||
return {
|
payload1 = FileWithSignedUrl(
|
||||||
"id": upload_file.id,
|
id=upload_file.id,
|
||||||
"name": upload_file.name,
|
name=upload_file.name,
|
||||||
"size": upload_file.size,
|
size=upload_file.size,
|
||||||
"extension": upload_file.extension,
|
extension=upload_file.extension,
|
||||||
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||||
"mime_type": upload_file.mime_type,
|
mime_type=upload_file.mime_type,
|
||||||
"created_by": upload_file.created_by,
|
created_by=upload_file.created_by,
|
||||||
"created_at": upload_file.created_at,
|
created_at=int(upload_file.created_at.timestamp()),
|
||||||
}, 201
|
)
|
||||||
|
return payload1.model_dump(mode="json"), 201
|
||||||
|
|||||||
@ -20,6 +20,8 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
)
|
)
|
||||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||||
|
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
from core.moderation.input_moderation import InputModeration
|
from core.moderation.input_moderation import InputModeration
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
@ -40,6 +42,7 @@ from models import Workflow
|
|||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.model import App, Conversation, Message, MessageAnnotation
|
from models.model import App, Conversation, Message, MessageAnnotation
|
||||||
from models.workflow import ConversationVariable
|
from models.workflow import ConversationVariable
|
||||||
|
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -200,6 +203,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
workflow_entry.graph_engine.layer(persistence_layer)
|
workflow_entry.graph_engine.layer(persistence_layer)
|
||||||
|
conversation_variable_layer = ConversationVariablePersistenceLayer(
|
||||||
|
ConversationVariableUpdater(session_factory.get_session_maker())
|
||||||
|
)
|
||||||
|
workflow_entry.graph_engine.layer(conversation_variable_layer)
|
||||||
for layer in self._graph_engine_layers:
|
for layer in self._graph_engine_layers:
|
||||||
workflow_entry.graph_engine.layer(layer)
|
workflow_entry.graph_engine.layer(layer)
|
||||||
|
|
||||||
|
|||||||
@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||||
)
|
)
|
||||||
documents: list[Document] = []
|
documents: list[Document] = []
|
||||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
|
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"):
|
||||||
from services.dataset_service import DocumentService
|
from services.dataset_service import DocumentService
|
||||||
|
|
||||||
for datasource_info in datasource_info_list:
|
for datasource_info in datasource_info_list:
|
||||||
@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||||||
for i, datasource_info in enumerate(datasource_info_list):
|
for i, datasource_info in enumerate(datasource_info_list):
|
||||||
workflow_run_id = str(uuid.uuid4())
|
workflow_run_id = str(uuid.uuid4())
|
||||||
document_id = args.get("original_document_id") or None
|
document_id = args.get("original_document_id") or None
|
||||||
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
|
if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry:
|
||||||
document_id = document_id or documents[i].id
|
document_id = document_id or documents[i].id
|
||||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
|
|||||||
@ -42,7 +42,8 @@ class InvokeFrom(StrEnum):
|
|||||||
# DEBUGGER indicates that this invocation is from
|
# DEBUGGER indicates that this invocation is from
|
||||||
# the workflow (or chatflow) edit page.
|
# the workflow (or chatflow) edit page.
|
||||||
DEBUGGER = "debugger"
|
DEBUGGER = "debugger"
|
||||||
PUBLISHED = "published"
|
# PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
|
||||||
|
PUBLISHED_PIPELINE = "published"
|
||||||
|
|
||||||
# VALIDATION indicates that this invocation is from validation.
|
# VALIDATION indicates that this invocation is from validation.
|
||||||
VALIDATION = "validation"
|
VALIDATION = "validation"
|
||||||
|
|||||||
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
60
api/core/app/layers/conversation_variable_persist_layer.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.variables import Variable
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
|
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
|
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||||
|
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||||
|
def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._conversation_variable_updater = conversation_variable_updater
|
||||||
|
|
||||||
|
def on_graph_start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_event(self, event: GraphEngineEvent) -> None:
|
||||||
|
if not isinstance(event, NodeRunSucceededEvent):
|
||||||
|
return
|
||||||
|
if event.node_type != NodeType.VARIABLE_ASSIGNER:
|
||||||
|
return
|
||||||
|
if self.graph_runtime_state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
|
||||||
|
if not updated_variables:
|
||||||
|
return
|
||||||
|
|
||||||
|
conversation_id = self.graph_runtime_state.system_variable.conversation_id
|
||||||
|
if conversation_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
updated_any = False
|
||||||
|
for item in updated_variables:
|
||||||
|
selector = item.selector
|
||||||
|
if len(selector) < 2:
|
||||||
|
logger.warning("Conversation variable selector invalid. selector=%s", selector)
|
||||||
|
continue
|
||||||
|
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||||
|
continue
|
||||||
|
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||||
|
if not isinstance(variable, Variable):
|
||||||
|
logger.warning(
|
||||||
|
"Conversation variable not found in variable pool. selector=%s",
|
||||||
|
selector,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
|
||||||
|
updated_any = True
|
||||||
|
|
||||||
|
if updated_any:
|
||||||
|
self._conversation_variable_updater.flush()
|
||||||
|
|
||||||
|
def on_graph_end(self, error: Exception | None) -> None:
|
||||||
|
pass
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
|
|||||||
"""
|
"""
|
||||||
return DatasourceProviderType.LOCAL_FILE
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
entity=self.entity.model_copy(),
|
entity=self.entity.model_copy(),
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
|
|||||||
ONLINE_DRIVE = "online_drive"
|
ONLINE_DRIVE = "online_drive"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
def value_of(cls, value: str) -> DatasourceProviderType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
|
|||||||
typ: DatasourceParameterType,
|
typ: DatasourceParameterType,
|
||||||
required: bool,
|
required: bool,
|
||||||
options: list[str] | None = None,
|
options: list[str] | None = None,
|
||||||
) -> "DatasourceParameter":
|
) -> DatasourceParameter:
|
||||||
"""
|
"""
|
||||||
get a simple datasource parameter
|
get a simple datasource parameter
|
||||||
|
|
||||||
@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
|
|||||||
tool_config: dict | None = None
|
tool_config: dict | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "DatasourceInvokeMeta":
|
def empty(cls) -> DatasourceInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an empty instance of DatasourceInvokeMeta
|
Get an empty instance of DatasourceInvokeMeta
|
||||||
"""
|
"""
|
||||||
return cls(time_cost=0.0, error=None, tool_config={})
|
return cls(time_cost=0.0, error=None, tool_config={})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
def error_instance(cls, error: str) -> DatasourceInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an instance of DatasourceInvokeMeta with error
|
Get an instance of DatasourceInvokeMeta with error
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
_session_maker: sessionmaker | None = None
|
_session_maker: sessionmaker[Session] | None = None
|
||||||
|
|
||||||
|
|
||||||
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
||||||
@ -10,7 +10,7 @@ def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
|
|||||||
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
_session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
|
||||||
|
|
||||||
|
|
||||||
def get_session_maker() -> sessionmaker:
|
def get_session_maker() -> sessionmaker[Session]:
|
||||||
if _session_maker is None:
|
if _session_maker is None:
|
||||||
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
|
||||||
return _session_maker
|
return _session_maker
|
||||||
@ -27,7 +27,7 @@ class SessionFactory:
|
|||||||
configure_session_factory(engine, expire_on_commit)
|
configure_session_factory(engine, expire_on_commit)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_session_maker() -> sessionmaker:
|
def get_session_maker() -> sessionmaker[Session]:
|
||||||
return get_session_maker()
|
return get_session_maker()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
|
|||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
|
||||||
"""Create entity from database model with decryption"""
|
"""Create entity from database model with decryption"""
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
|
|||||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
def value_of(cls, value: str) -> ProviderConfig.Type:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,9 @@ import urllib.parse
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
|
def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
|
||||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||||
|
url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||||
|
|
||||||
timestamp = str(int(time.time()))
|
timestamp = str(int(time.time()))
|
||||||
nonce = os.urandom(16).hex()
|
nonce = os.urandom(16).hex()
|
||||||
|
|||||||
@ -112,17 +112,17 @@ class File(BaseModel):
|
|||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def generate_url(self) -> str | None:
|
def generate_url(self, for_external: bool = True) -> str | None:
|
||||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
return self.remote_url
|
return self.remote_url
|
||||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
if self.related_id is None:
|
if self.related_id is None:
|
||||||
raise ValueError("Missing file related_id")
|
raise ValueError("Missing file related_id")
|
||||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
|
||||||
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
|
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
|
||||||
assert self.related_id is not None
|
assert self.related_id is not None
|
||||||
assert self.extension is not None
|
assert self.extension is not None
|
||||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||||
@ -133,7 +133,7 @@ class File(BaseModel):
|
|||||||
"extension": self.extension,
|
"extension": self.extension,
|
||||||
"size": self.size,
|
"size": self.size,
|
||||||
"type": self.type,
|
"type": self.type,
|
||||||
"url": self.generate_url(),
|
"url": self.generate_url(for_external=False),
|
||||||
}
|
}
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@ -76,7 +76,7 @@ class TemplateTransformer(ABC):
|
|||||||
Post-process the result to convert scientific notation strings back to numbers
|
Post-process the result to convert scientific notation strings back to numbers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def convert_scientific_notation(value):
|
def convert_scientific_notation(value: Any) -> Any:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
# Check if the string looks like scientific notation
|
# Check if the string looks like scientific notation
|
||||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||||
@ -90,7 +90,7 @@ class TemplateTransformer(ABC):
|
|||||||
return [convert_scientific_notation(v) for v in value]
|
return [convert_scientific_notation(v) for v in value]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return convert_scientific_notation(result) # type: ignore[no-any-return]
|
return convert_scientific_notation(result)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
request_meta: RequestParams.Meta | None,
|
request_meta: RequestParams.Meta | None,
|
||||||
request: ReceiveRequestT,
|
request: ReceiveRequestT,
|
||||||
session: """BaseSession[
|
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
|
||||||
SendRequestT,
|
|
||||||
SendNotificationT,
|
|
||||||
SendResultT,
|
|
||||||
ReceiveRequestT,
|
|
||||||
ReceiveNotificationT
|
|
||||||
]""",
|
|
||||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||||
):
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
|
|||||||
TOOL = auto()
|
TOOL = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "PromptMessageRole":
|
def value_of(cls, value: str) -> PromptMessageRole:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -20,7 +22,7 @@ class ModelType(StrEnum):
|
|||||||
TTS = auto()
|
TTS = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
def value_of(cls, origin_model_type: str) -> ModelType:
|
||||||
"""
|
"""
|
||||||
Get model type from origin model type.
|
Get model type from origin model type.
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
|
|||||||
JSON_SCHEMA = auto()
|
JSON_SCHEMA = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: Any) -> "DefaultParameterName":
|
def value_of(cls, value: Any) -> DefaultParameterName:
|
||||||
"""
|
"""
|
||||||
Get parameter name from value.
|
Get parameter name from value.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@ -38,7 +40,7 @@ class ModelProviderFactory:
|
|||||||
plugin_providers = self.get_plugin_model_providers()
|
plugin_providers = self.get_plugin_model_providers()
|
||||||
return [provider.declaration for provider in plugin_providers]
|
return [provider.declaration for provider in plugin_providers]
|
||||||
|
|
||||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get all plugin model providers
|
Get all plugin model providers
|
||||||
:return: list of plugin model providers
|
:return: list of plugin model providers
|
||||||
@ -76,7 +78,7 @@ class ModelProviderFactory:
|
|||||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||||
return plugin_model_provider_entity.declaration
|
return plugin_model_provider_entity.declaration
|
||||||
|
|
||||||
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||||
"""
|
"""
|
||||||
Get plugin model provider
|
Get plugin model provider
|
||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
|
|||||||
return [item.value for item in cls]
|
return [item.value for item in cls]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def of(cls, credential_type: str) -> "CredentialType":
|
def of(cls, credential_type: str) -> CredentialType:
|
||||||
type_name = credential_type.lower()
|
type_name = credential_type.lower()
|
||||||
if type_name in {"api-key", "api_key"}:
|
if type_name in {"api-key", "api_key"}:
|
||||||
return cls.API_KEY
|
return cls.API_KEY
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -6,7 +8,7 @@ import re
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import clickzetta # type: ignore
|
import clickzetta # type: ignore
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
|
|||||||
Manages connection reuse across ClickzettaVector instances.
|
Manages connection reuse across ClickzettaVector instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: Optional["ClickzettaConnectionPool"] = None
|
_instance: ClickzettaConnectionPool | None = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
|
|||||||
self._start_cleanup_thread()
|
self._start_cleanup_thread()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> "ClickzettaConnectionPool":
|
def get_instance(cls) -> ClickzettaConnectionPool:
|
||||||
"""Get singleton instance of connection pool."""
|
"""Get singleton instance of connection pool."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
|
|||||||
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
|
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
|
def _create_connection(self, config: ClickzettaConfig) -> Connection:
|
||||||
"""Create a new ClickZetta connection."""
|
"""Create a new ClickZetta connection."""
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
retry_delay = 1.0
|
retry_delay = 1.0
|
||||||
@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
|
|||||||
|
|
||||||
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
|
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
|
||||||
|
|
||||||
def _configure_connection(self, connection: "Connection"):
|
def _configure_connection(self, connection: Connection):
|
||||||
"""Configure connection session settings."""
|
"""Configure connection session settings."""
|
||||||
try:
|
try:
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to configure connection, continuing with defaults")
|
logger.exception("Failed to configure connection, continuing with defaults")
|
||||||
|
|
||||||
def _is_connection_valid(self, connection: "Connection") -> bool:
|
def _is_connection_valid(self, connection: Connection) -> bool:
|
||||||
"""Check if connection is still valid."""
|
"""Check if connection is still valid."""
|
||||||
try:
|
try:
|
||||||
with connection.cursor() as cursor:
|
with connection.cursor() as cursor:
|
||||||
@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_connection(self, config: ClickzettaConfig) -> "Connection":
|
def get_connection(self, config: ClickzettaConfig) -> Connection:
|
||||||
"""Get a connection from the pool or create a new one."""
|
"""Get a connection from the pool or create a new one."""
|
||||||
config_key = self._get_config_key(config)
|
config_key = self._get_config_key(config)
|
||||||
|
|
||||||
@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
|
|||||||
# No valid connection found, create new one
|
# No valid connection found, create new one
|
||||||
return self._create_connection(config)
|
return self._create_connection(config)
|
||||||
|
|
||||||
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
|
def return_connection(self, config: ClickzettaConfig, connection: Connection):
|
||||||
"""Return a connection to the pool."""
|
"""Return a connection to the pool."""
|
||||||
config_key = self._get_config_key(config)
|
config_key = self._get_config_key(config)
|
||||||
|
|
||||||
@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
|
|||||||
self._connection_pool = ClickzettaConnectionPool.get_instance()
|
self._connection_pool = ClickzettaConnectionPool.get_instance()
|
||||||
self._init_write_queue()
|
self._init_write_queue()
|
||||||
|
|
||||||
def _get_connection(self) -> "Connection":
|
def _get_connection(self) -> Connection:
|
||||||
"""Get a connection from the pool."""
|
"""Get a connection from the pool."""
|
||||||
return self._connection_pool.get_connection(self._config)
|
return self._connection_pool.get_connection(self._config)
|
||||||
|
|
||||||
def _return_connection(self, connection: "Connection"):
|
def _return_connection(self, connection: Connection):
|
||||||
"""Return a connection to the pool."""
|
"""Return a connection to the pool."""
|
||||||
self._connection_pool.return_connection(self._config, connection)
|
self._connection_pool.return_connection(self._config, connection)
|
||||||
|
|
||||||
class ConnectionContext:
|
class ConnectionContext:
|
||||||
"""Context manager for borrowing and returning connections."""
|
"""Context manager for borrowing and returning connections."""
|
||||||
|
|
||||||
def __init__(self, vector_instance: "ClickzettaVector"):
|
def __init__(self, vector_instance: ClickzettaVector):
|
||||||
self.vector = vector_instance
|
self.vector = vector_instance
|
||||||
self.connection: Connection | None = None
|
self.connection: Connection | None = None
|
||||||
|
|
||||||
def __enter__(self) -> "Connection":
|
def __enter__(self) -> Connection:
|
||||||
self.connection = self.vector._get_connection()
|
self.connection = self.vector._get_connection()
|
||||||
return self.connection
|
return self.connection
|
||||||
|
|
||||||
@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
|
|||||||
if self.connection:
|
if self.connection:
|
||||||
self.vector._return_connection(self.connection)
|
self.vector._return_connection(self.connection)
|
||||||
|
|
||||||
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
|
def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
|
||||||
"""Get a connection context manager."""
|
"""Get a connection context manager."""
|
||||||
return self.ConnectionContext(self)
|
return self.ConnectionContext(self)
|
||||||
|
|
||||||
@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
|
|||||||
"""Return the vector database type."""
|
"""Return the vector database type."""
|
||||||
return "clickzetta"
|
return "clickzetta"
|
||||||
|
|
||||||
def _ensure_connection(self) -> "Connection":
|
def _ensure_connection(self) -> Connection:
|
||||||
"""Get a connection from the pool."""
|
"""Get a connection from the pool."""
|
||||||
return self._get_connection()
|
return self._get_connection()
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -22,7 +24,7 @@ class DatasetDocumentStore:
|
|||||||
self._document_id = document_id
|
self._document_id = document_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
|
||||||
return cls(**config_dict)
|
return cls(**config_dict)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
|||||||
@ -7,10 +7,11 @@ import re
|
|||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from xml.etree import ElementTree
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from docx import Document as DocxDocument
|
from docx import Document as DocxDocument
|
||||||
|
from docx.oxml.ns import qn
|
||||||
|
from docx.text.run import Run
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
@ -229,44 +230,20 @@ class WordExtractor(BaseExtractor):
|
|||||||
|
|
||||||
image_map = self._extract_images_from_docx(doc)
|
image_map = self._extract_images_from_docx(doc)
|
||||||
|
|
||||||
hyperlinks_url = None
|
|
||||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
|
||||||
for para in doc.paragraphs:
|
|
||||||
for run in para.runs:
|
|
||||||
if run.text and hyperlinks_url:
|
|
||||||
result = f" [{run.text}]({hyperlinks_url}) "
|
|
||||||
run.text = result
|
|
||||||
hyperlinks_url = None
|
|
||||||
if "HYPERLINK" in run.element.xml:
|
|
||||||
try:
|
|
||||||
xml = ElementTree.XML(run.element.xml)
|
|
||||||
x_child = [c for c in xml.iter() if c is not None]
|
|
||||||
for x in x_child:
|
|
||||||
if x is None:
|
|
||||||
continue
|
|
||||||
if x.tag.endswith("instrText"):
|
|
||||||
if x.text is None:
|
|
||||||
continue
|
|
||||||
for i in url_pattern.findall(x.text):
|
|
||||||
hyperlinks_url = str(i)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to parse HYPERLINK xml")
|
|
||||||
|
|
||||||
def parse_paragraph(paragraph):
|
def parse_paragraph(paragraph):
|
||||||
paragraph_content = []
|
def append_image_link(image_id, has_drawing, target_buffer):
|
||||||
|
|
||||||
def append_image_link(image_id, has_drawing):
|
|
||||||
"""Helper to append image link from image_map based on relationship type."""
|
"""Helper to append image link from image_map based on relationship type."""
|
||||||
rel = doc.part.rels[image_id]
|
rel = doc.part.rels[image_id]
|
||||||
if rel.is_external:
|
if rel.is_external:
|
||||||
if image_id in image_map and not has_drawing:
|
if image_id in image_map and not has_drawing:
|
||||||
paragraph_content.append(image_map[image_id])
|
target_buffer.append(image_map[image_id])
|
||||||
else:
|
else:
|
||||||
image_part = rel.target_part
|
image_part = rel.target_part
|
||||||
if image_part in image_map and not has_drawing:
|
if image_part in image_map and not has_drawing:
|
||||||
paragraph_content.append(image_map[image_part])
|
target_buffer.append(image_map[image_part])
|
||||||
|
|
||||||
for run in paragraph.runs:
|
def process_run(run, target_buffer):
|
||||||
|
# Helper to extract text and embedded images from a run element and append them to target_buffer
|
||||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||||
# Process drawing type images
|
# Process drawing type images
|
||||||
drawing_elements = run.element.findall(
|
drawing_elements = run.element.findall(
|
||||||
@ -287,13 +264,13 @@ class WordExtractor(BaseExtractor):
|
|||||||
# External image: use embed_id as key
|
# External image: use embed_id as key
|
||||||
if embed_id in image_map:
|
if embed_id in image_map:
|
||||||
has_drawing = True
|
has_drawing = True
|
||||||
paragraph_content.append(image_map[embed_id])
|
target_buffer.append(image_map[embed_id])
|
||||||
else:
|
else:
|
||||||
# Internal image: use target_part as key
|
# Internal image: use target_part as key
|
||||||
image_part = doc.part.related_parts.get(embed_id)
|
image_part = doc.part.related_parts.get(embed_id)
|
||||||
if image_part in image_map:
|
if image_part in image_map:
|
||||||
has_drawing = True
|
has_drawing = True
|
||||||
paragraph_content.append(image_map[image_part])
|
target_buffer.append(image_map[image_part])
|
||||||
# Process pict type images
|
# Process pict type images
|
||||||
shape_elements = run.element.findall(
|
shape_elements = run.element.findall(
|
||||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||||
@ -308,7 +285,7 @@ class WordExtractor(BaseExtractor):
|
|||||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||||
)
|
)
|
||||||
if image_id and image_id in doc.part.rels:
|
if image_id and image_id in doc.part.rels:
|
||||||
append_image_link(image_id, has_drawing)
|
append_image_link(image_id, has_drawing, target_buffer)
|
||||||
# Find imagedata element in VML
|
# Find imagedata element in VML
|
||||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||||
if image_data is not None:
|
if image_data is not None:
|
||||||
@ -316,9 +293,93 @@ class WordExtractor(BaseExtractor):
|
|||||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||||
)
|
)
|
||||||
if image_id and image_id in doc.part.rels:
|
if image_id and image_id in doc.part.rels:
|
||||||
append_image_link(image_id, has_drawing)
|
append_image_link(image_id, has_drawing, target_buffer)
|
||||||
if run.text.strip():
|
if run.text.strip():
|
||||||
paragraph_content.append(run.text.strip())
|
target_buffer.append(run.text.strip())
|
||||||
|
|
||||||
|
def process_hyperlink(hyperlink_elem, target_buffer):
|
||||||
|
# Helper to extract text from a hyperlink element and append it to target_buffer
|
||||||
|
r_id = hyperlink_elem.get(qn("r:id"))
|
||||||
|
|
||||||
|
# Extract text from runs inside the hyperlink
|
||||||
|
link_text_parts = []
|
||||||
|
for run_elem in hyperlink_elem.findall(qn("w:r")):
|
||||||
|
run = Run(run_elem, paragraph)
|
||||||
|
# Hyperlink text may be split across multiple runs (e.g., with different formatting),
|
||||||
|
# so collect all run texts first
|
||||||
|
if run.text:
|
||||||
|
link_text_parts.append(run.text)
|
||||||
|
|
||||||
|
link_text = "".join(link_text_parts).strip()
|
||||||
|
|
||||||
|
# Resolve URL
|
||||||
|
if r_id:
|
||||||
|
try:
|
||||||
|
rel = doc.part.rels.get(r_id)
|
||||||
|
if rel and rel.is_external:
|
||||||
|
link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
|
||||||
|
|
||||||
|
if link_text:
|
||||||
|
target_buffer.append(link_text)
|
||||||
|
|
||||||
|
paragraph_content = []
|
||||||
|
# State for legacy HYPERLINK fields
|
||||||
|
hyperlink_field_url = None
|
||||||
|
hyperlink_field_text_parts: list = []
|
||||||
|
is_collecting_field_text = False
|
||||||
|
# Iterate through paragraph elements in document order
|
||||||
|
for child in paragraph._element:
|
||||||
|
tag = child.tag
|
||||||
|
if tag == qn("w:r"):
|
||||||
|
# Regular run
|
||||||
|
run = Run(child, paragraph)
|
||||||
|
|
||||||
|
# Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
|
||||||
|
fld_chars = child.findall(qn("w:fldChar"))
|
||||||
|
instr_texts = child.findall(qn("w:instrText"))
|
||||||
|
|
||||||
|
# Handle Fields
|
||||||
|
if fld_chars or instr_texts:
|
||||||
|
# Process instrText to find HYPERLINK "url"
|
||||||
|
for instr in instr_texts:
|
||||||
|
if instr.text and "HYPERLINK" in instr.text:
|
||||||
|
# Quick regex to extract URL
|
||||||
|
match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
hyperlink_field_url = match.group(1)
|
||||||
|
|
||||||
|
# Process fldChar
|
||||||
|
for fld_char in fld_chars:
|
||||||
|
fld_char_type = fld_char.get(qn("w:fldCharType"))
|
||||||
|
if fld_char_type == "begin":
|
||||||
|
# Start of a field: reset legacy link state
|
||||||
|
hyperlink_field_url = None
|
||||||
|
hyperlink_field_text_parts = []
|
||||||
|
is_collecting_field_text = False
|
||||||
|
elif fld_char_type == "separate":
|
||||||
|
# Separator: if we found a URL, start collecting visible text
|
||||||
|
if hyperlink_field_url:
|
||||||
|
is_collecting_field_text = True
|
||||||
|
elif fld_char_type == "end":
|
||||||
|
# End of field
|
||||||
|
if is_collecting_field_text and hyperlink_field_url:
|
||||||
|
# Create markdown link and append to main content
|
||||||
|
display_text = "".join(hyperlink_field_text_parts).strip()
|
||||||
|
if display_text:
|
||||||
|
link_md = f"[{display_text}]({hyperlink_field_url})"
|
||||||
|
paragraph_content.append(link_md)
|
||||||
|
# Reset state
|
||||||
|
hyperlink_field_url = None
|
||||||
|
hyperlink_field_text_parts = []
|
||||||
|
is_collecting_field_text = False
|
||||||
|
|
||||||
|
# Decide where to append content
|
||||||
|
target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
|
||||||
|
process_run(run, target_buffer)
|
||||||
|
elif tag == qn("w:hyperlink"):
|
||||||
|
process_hyperlink(child, paragraph_content)
|
||||||
return "".join(paragraph_content) if paragraph_content else ""
|
return "".join(paragraph_content) if paragraph_content else ""
|
||||||
|
|
||||||
paragraphs = doc.paragraphs.copy()
|
paragraphs = doc.paragraphs.copy()
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
|
|||||||
return self.model_dump_json()
|
return self.model_dump_json()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
|
def deserialize(cls, serialized_data: str) -> TaskWrapper:
|
||||||
return cls.model_validate_json(serialized_data)
|
return cls.model_validate_json(serialized_data)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Mapping, MutableMapping
|
from collections.abc import Mapping, MutableMapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar, Optional
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
|
||||||
class SchemaRegistry:
|
class SchemaRegistry:
|
||||||
@ -11,7 +13,7 @@ class SchemaRegistry:
|
|||||||
|
|
||||||
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
||||||
|
|
||||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
_default_instance: ClassVar[SchemaRegistry | None] = None
|
||||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, base_dir: str):
|
def __init__(self, base_dir: str):
|
||||||
@ -20,7 +22,7 @@ class SchemaRegistry:
|
|||||||
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
|
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_registry(cls) -> "SchemaRegistry":
|
def default_registry(cls) -> SchemaRegistry:
|
||||||
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
|
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
|
||||||
if cls._default_instance is None:
|
if cls._default_instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -24,7 +26,7 @@ class Tool(ABC):
|
|||||||
self.entity = entity
|
self.entity = entity
|
||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
:return: the new tool
|
:return: the new tool
|
||||||
@ -166,7 +168,7 @@ class Tool(ABC):
|
|||||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
def create_file_message(self, file: File) -> ToolInvokeMessage:
|
||||||
return ToolInvokeMessage(
|
return ToolInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.FILE,
|
type=ToolInvokeMessage.MessageType.FILE,
|
||||||
message=ToolInvokeMessage.FileMessage(),
|
message=ToolInvokeMessage.FileMessage(),
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
@ -24,7 +26,7 @@ class BuiltinTool(Tool):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
:return: the new tool
|
:return: the new tool
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
|
|||||||
self.tools = []
|
self.tools = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
|
||||||
credentials_schema = [
|
credentials_schema = [
|
||||||
ProviderConfig(
|
ProviderConfig(
|
||||||
name="auth_type",
|
name="auth_type",
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import contextlib
|
import contextlib
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
|
|||||||
MCP = auto()
|
MCP = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ToolProviderType":
|
def value_of(cls, value: str) -> ToolProviderType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
|
|||||||
OPENAI_ACTIONS = auto()
|
OPENAI_ACTIONS = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
def value_of(cls, value: str) -> ApiProviderSchemaType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
|
|||||||
API_KEY_QUERY = auto()
|
API_KEY_QUERY = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
def value_of(cls, value: str) -> ApiProviderAuthType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -307,7 +309,7 @@ class ToolParameter(PluginParameter):
|
|||||||
typ: ToolParameterType,
|
typ: ToolParameterType,
|
||||||
required: bool,
|
required: bool,
|
||||||
options: list[str] | None = None,
|
options: list[str] | None = None,
|
||||||
) -> "ToolParameter":
|
) -> ToolParameter:
|
||||||
"""
|
"""
|
||||||
get a simple tool parameter
|
get a simple tool parameter
|
||||||
|
|
||||||
@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel):
|
|||||||
tool_config: dict | None = None
|
tool_config: dict | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "ToolInvokeMeta":
|
def empty(cls) -> ToolInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an empty instance of ToolInvokeMeta
|
Get an empty instance of ToolInvokeMeta
|
||||||
"""
|
"""
|
||||||
return cls(time_cost=0.0, error=None, tool_config={})
|
return cls(time_cost=0.0, error=None, tool_config={})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def error_instance(cls, error: str) -> "ToolInvokeMeta":
|
def error_instance(cls, error: str) -> ToolInvokeMeta:
|
||||||
"""
|
"""
|
||||||
Get an instance of ToolInvokeMeta with error
|
Get an instance of ToolInvokeMeta with error
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -118,7 +120,7 @@ class MCPTool(Tool):
|
|||||||
for item in json_list:
|
for item in json_list:
|
||||||
yield self.create_json_message(item)
|
yield self.create_json_message(item)
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||||
return MCPTool(
|
return MCPTool(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -46,7 +48,7 @@ class PluginTool(Tool):
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
|
||||||
return PluginTool(
|
return PluginTool(
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
|
|||||||
@ -7,12 +7,12 @@ import time
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
sign file to get a temporary url for plugin access
|
sign file to get a temporary url for plugin access
|
||||||
"""
|
"""
|
||||||
# Use internal URL for plugin/tool file access in Docker environments
|
# Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
|
||||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
|
||||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||||
|
|
||||||
timestamp = str(int(time.time()))
|
timestamp = str(int(time.time()))
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -47,7 +49,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
self.provider_id = provider_id
|
self.provider_id = provider_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||||
with session_factory.create_session() as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
app = session.get(App, db_provider.app_id)
|
app = session.get(App, db_provider.app_id)
|
||||||
if not app:
|
if not app:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
@ -181,7 +183,7 @@ class WorkflowTool(Tool):
|
|||||||
return found
|
return found
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
|
||||||
"""
|
"""
|
||||||
fork a new tool with metadata
|
fork a new tool with metadata
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
|
||||||
@ -52,7 +54,7 @@ class SegmentType(StrEnum):
|
|||||||
return self in _ARRAY_TYPES
|
return self in _ARRAY_TYPES
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
|
def infer_segment_type(cls, value: Any) -> SegmentType | None:
|
||||||
"""
|
"""
|
||||||
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
|
||||||
|
|
||||||
@ -173,7 +175,7 @@ class SegmentType(StrEnum):
|
|||||||
raise AssertionError("this statement should be unreachable.")
|
raise AssertionError("this statement should be unreachable.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cast_value(value: Any, type_: "SegmentType"):
|
def cast_value(value: Any, type_: SegmentType):
|
||||||
# Cast Python's `bool` type to `int` when the runtime type requires
|
# Cast Python's `bool` type to `int` when the runtime type requires
|
||||||
# an integer or number.
|
# an integer or number.
|
||||||
#
|
#
|
||||||
@ -193,7 +195,7 @@ class SegmentType(StrEnum):
|
|||||||
return [int(i) for i in value]
|
return [int(i) for i in value]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def exposed_type(self) -> "SegmentType":
|
def exposed_type(self) -> SegmentType:
|
||||||
"""Returns the type exposed to the frontend.
|
"""Returns the type exposed to the frontend.
|
||||||
|
|
||||||
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
|
||||||
@ -202,7 +204,7 @@ class SegmentType(StrEnum):
|
|||||||
return SegmentType.NUMBER
|
return SegmentType.NUMBER
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def element_type(self) -> "SegmentType | None":
|
def element_type(self) -> SegmentType | None:
|
||||||
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
"""Return the element type of the current segment type, or `None` if the element type is undefined.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -217,7 +219,7 @@ class SegmentType(StrEnum):
|
|||||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_zero_value(t: "SegmentType"):
|
def get_zero_value(t: SegmentType):
|
||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
|
|||||||
implementation details like tenant_id, app_id, etc.
|
implementation details like tenant_id, app_id, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
|
|||||||
graph: Mapping[str, Any],
|
graph: Mapping[str, Any],
|
||||||
inputs: Mapping[str, Any],
|
inputs: Mapping[str, Any],
|
||||||
started_at: datetime,
|
started_at: datetime,
|
||||||
) -> "WorkflowExecution":
|
) -> WorkflowExecution:
|
||||||
return WorkflowExecution(
|
return WorkflowExecution(
|
||||||
id_=id_,
|
id_=id_,
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
@ -175,7 +177,7 @@ class Graph:
|
|||||||
def _create_node_instances(
|
def _create_node_instances(
|
||||||
cls,
|
cls,
|
||||||
node_configs_map: dict[str, dict[str, object]],
|
node_configs_map: dict[str, dict[str, object]],
|
||||||
node_factory: "NodeFactory",
|
node_factory: NodeFactory,
|
||||||
) -> dict[str, Node]:
|
) -> dict[str, Node]:
|
||||||
"""
|
"""
|
||||||
Create node instances from configurations using the node factory.
|
Create node instances from configurations using the node factory.
|
||||||
@ -197,7 +199,7 @@ class Graph:
|
|||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(cls) -> "GraphBuilder":
|
def new(cls) -> GraphBuilder:
|
||||||
"""Create a fluent builder for assembling a graph programmatically."""
|
"""Create a fluent builder for assembling a graph programmatically."""
|
||||||
|
|
||||||
return GraphBuilder(graph_cls=cls)
|
return GraphBuilder(graph_cls=cls)
|
||||||
@ -284,9 +286,9 @@ class Graph:
|
|||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
graph_config: Mapping[str, object],
|
graph_config: Mapping[str, object],
|
||||||
node_factory: "NodeFactory",
|
node_factory: NodeFactory,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
) -> "Graph":
|
) -> Graph:
|
||||||
"""
|
"""
|
||||||
Initialize graph
|
Initialize graph
|
||||||
|
|
||||||
@ -383,7 +385,7 @@ class GraphBuilder:
|
|||||||
self._edges: list[Edge] = []
|
self._edges: list[Edge] = []
|
||||||
self._edge_counter = 0
|
self._edge_counter = 0
|
||||||
|
|
||||||
def add_root(self, node: Node) -> "GraphBuilder":
|
def add_root(self, node: Node) -> GraphBuilder:
|
||||||
"""Register the root node. Must be called exactly once."""
|
"""Register the root node. Must be called exactly once."""
|
||||||
|
|
||||||
if self._nodes:
|
if self._nodes:
|
||||||
@ -398,7 +400,7 @@ class GraphBuilder:
|
|||||||
*,
|
*,
|
||||||
from_node_id: str | None = None,
|
from_node_id: str | None = None,
|
||||||
source_handle: str = "source",
|
source_handle: str = "source",
|
||||||
) -> "GraphBuilder":
|
) -> GraphBuilder:
|
||||||
"""Append a node and connect it from the specified predecessor."""
|
"""Append a node and connect it from the specified predecessor."""
|
||||||
|
|
||||||
if not self._nodes:
|
if not self._nodes:
|
||||||
@ -419,7 +421,7 @@ class GraphBuilder:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
|
||||||
"""Connect two existing nodes without adding a new node."""
|
"""Connect two existing nodes without adding a new node."""
|
||||||
|
|
||||||
if tail not in self._nodes_by_id:
|
if tail not in self._nodes_by_id:
|
||||||
|
|||||||
@ -5,6 +5,8 @@ This engine uses a modular architecture with separated packages following
|
|||||||
Domain-Driven Design principles for improved maintainability and testability.
|
Domain-Driven Design principles for improved maintainability and testability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
@ -232,7 +234,7 @@ class GraphEngine:
|
|||||||
) -> None:
|
) -> None:
|
||||||
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
||||||
|
|
||||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
|
||||||
"""Add a layer for extending functionality."""
|
"""Add a layer for extending functionality."""
|
||||||
self._layers.append(layer)
|
self._layers.append(layer)
|
||||||
self._bind_layer_context(layer)
|
self._bind_layer_context(layer)
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
Factory for creating ReadyQueue instances from serialized state.
|
Factory for creating ReadyQueue instances from serialized state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from .in_memory import InMemoryReadyQueue
|
from .in_memory import InMemoryReadyQueue
|
||||||
@ -11,7 +13,7 @@ if TYPE_CHECKING:
|
|||||||
from .protocol import ReadyQueue
|
from .protocol import ReadyQueue
|
||||||
|
|
||||||
|
|
||||||
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
|
||||||
"""
|
"""
|
||||||
Create a ReadyQueue instance from a serialized state.
|
Create a ReadyQueue instance from a serialized state.
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
|
|||||||
by ResponseStreamCoordinator to manage streaming sessions.
|
by ResponseStreamCoordinator to manage streaming sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
@ -27,7 +29,7 @@ class ResponseSession:
|
|||||||
index: int = 0 # Current position in the template segments
|
index: int = 0 # Current position in the template segments
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_node(cls, node: Node) -> "ResponseSession":
|
def from_node(cls, node: Node) -> ResponseSession:
|
||||||
"""
|
"""
|
||||||
Create a ResponseSession from an AnswerNode or EndNode.
|
Create a ResponseSession from an AnswerNode or EndNode.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
node_data: AgentNodeData,
|
node_data: AgentNodeData,
|
||||||
for_log: bool = False,
|
for_log: bool = False,
|
||||||
strategy: "PluginAgentStrategy",
|
strategy: PluginAgentStrategy,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||||
@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
def _generate_credentials(
|
def _generate_credentials(
|
||||||
self,
|
self,
|
||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
) -> "InvokeCredentials":
|
) -> InvokeCredentials:
|
||||||
"""
|
"""
|
||||||
Generate credentials based on the given agent parameters.
|
Generate credentials based on the given agent parameters.
|
||||||
"""
|
"""
|
||||||
@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||||||
model_schema.features.remove(feature)
|
model_schema.features.remove(feature)
|
||||||
return model_schema
|
return model_schema
|
||||||
|
|
||||||
def _filter_mcp_type_tool(
|
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Filter MCP type tool
|
Filter MCP type tool
|
||||||
:param strategy: plugin agent strategy
|
:param strategy: plugin agent strategy
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from builtins import type as type_
|
from builtins import type as type_
|
||||||
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
|
|||||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_value_type(self) -> "DefaultValue":
|
def validate_value_type(self) -> DefaultValue:
|
||||||
# Type validation configuration
|
# Type validation configuration
|
||||||
type_validators = {
|
type_validators = {
|
||||||
DefaultValueType.STRING: {
|
DefaultValueType.STRING: {
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
@ -59,7 +61,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Node(Generic[NodeDataT]):
|
class Node(Generic[NodeDataT]):
|
||||||
node_type: ClassVar["NodeType"]
|
node_type: ClassVar[NodeType]
|
||||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||||
|
|
||||||
@ -198,14 +200,14 @@ class Node(Generic[NodeDataT]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Global registry populated via __init_subclass__
|
# Global registry populated via __init_subclass__
|
||||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
config: Mapping[str, Any],
|
config: Mapping[str, Any],
|
||||||
graph_init_params: "GraphInitParams",
|
graph_init_params: GraphInitParams,
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
graph_runtime_state: GraphRuntimeState,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._graph_init_params = graph_init_params
|
self._graph_init_params = graph_init_params
|
||||||
self.id = id
|
self.id = id
|
||||||
@ -241,7 +243,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def graph_init_params(self) -> "GraphInitParams":
|
def graph_init_params(self) -> GraphInitParams:
|
||||||
return self._graph_init_params
|
return self._graph_init_params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -457,7 +459,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||||
|
|
||||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||||
|
|||||||
@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
|
|||||||
similar to SegmentGroup but focused on template representation without values.
|
similar to SegmentGroup but focused on template representation without values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -58,7 +60,7 @@ class Template:
|
|||||||
segments: list[TemplateSegmentUnion]
|
segments: list[TemplateSegmentUnion]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_answer_template(cls, template_str: str) -> "Template":
|
def from_answer_template(cls, template_str: str) -> Template:
|
||||||
"""Create a Template from an Answer node template string.
|
"""Create a Template from an Answer node template string.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -107,7 +109,7 @@ class Template:
|
|||||||
return cls(segments=segments)
|
return cls(segments=segments)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
|
||||||
"""Create a Template from an End node outputs configuration.
|
"""Create a Template from an End node outputs configuration.
|
||||||
|
|
||||||
End nodes are treated as templates of concatenated variables with newlines.
|
End nodes are treated as templates of concatenated variables with newlines.
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
|
|
||||||
# Instance attributes specific to LLMNode.
|
# Instance attributes specific to LLMNode.
|
||||||
# Output variable for file
|
# Output variable for file
|
||||||
_file_outputs: list["File"]
|
_file_outputs: list[File]
|
||||||
|
|
||||||
_llm_file_saver: LLMFileSaver
|
_llm_file_saver: LLMFileSaver
|
||||||
|
|
||||||
@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
config: Mapping[str, Any],
|
config: Mapping[str, Any],
|
||||||
graph_init_params: "GraphInitParams",
|
graph_init_params: GraphInitParams,
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
graph_runtime_state: GraphRuntimeState,
|
||||||
*,
|
*,
|
||||||
llm_file_saver: LLMFileSaver | None = None,
|
llm_file_saver: LLMFileSaver | None = None,
|
||||||
):
|
):
|
||||||
@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
structured_output_enabled: bool,
|
structured_output_enabled: bool,
|
||||||
structured_output: Mapping[str, Any] | None = None,
|
structured_output: Mapping[str, Any] | None = None,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||||
@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||||
@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _image_file_to_markdown(file: "File", /):
|
def _image_file_to_markdown(file: File, /):
|
||||||
text_chunk = f"})"
|
text_chunk = f"})"
|
||||||
return text_chunk
|
return text_chunk
|
||||||
|
|
||||||
@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
def fetch_prompt_messages(
|
def fetch_prompt_messages(
|
||||||
*,
|
*,
|
||||||
sys_query: str | None = None,
|
sys_query: str | None = None,
|
||||||
sys_files: Sequence["File"],
|
sys_files: Sequence[File],
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
jinja2_variables: Sequence[VariableSelector],
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
context_files: list["File"] | None = None,
|
context_files: list[File] | None = None,
|
||||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
|
||||||
@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
invoke_result: LLMResult | LLMResultWithStructuredOutput,
|
||||||
saver: LLMFileSaver,
|
saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||||
request_latency: float | None = None,
|
request_latency: float | None = None,
|
||||||
) -> ModelInvokeCompletedEvent:
|
) -> ModelInvokeCompletedEvent:
|
||||||
@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
content: ImagePromptMessageContent,
|
content: ImagePromptMessageContent,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
) -> "File":
|
) -> File:
|
||||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||||
|
|
||||||
There are two kinds of multimodal outputs:
|
There are two kinds of multimodal outputs:
|
||||||
@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
*,
|
*,
|
||||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list["File"],
|
file_outputs: list[File],
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||||
|
|
||||||
|
|||||||
@ -113,7 +113,6 @@ class DifyNodeFactory(NodeFactory):
|
|||||||
code_providers=self._code_providers,
|
code_providers=self._code_providers,
|
||||||
code_limits=self._code_limits,
|
code_limits=self._code_limits,
|
||||||
)
|
)
|
||||||
|
|
||||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||||
return TemplateTransformNode(
|
return TemplateTransformNode(
|
||||||
id=node_id,
|
id=node_id,
|
||||||
|
|||||||
@ -1,28 +0,0 @@
|
|||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from core.variables.variables import Variable
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models import ConversationVariable
|
|
||||||
|
|
||||||
from .exc import VariableOperatorNodeError
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariableUpdaterImpl:
|
|
||||||
def update(self, conversation_id: str, variable: Variable):
|
|
||||||
stmt = select(ConversationVariable).where(
|
|
||||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
|
||||||
)
|
|
||||||
with Session(db.engine) as session:
|
|
||||||
row = session.scalar(stmt)
|
|
||||||
if not row:
|
|
||||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
|
||||||
row.data = variable.model_dump_json()
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
|
|
||||||
return ConversationVariableUpdaterImpl()
|
|
||||||
@ -1,9 +1,8 @@
|
|||||||
from collections.abc import Callable, Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from core.variables import SegmentType, Variable
|
from core.variables import SegmentType, Variable
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
|
|||||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||||
|
|
||||||
from ..common.impl import conversation_variable_updater_factory
|
|
||||||
from .node_data import VariableAssignerData, WriteMode
|
from .node_data import VariableAssignerData, WriteMode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
|
||||||
|
|
||||||
|
|
||||||
class VariableAssignerNode(Node[VariableAssignerData]):
|
class VariableAssignerNode(Node[VariableAssignerData]):
|
||||||
node_type = NodeType.VARIABLE_ASSIGNER
|
node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
|||||||
config: Mapping[str, Any],
|
config: Mapping[str, Any],
|
||||||
graph_init_params: "GraphInitParams",
|
graph_init_params: "GraphInitParams",
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id=id,
|
id=id,
|
||||||
@ -39,7 +32,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
|||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
self._conv_var_updater_factory = conv_var_updater_factory
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
@ -96,16 +88,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
|||||||
# Over write the variable.
|
# Over write the variable.
|
||||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||||
|
|
||||||
# TODO: Move database operation to the pipeline.
|
|
||||||
# Update conversation variable.
|
|
||||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
|
||||||
if not conversation_id:
|
|
||||||
raise VariableOperatorNodeError("conversation_id not found")
|
|
||||||
conv_var_updater = self._conv_var_updater_factory()
|
|
||||||
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
|
|
||||||
conv_var_updater.flush()
|
|
||||||
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
inputs={
|
inputs={
|
||||||
|
|||||||
@ -1,24 +1,20 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Mapping, MutableMapping, Sequence
|
from collections.abc import Mapping, MutableMapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.variables import SegmentType, Variable
|
from core.variables import SegmentType, Variable
|
||||||
from core.variables.consts import SELECTORS_LENGTH
|
from core.variables.consts import SELECTORS_LENGTH
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
|
||||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
|
||||||
|
|
||||||
from . import helpers
|
from . import helpers
|
||||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||||
from .enums import InputType, Operation
|
from .enums import InputType, Operation
|
||||||
from .exc import (
|
from .exc import (
|
||||||
ConversationIDNotFoundError,
|
|
||||||
InputTypeNotSupportedError,
|
InputTypeNotSupportedError,
|
||||||
InvalidDataError,
|
InvalidDataError,
|
||||||
InvalidInputValueError,
|
InvalidInputValueError,
|
||||||
@ -26,6 +22,10 @@ from .exc import (
|
|||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||||
selector_node_id = item.variable_selector[0]
|
selector_node_id = item.variable_selector[0]
|
||||||
@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
|||||||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||||
node_type = NodeType.VARIABLE_ASSIGNER
|
node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
config: Mapping[str, Any],
|
||||||
|
graph_init_params: "GraphInitParams",
|
||||||
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
id=id,
|
||||||
|
config=config,
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if this Variable Assigner node blocks the output of specific variables.
|
Check if this Variable Assigner node blocks the output of specific variables.
|
||||||
@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
|
||||||
return conversation_variable_updater_factory()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "2"
|
return "2"
|
||||||
@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
|||||||
# remove the duplicated items first.
|
# remove the duplicated items first.
|
||||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||||
|
|
||||||
conv_var_updater = self._conv_var_updater_factory()
|
|
||||||
# Update variables
|
|
||||||
for selector in updated_variable_selectors:
|
for selector in updated_variable_selectors:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||||
if not isinstance(variable, Variable):
|
if not isinstance(variable, Variable):
|
||||||
raise VariableNotFoundError(variable_selector=selector)
|
raise VariableNotFoundError(variable_selector=selector)
|
||||||
process_data[variable.name] = variable.value
|
process_data[variable.name] = variable.value
|
||||||
|
|
||||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
|
||||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
|
||||||
if not conversation_id:
|
|
||||||
if self.invoke_from != InvokeFrom.DEBUGGER:
|
|
||||||
raise ConversationIDNotFoundError
|
|
||||||
else:
|
|
||||||
conversation_id = conversation_id.value
|
|
||||||
conv_var_updater.update(
|
|
||||||
conversation_id=cast(str, conversation_id),
|
|
||||||
variable=variable,
|
|
||||||
)
|
|
||||||
conv_var_updater.flush()
|
|
||||||
updated_variables = [
|
updated_variables = [
|
||||||
common_helpers.variable_to_processed_data(selector, seg)
|
common_helpers.variable_to_processed_data(selector, seg)
|
||||||
for selector in updated_variable_selectors
|
for selector in updated_variable_selectors
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
|
|||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
node_execution_id: str,
|
node_execution_id: str,
|
||||||
enclosing_node_id: str | None = None,
|
enclosing_node_id: str | None = None,
|
||||||
) -> "DraftVariableSaver":
|
) -> DraftVariableSaver:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
|
|||||||
class ReadOnlyVariablePool(Protocol):
|
class ReadOnlyVariablePool(Protocol):
|
||||||
"""Read-only interface for VariablePool."""
|
"""Read-only interface for VariablePool."""
|
||||||
|
|
||||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||||
"""Get a variable value (read-only)."""
|
"""Get a variable value (read-only)."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
|
|||||||
def __init__(self, variable_pool: VariablePool) -> None:
|
def __init__(self, variable_pool: VariablePool) -> None:
|
||||||
self._variable_pool = variable_pool
|
self._variable_pool = variable_pool
|
||||||
|
|
||||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||||
"""Return a copy of a variable value if present."""
|
"""Return a copy of a variable value if present."""
|
||||||
value = self._variable_pool.get([node_id, variable_key])
|
value = self._variable_pool.get(selector)
|
||||||
return deepcopy(value) if value is not None else None
|
return deepcopy(value) if value is not None else None
|
||||||
|
|
||||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
@ -267,6 +269,6 @@ class VariablePool(BaseModel):
|
|||||||
self.add(selector, value)
|
self.add(selector, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "VariablePool":
|
def empty(cls) -> VariablePool:
|
||||||
"""Create an empty variable pool."""
|
"""Create an empty variable pool."""
|
||||||
return cls(system_variables=SystemVariable.empty())
|
return cls(system_variables=SystemVariable.empty())
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -70,7 +72,7 @@ class SystemVariable(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "SystemVariable":
|
def empty(cls) -> SystemVariable:
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||||
@ -114,7 +116,7 @@ class SystemVariable(BaseModel):
|
|||||||
d[SystemVariableKey.TIMESTAMP] = self.timestamp
|
d[SystemVariableKey.TIMESTAMP] = self.timestamp
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def as_view(self) -> "SystemVariableReadOnlyView":
|
def as_view(self) -> SystemVariableReadOnlyView:
|
||||||
return SystemVariableReadOnlyView(self)
|
return SystemVariableReadOnlyView(self)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,9 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# Set UTF-8 encoding to address potential encoding issues in containerized environments
|
# Set UTF-8 encoding to address potential encoding issues in containerized environments
|
||||||
export LANG=${LANG:-en_US.UTF-8}
|
# Use C.UTF-8 which is universally available in all containers
|
||||||
export LC_ALL=${LC_ALL:-en_US.UTF-8}
|
export LANG=${LANG:-C.UTF-8}
|
||||||
|
export LC_ALL=${LC_ALL:-C.UTF-8}
|
||||||
export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8}
|
export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8}
|
||||||
|
|
||||||
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
|
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
|
||||||
|
|||||||
@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
|
|||||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||||
|
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||||
|
|
||||||
|
|
||||||
@ -42,10 +43,28 @@ def init_app(app: DifyApp):
|
|||||||
|
|
||||||
_apply_cors_once(
|
_apply_cors_once(
|
||||||
web_bp,
|
web_bp,
|
||||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
resources={
|
||||||
supports_credentials=True,
|
# Embedded bot endpoints (unauthenticated, cross-origin safe)
|
||||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
r"^/chat-messages$": {
|
||||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||||
|
"supports_credentials": False,
|
||||||
|
"allow_headers": list(EMBED_HEADERS),
|
||||||
|
"methods": ["GET", "POST", "OPTIONS"],
|
||||||
|
},
|
||||||
|
r"^/chat-messages/.*": {
|
||||||
|
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||||
|
"supports_credentials": False,
|
||||||
|
"allow_headers": list(EMBED_HEADERS),
|
||||||
|
"methods": ["GET", "POST", "OPTIONS"],
|
||||||
|
},
|
||||||
|
# Default web application endpoints (authenticated)
|
||||||
|
r"/*": {
|
||||||
|
"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
|
||||||
|
"supports_credentials": True,
|
||||||
|
"allow_headers": list(AUTHENTICATED_HEADERS),
|
||||||
|
"methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
},
|
||||||
|
},
|
||||||
expose_headers=list(EXPOSED_HEADERS),
|
expose_headers=list(EXPOSED_HEADERS),
|
||||||
)
|
)
|
||||||
app.register_blueprint(web_bp)
|
app.register_blueprint(web_bp)
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@ -33,7 +35,7 @@ class AliyunLogStore:
|
|||||||
Ensures only one instance exists to prevent multiple PG connection pools.
|
Ensures only one instance exists to prevent multiple PG connection pools.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: "AliyunLogStore | None" = None
|
_instance: AliyunLogStore | None = None
|
||||||
_initialized: bool = False
|
_initialized: bool = False
|
||||||
|
|
||||||
# Track delayed PG connection for newly created projects
|
# Track delayed PG connection for newly created projects
|
||||||
@ -66,7 +68,7 @@ class AliyunLogStore:
|
|||||||
"\t",
|
"\t",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __new__(cls) -> "AliyunLogStore":
|
def __new__(cls) -> AliyunLogStore:
|
||||||
"""Implement singleton pattern."""
|
"""Implement singleton pattern."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
|
|||||||
@ -5,6 +5,8 @@ automatic cleanup, backup and restore.
|
|||||||
Supports complete lifecycle management for knowledge base files.
|
Supports complete lifecycle management for knowledge base files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
@ -48,7 +50,7 @@ class FileMetadata:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
def from_dict(cls, data: dict) -> FileMetadata:
|
||||||
"""Create instance from dictionary"""
|
"""Create instance from dictionary"""
|
||||||
data = data.copy()
|
data = data.copy()
|
||||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||||
|
|||||||
@ -1,93 +1,85 @@
|
|||||||
from flask_restx import Namespace, fields
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from datetime import datetime
|
||||||
|
|
||||||
upload_config_fields = {
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
"file_size_limit": fields.Integer,
|
|
||||||
"batch_count_limit": fields.Integer,
|
|
||||||
"image_file_size_limit": fields.Integer,
|
|
||||||
"video_file_size_limit": fields.Integer,
|
|
||||||
"audio_file_size_limit": fields.Integer,
|
|
||||||
"workflow_file_upload_limit": fields.Integer,
|
|
||||||
"image_file_batch_limit": fields.Integer,
|
|
||||||
"single_chunk_attachment_limit": fields.Integer,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_upload_config_model(api_or_ns: Namespace):
|
class ResponseModel(BaseModel):
|
||||||
"""Build the upload config model for the API or Namespace.
|
model_config = ConfigDict(
|
||||||
|
from_attributes=True,
|
||||||
Args:
|
extra="ignore",
|
||||||
api_or_ns: Flask-RestX Api or Namespace instance
|
populate_by_name=True,
|
||||||
|
serialize_by_alias=True,
|
||||||
Returns:
|
protected_namespaces=(),
|
||||||
The registered model
|
)
|
||||||
"""
|
|
||||||
return api_or_ns.model("UploadConfig", upload_config_fields)
|
|
||||||
|
|
||||||
|
|
||||||
file_fields = {
|
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||||
"id": fields.String,
|
if isinstance(value, datetime):
|
||||||
"name": fields.String,
|
return int(value.timestamp())
|
||||||
"size": fields.Integer,
|
return value
|
||||||
"extension": fields.String,
|
|
||||||
"mime_type": fields.String,
|
|
||||||
"created_by": fields.String,
|
|
||||||
"created_at": TimestampField,
|
|
||||||
"preview_url": fields.String,
|
|
||||||
"source_url": fields.String,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_file_model(api_or_ns: Namespace):
|
class UploadConfig(ResponseModel):
|
||||||
"""Build the file model for the API or Namespace.
|
file_size_limit: int
|
||||||
|
batch_count_limit: int
|
||||||
Args:
|
file_upload_limit: int | None = None
|
||||||
api_or_ns: Flask-RestX Api or Namespace instance
|
image_file_size_limit: int
|
||||||
|
video_file_size_limit: int
|
||||||
Returns:
|
audio_file_size_limit: int
|
||||||
The registered model
|
workflow_file_upload_limit: int
|
||||||
"""
|
image_file_batch_limit: int
|
||||||
return api_or_ns.model("File", file_fields)
|
single_chunk_attachment_limit: int
|
||||||
|
attachment_image_file_size_limit: int | None = None
|
||||||
|
|
||||||
|
|
||||||
remote_file_info_fields = {
|
class FileResponse(ResponseModel):
|
||||||
"file_type": fields.String(attribute="file_type"),
|
id: str
|
||||||
"file_length": fields.Integer(attribute="file_length"),
|
name: str
|
||||||
}
|
size: int
|
||||||
|
extension: str | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
created_by: str | None = None
|
||||||
|
created_at: int | None = None
|
||||||
|
preview_url: str | None = None
|
||||||
|
source_url: str | None = None
|
||||||
|
original_url: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
|
tenant_id: str | None = None
|
||||||
|
conversation_id: str | None = None
|
||||||
|
file_key: str | None = None
|
||||||
|
|
||||||
|
@field_validator("created_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||||
|
return _to_timestamp(value)
|
||||||
|
|
||||||
|
|
||||||
def build_remote_file_info_model(api_or_ns: Namespace):
|
class RemoteFileInfo(ResponseModel):
|
||||||
"""Build the remote file info model for the API or Namespace.
|
file_type: str
|
||||||
|
file_length: int
|
||||||
Args:
|
|
||||||
api_or_ns: Flask-RestX Api or Namespace instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The registered model
|
|
||||||
"""
|
|
||||||
return api_or_ns.model("RemoteFileInfo", remote_file_info_fields)
|
|
||||||
|
|
||||||
|
|
||||||
file_fields_with_signed_url = {
|
class FileWithSignedUrl(ResponseModel):
|
||||||
"id": fields.String,
|
id: str
|
||||||
"name": fields.String,
|
name: str
|
||||||
"size": fields.Integer,
|
size: int
|
||||||
"extension": fields.String,
|
extension: str | None = None
|
||||||
"url": fields.String,
|
url: str | None = None
|
||||||
"mime_type": fields.String,
|
mime_type: str | None = None
|
||||||
"created_by": fields.String,
|
created_by: str | None = None
|
||||||
"created_at": TimestampField,
|
created_at: int | None = None
|
||||||
}
|
|
||||||
|
@field_validator("created_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||||
|
return _to_timestamp(value)
|
||||||
|
|
||||||
|
|
||||||
def build_file_with_signed_url_model(api_or_ns: Namespace):
|
__all__ = [
|
||||||
"""Build the file with signed URL model for the API or Namespace.
|
"FileResponse",
|
||||||
|
"FileWithSignedUrl",
|
||||||
Args:
|
"RemoteFileInfo",
|
||||||
api_or_ns: Flask-RestX Api or Namespace instance
|
"UploadConfig",
|
||||||
|
]
|
||||||
Returns:
|
|
||||||
The registered model
|
|
||||||
"""
|
|
||||||
return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url)
|
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
Broadcast channel for Pub/Sub messaging.
|
Broadcast channel for Pub/Sub messaging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import types
|
import types
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
@ -129,6 +131,6 @@ class BroadcastChannel(Protocol):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def topic(self, topic: str) -> "Topic":
|
def topic(self, topic: str) -> Topic:
|
||||||
"""topic returns a `Topic` instance for the given topic name."""
|
"""topic returns a `Topic` instance for the given topic name."""
|
||||||
...
|
...
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
@ -20,7 +22,7 @@ class BroadcastChannel:
|
|||||||
):
|
):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
|
|
||||||
def topic(self, topic: str) -> "Topic":
|
def topic(self, topic: str) -> Topic:
|
||||||
return Topic(self._client, topic)
|
return Topic(self._client, topic)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel:
|
|||||||
):
|
):
|
||||||
self._client = redis_client
|
self._client = redis_client
|
||||||
|
|
||||||
def topic(self, topic: str) -> "ShardedTopic":
|
def topic(self, topic: str) -> ShardedTopic:
|
||||||
return ShardedTopic(self._client, topic)
|
return ShardedTopic(self._client, topic)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and
|
|||||||
eliminates the need for repetitive language switching logic.
|
eliminates the need for repetitive language switching logic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
@ -53,7 +55,7 @@ class EmailLanguage(StrEnum):
|
|||||||
ZH_HANS = "zh-Hans"
|
ZH_HANS = "zh-Hans"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_language_code(cls, language_code: str) -> "EmailLanguage":
|
def from_language_code(cls, language_code: str) -> EmailLanguage:
|
||||||
"""Convert a language code to EmailLanguage with fallback to English."""
|
"""Convert a language code to EmailLanguage with fallback to English."""
|
||||||
if language_code == "zh-Hans":
|
if language_code == "zh-Hans":
|
||||||
return cls.ZH_HANS
|
return cls.ZH_HANS
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@ -5,7 +7,7 @@ from collections.abc import Mapping
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@ -54,7 +56,7 @@ class AppMode(StrEnum):
|
|||||||
RAG_PIPELINE = "rag-pipeline"
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "AppMode":
|
def value_of(cls, value: str) -> AppMode:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -121,19 +123,19 @@ class App(Base):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def site(self) -> Optional["Site"]:
|
def site(self) -> Site | None:
|
||||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||||
return site
|
return site
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_model_config(self) -> Optional["AppModelConfig"]:
|
def app_model_config(self) -> AppModelConfig | None:
|
||||||
if self.app_model_config_id:
|
if self.app_model_config_id:
|
||||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workflow(self) -> Optional["Workflow"]:
|
def workflow(self) -> Workflow | None:
|
||||||
if self.workflow_id:
|
if self.workflow_id:
|
||||||
from .workflow import Workflow
|
from .workflow import Workflow
|
||||||
|
|
||||||
@ -288,7 +290,7 @@ class App(Base):
|
|||||||
return deleted_tools
|
return deleted_tools
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self) -> list["Tag"]:
|
def tags(self) -> list[Tag]:
|
||||||
tags = (
|
tags = (
|
||||||
db.session.query(Tag)
|
db.session.query(Tag)
|
||||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||||
@ -1194,7 +1196,7 @@ class Message(Base):
|
|||||||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agent_thoughts(self) -> list["MessageAgentThought"]:
|
def agent_thoughts(self) -> list[MessageAgentThought]:
|
||||||
return (
|
return (
|
||||||
db.session.query(MessageAgentThought)
|
db.session.query(MessageAgentThought)
|
||||||
.where(MessageAgentThought.message_id == self.id)
|
.where(MessageAgentThought.message_id == self.id)
|
||||||
@ -1307,7 +1309,7 @@ class Message(Base):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "Message":
|
def from_dict(cls, data: dict[str, Any]) -> Message:
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data["id"],
|
||||||
app_id=data["app_id"],
|
app_id=data["app_id"],
|
||||||
@ -1534,7 +1536,7 @@ class OperationLog(TypeBase):
|
|||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
action: Mapped[str] = mapped_column(String(255), nullable=False)
|
action: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
content: Mapped[Any] = mapped_column(sa.JSON)
|
content: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@ -19,7 +21,7 @@ class ProviderType(StrEnum):
|
|||||||
SYSTEM = auto()
|
SYSTEM = auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def value_of(value: str) -> "ProviderType":
|
def value_of(value: str) -> ProviderType:
|
||||||
for member in ProviderType:
|
for member in ProviderType:
|
||||||
if member.value == value:
|
if member.value == value:
|
||||||
return member
|
return member
|
||||||
@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum):
|
|||||||
"""hosted trial quota"""
|
"""hosted trial quota"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def value_of(value: str) -> "ProviderQuotaType":
|
def value_of(value: str) -> ProviderQuotaType:
|
||||||
for member in ProviderQuotaType:
|
for member in ProviderQuotaType:
|
||||||
if member.value == value:
|
if member.value == value:
|
||||||
return member
|
return member
|
||||||
@ -76,7 +78,7 @@ class Provider(TypeBase):
|
|||||||
|
|
||||||
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
|
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
|
||||||
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
|
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
|
||||||
quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0)
|
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema_type(self) -> "ApiProviderSchemaType":
|
def schema_type(self) -> ApiProviderSchemaType:
|
||||||
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tools(self) -> list["ApiToolBundle"]:
|
def tools(self) -> list[ApiToolBundle]:
|
||||||
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
|
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase):
|
|||||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
|
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||||
return [
|
return [
|
||||||
WorkflowToolParameterConfiguration.model_validate(config)
|
WorkflowToolParameterConfiguration.model_validate(config)
|
||||||
for config in json.loads(self.parameter_configuration)
|
for config in json.loads(self.parameter_configuration)
|
||||||
@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase):
|
|||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def to_entity(self) -> "MCPProviderEntity":
|
def to_entity(self) -> MCPProviderEntity:
|
||||||
"""Convert to domain entity"""
|
"""Convert to domain entity"""
|
||||||
from core.entities.mcp_provider import MCPProviderEntity
|
from core.entities.mcp_provider import MCPProviderEntity
|
||||||
|
|
||||||
@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description_i18n(self) -> "I18nObject":
|
def description_i18n(self) -> I18nObject:
|
||||||
return I18nObject.model_validate(json.loads(self.description))
|
return I18nObject.model_validate(json.loads(self.description))
|
||||||
|
|||||||
@ -415,7 +415,7 @@ class AppTrigger(TypeBase):
|
|||||||
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
|
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
|
||||||
trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False)
|
trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False)
|
||||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
provider_name: Mapped[str] = mapped_column(String(255), server_default="", default="") # why it is nullable?
|
provider_name: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default="")
|
||||||
status: Mapped[str] = mapped_column(
|
status: Mapped[str] = mapped_column(
|
||||||
EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED
|
EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Union, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@ -67,7 +69,7 @@ class WorkflowType(StrEnum):
|
|||||||
RAG_PIPELINE = "rag-pipeline"
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowType":
|
def value_of(cls, value: str) -> WorkflowType:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -80,7 +82,7 @@ class WorkflowType(StrEnum):
|
|||||||
raise ValueError(f"invalid workflow type value {value}")
|
raise ValueError(f"invalid workflow type value {value}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
|
||||||
"""
|
"""
|
||||||
Get workflow type from app mode.
|
Get workflow type from app mode.
|
||||||
|
|
||||||
@ -181,7 +183,7 @@ class Workflow(Base): # bug
|
|||||||
rag_pipeline_variables: list[dict],
|
rag_pipeline_variables: list[dict],
|
||||||
marked_name: str = "",
|
marked_name: str = "",
|
||||||
marked_comment: str = "",
|
marked_comment: str = "",
|
||||||
) -> "Workflow":
|
) -> Workflow:
|
||||||
workflow = Workflow()
|
workflow = Workflow()
|
||||||
workflow.id = str(uuid4())
|
workflow.id = str(uuid4())
|
||||||
workflow.tenant_id = tenant_id
|
workflow.tenant_id = tenant_id
|
||||||
@ -619,7 +621,7 @@ class WorkflowRun(Base):
|
|||||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||||
|
|
||||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
pause: Mapped[WorkflowPause | None] = orm.relationship(
|
||||||
"WorkflowPause",
|
"WorkflowPause",
|
||||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -689,7 +691,7 @@ class WorkflowRun(Base):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
|
||||||
return cls(
|
return cls(
|
||||||
id=data.get("id"),
|
id=data.get("id"),
|
||||||
tenant_id=data.get("tenant_id"),
|
tenant_id=data.get("tenant_id"),
|
||||||
@ -841,7 +843,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
|
|
||||||
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
|
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
|
||||||
"WorkflowNodeExecutionOffload",
|
"WorkflowNodeExecutionOffload",
|
||||||
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
|
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
|
||||||
uselist=True,
|
uselist=True,
|
||||||
@ -851,13 +853,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preload_offload_data(
|
def preload_offload_data(
|
||||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||||
):
|
):
|
||||||
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
|
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preload_offload_data_and_files(
|
def preload_offload_data_and_files(
|
||||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||||
):
|
):
|
||||||
return query.options(
|
return query.options(
|
||||||
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
||||||
@ -932,7 +934,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
)
|
)
|
||||||
return extras
|
return extras
|
||||||
|
|
||||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
|
||||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1046,7 +1048,7 @@ class WorkflowNodeExecutionOffload(Base):
|
|||||||
back_populates="offload_data",
|
back_populates="offload_data",
|
||||||
)
|
)
|
||||||
|
|
||||||
file: Mapped[Optional["UploadFile"]] = orm.relationship(
|
file: Mapped[UploadFile | None] = orm.relationship(
|
||||||
foreign_keys=[file_id],
|
foreign_keys=[file_id],
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -1064,7 +1066,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
|||||||
INSTALLED_APP = "installed-app"
|
INSTALLED_APP = "installed-app"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
@ -1181,7 +1183,7 @@ class ConversationVariable(TypeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable":
|
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
|
||||||
obj = cls(
|
obj = cls(
|
||||||
id=variable.id,
|
id=variable.id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
@ -1334,7 +1336,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Relationship to WorkflowDraftVariableFile
|
# Relationship to WorkflowDraftVariableFile
|
||||||
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
|
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
|
||||||
foreign_keys=[file_id],
|
foreign_keys=[file_id],
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -1504,8 +1506,9 @@ class WorkflowDraftVariable(Base):
|
|||||||
node_execution_id: str | None,
|
node_execution_id: str | None,
|
||||||
description: str = "",
|
description: str = "",
|
||||||
file_id: str | None = None,
|
file_id: str | None = None,
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = WorkflowDraftVariable()
|
variable = WorkflowDraftVariable()
|
||||||
|
variable.id = str(uuid4())
|
||||||
variable.created_at = naive_utc_now()
|
variable.created_at = naive_utc_now()
|
||||||
variable.updated_at = naive_utc_now()
|
variable.updated_at = naive_utc_now()
|
||||||
variable.description = description
|
variable.description = description
|
||||||
@ -1526,7 +1529,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
name: str,
|
name: str,
|
||||||
value: Segment,
|
value: Segment,
|
||||||
description: str = "",
|
description: str = "",
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = cls._new(
|
variable = cls._new(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||||
@ -1547,7 +1550,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
value: Segment,
|
value: Segment,
|
||||||
node_execution_id: str,
|
node_execution_id: str,
|
||||||
editable: bool = False,
|
editable: bool = False,
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = cls._new(
|
variable = cls._new(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
node_id=SYSTEM_VARIABLE_NODE_ID,
|
node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||||
@ -1570,7 +1573,7 @@ class WorkflowDraftVariable(Base):
|
|||||||
visible: bool = True,
|
visible: bool = True,
|
||||||
editable: bool = True,
|
editable: bool = True,
|
||||||
file_id: str | None = None,
|
file_id: str | None = None,
|
||||||
) -> "WorkflowDraftVariable":
|
) -> WorkflowDraftVariable:
|
||||||
variable = cls._new(
|
variable = cls._new(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
@ -1666,7 +1669,7 @@ class WorkflowDraftVariableFile(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Relationship to UploadFile
|
# Relationship to UploadFile
|
||||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
upload_file: Mapped[UploadFile] = orm.relationship(
|
||||||
foreign_keys=[upload_file_id],
|
foreign_keys=[upload_file_id],
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
@ -1733,7 +1736,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
|||||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||||
|
|
||||||
# Relationship to WorkflowRun
|
# Relationship to WorkflowRun
|
||||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
workflow_run: Mapped[WorkflowRun] = orm.relationship(
|
||||||
foreign_keys=[workflow_run_id],
|
foreign_keys=[workflow_run_id],
|
||||||
# require explicit preloading.
|
# require explicit preloading.
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
@ -1789,7 +1792,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
|
||||||
if isinstance(pause_reason, HumanInputRequired):
|
if isinstance(pause_reason, HumanInputRequired):
|
||||||
return cls(
|
return cls(
|
||||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||||
|
|||||||
@ -11,13 +11,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from models import Account, ConversationVariable
|
from models import Account, ConversationVariable
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
|
from services.conversation_variable_updater import ConversationVariableUpdater
|
||||||
from services.errors.conversation import (
|
from services.errors.conversation import (
|
||||||
ConversationNotExistsError,
|
ConversationNotExistsError,
|
||||||
ConversationVariableNotExistsError,
|
ConversationVariableNotExistsError,
|
||||||
@ -337,7 +337,7 @@ class ConversationService:
|
|||||||
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
||||||
|
|
||||||
# Use the conversation variable updater to persist the changes
|
# Use the conversation variable updater to persist the changes
|
||||||
updater = conversation_variable_updater_factory()
|
updater = ConversationVariableUpdater(session_factory.get_session_maker())
|
||||||
updater.update(conversation_id, updated_variable)
|
updater.update(conversation_id, updated_variable)
|
||||||
updater.flush()
|
updater.flush()
|
||||||
|
|
||||||
|
|||||||
28
api/services/conversation_variable_updater.py
Normal file
28
api/services/conversation_variable_updater.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from core.variables.variables import Variable
|
||||||
|
from models import ConversationVariable
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableUpdater:
|
||||||
|
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||||
|
self._session_maker: sessionmaker[Session] = session_maker
|
||||||
|
|
||||||
|
def update(self, conversation_id: str, variable: Variable) -> None:
|
||||||
|
stmt = select(ConversationVariable).where(
|
||||||
|
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||||
|
)
|
||||||
|
with self._session_maker() as session:
|
||||||
|
row = session.scalar(stmt)
|
||||||
|
if not row:
|
||||||
|
raise ConversationVariableNotFoundError("conversation variable not found in the database")
|
||||||
|
row.data = variable.model_dump_json()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
pass
|
||||||
@ -874,7 +874,7 @@ class RagPipelineService:
|
|||||||
variable_pool = node_instance.graph_runtime_state.variable_pool
|
variable_pool = node_instance.graph_runtime_state.variable_pool
|
||||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||||
if invoke_from:
|
if invoke_from:
|
||||||
if invoke_from.value == InvokeFrom.PUBLISHED:
|
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
|
||||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||||
if document_id:
|
if document_id:
|
||||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||||
@ -1318,7 +1318,7 @@ class RagPipelineService:
|
|||||||
"datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
|
"datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
|
||||||
"original_document_id": document.id,
|
"original_document_id": document.id,
|
||||||
},
|
},
|
||||||
invoke_from=InvokeFrom.PUBLISHED,
|
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
call_depth=0,
|
call_depth=0,
|
||||||
workflow_thread_pool_id=None,
|
workflow_thread_pool_id=None,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator):
|
|||||||
self._max_size_bytes = max_size_bytes
|
self._max_size_bytes = max_size_bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls) -> "VariableTruncator":
|
def default(cls) -> VariableTruncator:
|
||||||
return VariableTruncator(
|
return VariableTruncator(
|
||||||
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
|
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
|
||||||
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
|
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest:
|
|||||||
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
|
def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
|
||||||
"""Create from Flask-RESTful parsed arguments."""
|
"""Create from Flask-RESTful parsed arguments."""
|
||||||
provider = args.get("provider")
|
provider = args.get("provider")
|
||||||
url = args.get("url")
|
url = args.get("url")
|
||||||
@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest:
|
|||||||
job_id: str
|
job_id: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
|
def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
|
||||||
"""Create from Flask-RESTful parsed arguments."""
|
"""Create from Flask-RESTful parsed arguments."""
|
||||||
provider = args.get("provider")
|
provider = args.get("provider")
|
||||||
if not provider:
|
if not provider:
|
||||||
|
|||||||
@ -679,6 +679,7 @@ def _batch_upsert_draft_variable(
|
|||||||
|
|
||||||
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
|
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
|
||||||
d: dict[str, Any] = {
|
d: dict[str, Any] = {
|
||||||
|
"id": model.id,
|
||||||
"app_id": model.app_id,
|
"app_id": model.app_id,
|
||||||
"last_edited_at": None,
|
"last_edited_at": None,
|
||||||
"node_id": model.node_id,
|
"node_id": model.node_id,
|
||||||
|
|||||||
@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
user=account,
|
user=account,
|
||||||
application_generate_entity=entity,
|
application_generate_entity=entity,
|
||||||
invoke_from=InvokeFrom.PUBLISHED,
|
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||||
workflow_execution_repository=workflow_execution_repository,
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
|
|||||||
@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
user=account,
|
user=account,
|
||||||
application_generate_entity=entity,
|
application_generate_entity=entity,
|
||||||
invoke_from=InvokeFrom.PUBLISHED,
|
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||||
workflow_execution_repository=workflow_execution_repository,
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
|
|||||||
@ -165,7 +165,7 @@ class TestRagPipelineRunTasks:
|
|||||||
"files": [],
|
"files": [],
|
||||||
"user_id": account.id,
|
"user_id": account.id,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"invoke_from": "published",
|
"invoke_from": InvokeFrom.PUBLISHED_PIPELINE.value,
|
||||||
"workflow_execution_id": str(uuid.uuid4()),
|
"workflow_execution_id": str(uuid.uuid4()),
|
||||||
"pipeline_config": {
|
"pipeline_config": {
|
||||||
"app_id": str(uuid.uuid4()),
|
"app_id": str(uuid.uuid4()),
|
||||||
@ -249,7 +249,7 @@ class TestRagPipelineRunTasks:
|
|||||||
assert call_kwargs["pipeline"].id == pipeline.id
|
assert call_kwargs["pipeline"].id == pipeline.id
|
||||||
assert call_kwargs["workflow_id"] == workflow.id
|
assert call_kwargs["workflow_id"] == workflow.id
|
||||||
assert call_kwargs["user"].id == account.id
|
assert call_kwargs["user"].id == account.id
|
||||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
|
||||||
assert call_kwargs["streaming"] == False
|
assert call_kwargs["streaming"] == False
|
||||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||||
|
|
||||||
@ -294,7 +294,7 @@ class TestRagPipelineRunTasks:
|
|||||||
assert call_kwargs["pipeline"].id == pipeline.id
|
assert call_kwargs["pipeline"].id == pipeline.id
|
||||||
assert call_kwargs["workflow_id"] == workflow.id
|
assert call_kwargs["workflow_id"] == workflow.id
|
||||||
assert call_kwargs["user"].id == account.id
|
assert call_kwargs["user"].id == account.id
|
||||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
|
||||||
assert call_kwargs["streaming"] == False
|
assert call_kwargs["streaming"] == False
|
||||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||||
|
|
||||||
@ -743,7 +743,7 @@ class TestRagPipelineRunTasks:
|
|||||||
assert call_kwargs["pipeline"].id == pipeline.id
|
assert call_kwargs["pipeline"].id == pipeline.id
|
||||||
assert call_kwargs["workflow_id"] == workflow.id
|
assert call_kwargs["workflow_id"] == workflow.id
|
||||||
assert call_kwargs["user"].id == account.id
|
assert call_kwargs["user"].id == account.id
|
||||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
|
||||||
assert call_kwargs["streaming"] == False
|
assert call_kwargs["streaming"] == False
|
||||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
|
|||||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||||
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing
|
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing
|
||||||
|
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||||
monkeypatch.setenv("DB_HOST", "localhost")
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
@ -51,6 +52,7 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
|
|||||||
os.environ.clear()
|
os.environ.clear()
|
||||||
|
|
||||||
# Set minimal required env vars
|
# Set minimal required env vars
|
||||||
|
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||||
monkeypatch.setenv("DB_HOST", "localhost")
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
@ -75,6 +77,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Set environment variables using monkeypatch
|
# Set environment variables using monkeypatch
|
||||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||||
|
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||||
monkeypatch.setenv("DB_HOST", "localhost")
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
@ -124,6 +127,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Set environment variables using monkeypatch
|
# Set environment variables using monkeypatch
|
||||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||||
|
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||||
monkeypatch.setenv("DB_HOST", "localhost")
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
@ -140,6 +144,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
|
|||||||
def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
|
def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Test that DB_EXTRAS options are properly merged with default timezone setting"""
|
"""Test that DB_EXTRAS options are properly merged with default timezone setting"""
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
|
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||||
monkeypatch.setenv("DB_HOST", "localhost")
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
@ -199,6 +204,7 @@ def test_celery_broker_url_with_special_chars_password(
|
|||||||
# Set up basic required environment variables (following existing pattern)
|
# Set up basic required environment variables (following existing pattern)
|
||||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||||
|
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||||
monkeypatch.setenv("DB_HOST", "localhost")
|
monkeypatch.setenv("DB_HOST", "localhost")
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
|
import builtins
|
||||||
import io
|
import io
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from flask.views import MethodView
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.common.errors import (
|
from controllers.common.errors import (
|
||||||
@ -14,6 +16,9 @@ from controllers.common.errors import (
|
|||||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||||
|
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
class TestFileUploadSecurity:
|
class TestFileUploadSecurity:
|
||||||
"""Test file upload security logic without complex framework setup"""
|
"""Test file upload security logic without complex framework setup"""
|
||||||
@ -128,7 +133,7 @@ class TestFileUploadSecurity:
|
|||||||
# Test passes if no exception is raised
|
# Test passes if no exception is raised
|
||||||
|
|
||||||
# Test 4: Service error handling
|
# Test 4: Service error handling
|
||||||
@patch("services.file_service.FileService.upload_file")
|
@patch("controllers.console.files.FileService.upload_file")
|
||||||
def test_should_handle_file_too_large_error(self, mock_upload):
|
def test_should_handle_file_too_large_error(self, mock_upload):
|
||||||
"""Test that service FileTooLargeError is properly converted"""
|
"""Test that service FileTooLargeError is properly converted"""
|
||||||
mock_upload.side_effect = ServiceFileTooLargeError("File too large")
|
mock_upload.side_effect = ServiceFileTooLargeError("File too large")
|
||||||
@ -140,7 +145,7 @@ class TestFileUploadSecurity:
|
|||||||
with pytest.raises(FileTooLargeError):
|
with pytest.raises(FileTooLargeError):
|
||||||
raise FileTooLargeError(e.description)
|
raise FileTooLargeError(e.description)
|
||||||
|
|
||||||
@patch("services.file_service.FileService.upload_file")
|
@patch("controllers.console.files.FileService.upload_file")
|
||||||
def test_should_handle_unsupported_file_type_error(self, mock_upload):
|
def test_should_handle_unsupported_file_type_error(self, mock_upload):
|
||||||
"""Test that service UnsupportedFileTypeError is properly converted"""
|
"""Test that service UnsupportedFileTypeError is properly converted"""
|
||||||
mock_upload.side_effect = ServiceUnsupportedFileTypeError()
|
mock_upload.side_effect = ServiceUnsupportedFileTypeError()
|
||||||
|
|||||||
@ -431,10 +431,10 @@ class TestWorkflowResponseConverterServiceApiTruncation:
|
|||||||
description="Explore calls should have truncation enabled",
|
description="Explore calls should have truncation enabled",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
name="published_truncation_enabled",
|
name="published_pipeline_truncation_enabled",
|
||||||
invoke_from=InvokeFrom.PUBLISHED,
|
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||||
expected_truncation_enabled=True,
|
expected_truncation_enabled=True,
|
||||||
description="Published app calls should have truncation enabled",
|
description="Published pipeline calls should have truncation enabled",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
ids=lambda x: x.name,
|
ids=lambda x: x.name,
|
||||||
|
|||||||
@ -0,0 +1,144 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||||
|
from core.variables import StringVariable
|
||||||
|
from core.variables.segments import Segment
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||||
|
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||||
|
from core.workflow.node_events import NodeRunResult
|
||||||
|
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
|
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
|
class MockReadOnlyVariablePool:
|
||||||
|
def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None:
|
||||||
|
self._variables = variables or {}
|
||||||
|
|
||||||
|
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||||
|
if len(selector) < 2:
|
||||||
|
return None
|
||||||
|
return self._variables.get((selector[0], selector[1]))
|
||||||
|
|
||||||
|
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||||
|
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||||
|
|
||||||
|
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||||
|
return {key: value for (nid, key), value in self._variables.items() if nid == prefix}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_graph_runtime_state(
|
||||||
|
variable_pool: MockReadOnlyVariablePool,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> ReadOnlyGraphRuntimeState:
|
||||||
|
graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState)
|
||||||
|
graph_runtime_state.variable_pool = variable_pool
|
||||||
|
graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view()
|
||||||
|
return graph_runtime_state
|
||||||
|
|
||||||
|
|
||||||
|
def _build_node_run_succeeded_event(
|
||||||
|
*,
|
||||||
|
node_type: NodeType,
|
||||||
|
outputs: dict[str, object] | None = None,
|
||||||
|
process_data: dict[str, object] | None = None,
|
||||||
|
) -> NodeRunSucceededEvent:
|
||||||
|
return NodeRunSucceededEvent(
|
||||||
|
id="node-exec-id",
|
||||||
|
node_id="assigner",
|
||||||
|
node_type=node_type,
|
||||||
|
start_at=datetime.utcnow(),
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
outputs=outputs or {},
|
||||||
|
process_data=process_data or {},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_persists_conversation_variables_from_assigner_output():
|
||||||
|
conversation_id = "conv-123"
|
||||||
|
variable = StringVariable(
|
||||||
|
id="var-1",
|
||||||
|
name="name",
|
||||||
|
value="updated",
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||||
|
)
|
||||||
|
process_data = common_helpers.set_updated_variables(
|
||||||
|
{}, [common_helpers.variable_to_processed_data(variable.selector, variable)]
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||||
|
|
||||||
|
updater = Mock()
|
||||||
|
layer = ConversationVariablePersistenceLayer(updater)
|
||||||
|
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||||
|
|
||||||
|
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable)
|
||||||
|
updater.flush.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_when_outputs_missing():
|
||||||
|
conversation_id = "conv-456"
|
||||||
|
variable = StringVariable(
|
||||||
|
id="var-2",
|
||||||
|
name="name",
|
||||||
|
value="updated",
|
||||||
|
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||||
|
|
||||||
|
updater = Mock()
|
||||||
|
layer = ConversationVariablePersistenceLayer(updater)
|
||||||
|
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||||
|
|
||||||
|
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER)
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
updater.update.assert_not_called()
|
||||||
|
updater.flush.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_non_assigner_nodes():
|
||||||
|
updater = Mock()
|
||||||
|
layer = ConversationVariablePersistenceLayer(updater)
|
||||||
|
layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel))
|
||||||
|
|
||||||
|
event = _build_node_run_succeeded_event(node_type=NodeType.LLM)
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
updater.update.assert_not_called()
|
||||||
|
updater.flush.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_non_conversation_variables():
|
||||||
|
conversation_id = "conv-789"
|
||||||
|
non_conversation_variable = StringVariable(
|
||||||
|
id="var-3",
|
||||||
|
name="name",
|
||||||
|
value="updated",
|
||||||
|
selector=["environment", "name"],
|
||||||
|
)
|
||||||
|
process_data = common_helpers.set_updated_variables(
|
||||||
|
{}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)]
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = MockReadOnlyVariablePool()
|
||||||
|
|
||||||
|
updater = Mock()
|
||||||
|
layer = ConversationVariablePersistenceLayer(updater)
|
||||||
|
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||||
|
|
||||||
|
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||||
|
layer.on_event(event)
|
||||||
|
|
||||||
|
updater.update.assert_not_called()
|
||||||
|
updater.flush.assert_not_called()
|
||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
from time import time
|
from time import time
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
@ -67,8 +68,10 @@ class MockReadOnlyVariablePool:
|
|||||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||||
self._variables = variables or {}
|
self._variables = variables or {}
|
||||||
|
|
||||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||||
value = self._variables.get((node_id, variable_key))
|
if len(selector) < 2:
|
||||||
|
return None
|
||||||
|
value = self._variables.get((selector[0], selector[1]))
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
mock_segment = Mock(spec=Segment)
|
mock_segment = Mock(spec=Segment)
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
"""Primarily used for testing merged cell scenarios"""
|
"""Primarily used for testing merged cell scenarios"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from docx import Document
|
from docx import Document
|
||||||
|
from docx.oxml import OxmlElement
|
||||||
|
from docx.oxml.ns import qn
|
||||||
|
|
||||||
import core.rag.extractor.word_extractor as we
|
import core.rag.extractor.word_extractor as we
|
||||||
from core.rag.extractor.word_extractor import WordExtractor
|
from core.rag.extractor.word_extractor import WordExtractor
|
||||||
@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url():
|
|||||||
dify_config.FILES_URL = original_files_url
|
dify_config.FILES_URL = original_files_url
|
||||||
if original_internal_files_url is not None:
|
if original_internal_files_url is not None:
|
||||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_hyperlinks(monkeypatch):
|
||||||
|
# Mock db and storage to avoid issues during image extraction (even if no images are present)
|
||||||
|
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||||
|
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
|
||||||
|
monkeypatch.setattr(we, "db", db_stub)
|
||||||
|
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||||
|
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
p = doc.add_paragraph("Visit ")
|
||||||
|
|
||||||
|
# Adding a hyperlink manually
|
||||||
|
r_id = "rId99"
|
||||||
|
hyperlink = OxmlElement("w:hyperlink")
|
||||||
|
hyperlink.set(qn("r:id"), r_id)
|
||||||
|
|
||||||
|
new_run = OxmlElement("w:r")
|
||||||
|
t = OxmlElement("w:t")
|
||||||
|
t.text = "Dify"
|
||||||
|
new_run.append(t)
|
||||||
|
hyperlink.append(new_run)
|
||||||
|
p._p.append(hyperlink)
|
||||||
|
|
||||||
|
# Add relationship to the part
|
||||||
|
doc.part.rels.add_relationship(
|
||||||
|
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink",
|
||||||
|
"https://dify.ai",
|
||||||
|
r_id,
|
||||||
|
is_external=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||||
|
doc.save(tmp.name)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
|
||||||
|
docs = extractor.extract()
|
||||||
|
# Verify modern hyperlink extraction
|
||||||
|
assert "Visit[Dify](https://dify.ai)" in docs[0].page_content
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_legacy_hyperlinks(monkeypatch):
|
||||||
|
# Mock db and storage
|
||||||
|
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||||
|
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
|
||||||
|
monkeypatch.setattr(we, "db", db_stub)
|
||||||
|
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||||
|
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
p = doc.add_paragraph()
|
||||||
|
|
||||||
|
# Construct a legacy HYPERLINK field:
|
||||||
|
# 1. w:fldChar (begin)
|
||||||
|
# 2. w:instrText (HYPERLINK "http://example.com")
|
||||||
|
# 3. w:fldChar (separate)
|
||||||
|
# 4. w:r (visible text "Example")
|
||||||
|
# 5. w:fldChar (end)
|
||||||
|
|
||||||
|
run1 = OxmlElement("w:r")
|
||||||
|
fldCharBegin = OxmlElement("w:fldChar")
|
||||||
|
fldCharBegin.set(qn("w:fldCharType"), "begin")
|
||||||
|
run1.append(fldCharBegin)
|
||||||
|
p._p.append(run1)
|
||||||
|
|
||||||
|
run2 = OxmlElement("w:r")
|
||||||
|
instrText = OxmlElement("w:instrText")
|
||||||
|
instrText.text = ' HYPERLINK "http://example.com" '
|
||||||
|
run2.append(instrText)
|
||||||
|
p._p.append(run2)
|
||||||
|
|
||||||
|
run3 = OxmlElement("w:r")
|
||||||
|
fldCharSep = OxmlElement("w:fldChar")
|
||||||
|
fldCharSep.set(qn("w:fldCharType"), "separate")
|
||||||
|
run3.append(fldCharSep)
|
||||||
|
p._p.append(run3)
|
||||||
|
|
||||||
|
run4 = OxmlElement("w:r")
|
||||||
|
t4 = OxmlElement("w:t")
|
||||||
|
t4.text = "Example"
|
||||||
|
run4.append(t4)
|
||||||
|
p._p.append(run4)
|
||||||
|
|
||||||
|
run5 = OxmlElement("w:r")
|
||||||
|
fldCharEnd = OxmlElement("w:fldChar")
|
||||||
|
fldCharEnd.set(qn("w:fldCharType"), "end")
|
||||||
|
run5.append(fldCharEnd)
|
||||||
|
p._p.append(run5)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||||
|
doc.save(tmp.name)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
|
||||||
|
docs = extractor.extract()
|
||||||
|
# Verify legacy hyperlink extraction
|
||||||
|
assert "[Example](http://example.com)" in docs[0].page_content
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|||||||
@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing
|
|||||||
the behavior of mock nodes during testing.
|
the behavior of mock nodes during testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -95,67 +97,67 @@ class MockConfigBuilder:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._config = MockConfig()
|
self._config = MockConfig()
|
||||||
|
|
||||||
def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder":
|
def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder:
|
||||||
"""Enable or disable auto-mocking."""
|
"""Enable or disable auto-mocking."""
|
||||||
self._config.enable_auto_mock = enabled
|
self._config.enable_auto_mock = enabled
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_delays(self, enabled: bool = True) -> "MockConfigBuilder":
|
def with_delays(self, enabled: bool = True) -> MockConfigBuilder:
|
||||||
"""Enable or disable simulated execution delays."""
|
"""Enable or disable simulated execution delays."""
|
||||||
self._config.simulate_delays = enabled
|
self._config.simulate_delays = enabled
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_llm_response(self, response: str) -> "MockConfigBuilder":
|
def with_llm_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default LLM response."""
|
"""Set default LLM response."""
|
||||||
self._config.default_llm_response = response
|
self._config.default_llm_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_agent_response(self, response: str) -> "MockConfigBuilder":
|
def with_agent_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default agent response."""
|
"""Set default agent response."""
|
||||||
self._config.default_agent_response = response
|
self._config.default_agent_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default tool response."""
|
"""Set default tool response."""
|
||||||
self._config.default_tool_response = response
|
self._config.default_tool_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_retrieval_response(self, response: str) -> "MockConfigBuilder":
|
def with_retrieval_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default retrieval response."""
|
"""Set default retrieval response."""
|
||||||
self._config.default_retrieval_response = response
|
self._config.default_retrieval_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default HTTP response."""
|
"""Set default HTTP response."""
|
||||||
self._config.default_http_response = response
|
self._config.default_http_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_template_transform_response(self, response: str) -> "MockConfigBuilder":
|
def with_template_transform_response(self, response: str) -> MockConfigBuilder:
|
||||||
"""Set default template transform response."""
|
"""Set default template transform response."""
|
||||||
self._config.default_template_transform_response = response
|
self._config.default_template_transform_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default code execution response."""
|
"""Set default code execution response."""
|
||||||
self._config.default_code_response = response
|
self._config.default_code_response = response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder":
|
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set outputs for a specific node."""
|
"""Set outputs for a specific node."""
|
||||||
self._config.set_node_outputs(node_id, outputs)
|
self._config.set_node_outputs(node_id, outputs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder":
|
def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder:
|
||||||
"""Set error for a specific node."""
|
"""Set error for a specific node."""
|
||||||
self._config.set_node_error(node_id, error)
|
self._config.set_node_error(node_id, error)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder":
|
def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder:
|
||||||
"""Add a node-specific configuration."""
|
"""Add a node-specific configuration."""
|
||||||
self._config.set_node_config(config.node_id, config)
|
self._config.set_node_config(config.node_id, config)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder":
|
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder:
|
||||||
"""Set default configuration for a node type."""
|
"""Set default configuration for a node type."""
|
||||||
self._config.set_default_config(node_type, config)
|
self._config.set_default_config(node_type, config)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -78,7 +78,7 @@ class TestFileSaverImpl:
|
|||||||
file_binary=_PNG_DATA,
|
file_binary=_PNG_DATA,
|
||||||
mimetype=mime_type,
|
mimetype=mime_type,
|
||||||
)
|
)
|
||||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True)
|
||||||
|
|
||||||
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
|
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
_TEST_URL = "https://example.com/image.png"
|
_TEST_URL = "https://example.com/image.png"
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tool_node(monkeypatch) -> "ToolNode":
|
def tool_node(monkeypatch) -> ToolNode:
|
||||||
module_name = "core.ops.ops_trace_manager"
|
module_name = "core.ops.ops_trace_manager"
|
||||||
if module_name not in sys.modules:
|
if module_name not in sys.modules:
|
||||||
ops_stub = types.ModuleType(module_name)
|
ops_stub = types.ModuleType(module_name)
|
||||||
@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
|
|||||||
return events, stop.value
|
return events, stop.value
|
||||||
|
|
||||||
|
|
||||||
def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
|
def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
|
||||||
def _identity_transform(messages, *_args, **_kwargs):
|
def _identity_transform(messages, *_args, **_kwargs):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l
|
|||||||
return _collect_events(generator)
|
return _collect_events(generator)
|
||||||
|
|
||||||
|
|
||||||
def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
|
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
|
||||||
file_obj = File(
|
file_obj = File(
|
||||||
tenant_id="tenant-id",
|
tenant_id="tenant-id",
|
||||||
type=FileType.DOCUMENT,
|
type=FileType.DOCUMENT,
|
||||||
@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
|
|||||||
assert files_segment.value == [file_obj]
|
assert files_segment.value == [file_obj]
|
||||||
|
|
||||||
|
|
||||||
def test_plain_link_messages_remain_links(tool_node: "ToolNode"):
|
def test_plain_link_messages_remain_links(tool_node: ToolNode):
|
||||||
message = ToolInvokeMessage(
|
message = ToolInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.LINK,
|
type=ToolInvokeMessage.MessageType.LINK,
|
||||||
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),
|
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user