diff --git a/.agent/skills b/.agent/skills
new file mode 120000
index 0000000000..454b8427cd
--- /dev/null
+++ b/.agent/skills
@@ -0,0 +1 @@
+../.claude/skills
\ No newline at end of file
diff --git a/.claude/settings.json b/.claude/settings.json
index 509dbe8447..f9e1016d02 100644
--- a/.claude/settings.json
+++ b/.claude/settings.json
@@ -1,9 +1,20 @@
{
+ "hooks": {
+ "PreToolUse": [
+ {
+ "matcher": "Bash",
+ "hooks": [
+ {
+ "type": "command",
+ "command": "npx -y block-no-verify@1.1.1"
+ }
+ ]
+ }
+ ]
+ },
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
- "typescript-lsp@claude-plugins-official": true,
- "pyright-lsp@claude-plugins-official": true,
"ralph-loop@claude-plugins-official": true
}
}
diff --git a/.claude/skills/frontend-testing/SKILL.md b/.claude/skills/frontend-testing/SKILL.md
index dd9677a78e..0716c81ef7 100644
--- a/.claude/skills/frontend-testing/SKILL.md
+++ b/.claude/skills/frontend-testing/SKILL.md
@@ -83,6 +83,9 @@ vi.mock('next/navigation', () => ({
usePathname: () => '/test',
}))
+// ✅ Zustand stores: Use real stores (auto-mocked globally)
+// Set test state with: useAppStore.setState({ ... })
+
// Shared state for mocks (if needed)
let mockSharedState = false
@@ -296,7 +299,7 @@ For each test file generated, aim for:
For more detailed information, refer to:
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
-- `references/mocking.md` - Mock patterns and best practices
+- `references/mocking.md` - Mock patterns, Zustand store testing, and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.claude/skills/frontend-testing/references/mocking.md
index c70bcf0ae5..86bd375987 100644
--- a/.claude/skills/frontend-testing/references/mocking.md
+++ b/.claude/skills/frontend-testing/references/mocking.md
@@ -37,16 +37,36 @@ Only mock these categories:
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
+### Zustand Stores - DO NOT Mock Manually
+
+**Zustand is globally mocked** in `web/vitest.setup.ts`. Use real stores with `setState()`:
+
+```typescript
+// ✅ CORRECT: Use real store, set test state
+import { useAppStore } from '@/app/components/app/store'
+
+useAppStore.setState({ appDetail: { id: 'test', name: 'Test' } })
+render()
+
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({ ... }))
+```
+
+See [Zustand Store Testing](#zustand-store-testing) section for full details.
+
## Mock Placement
| Location | Purpose |
|----------|---------|
-| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
+| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) |
+| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
+**Note**: Zustand is special - it's globally mocked but you should NOT mock store modules manually. See [Zustand Store Testing](#zustand-store-testing).
+
## Essential Mocks
### 1. i18n (Auto-loaded via Global Mock)
@@ -276,6 +296,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
+1. **Use real Zustand stores** - Set test state via `store.setState()`
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
@@ -285,6 +306,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
@@ -308,10 +330,151 @@ Need to use a component in test?
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
│
+├─ Is it a Zustand store?
+│ └─ YES → DO NOT mock the module!
+│ Use real store + setState() to set test state
+│ (Global mock handles auto-reset)
+│
└─ Is it i18n?
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
+## Zustand Store Testing
+
+### Global Zustand Mock (Auto-loaded)
+
+Zustand is globally mocked in `web/vitest.setup.ts` following the [official Zustand testing guide](https://zustand.docs.pmnd.rs/guides/testing). The mock in `web/__mocks__/zustand.ts` provides:
+
+- Real store behavior with `getState()`, `setState()`, `subscribe()` methods
+- Automatic store reset after each test via `afterEach`
+- Proper test isolation between tests
+
+### ✅ Recommended: Use Real Stores (Official Best Practice)
+
+**DO NOT mock store modules manually.** Import and use the real store, then use `setState()` to set test state:
+
+```typescript
+// ✅ CORRECT: Use real store with setState
+import { useAppStore } from '@/app/components/app/store'
+
+describe('MyComponent', () => {
+ it('should render app details', () => {
+ // Arrange: Set test state via setState
+ useAppStore.setState({
+ appDetail: {
+ id: 'test-app',
+ name: 'Test App',
+ mode: 'chat',
+ },
+ })
+
+ // Act
+ render()
+
+ // Assert
+ expect(screen.getByText('Test App')).toBeInTheDocument()
+ // Can also verify store state directly
+ expect(useAppStore.getState().appDetail?.name).toBe('Test App')
+ })
+
+ // No cleanup needed - global mock auto-resets after each test
+})
+```
+
+### ❌ Avoid: Manual Store Module Mocking
+
+Manual mocking conflicts with the global Zustand mock and loses store functionality:
+
+```typescript
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({
+ useStore: (selector) => mockSelector(selector), // Missing getState, setState!
+}))
+
+// ❌ WRONG: This conflicts with global zustand mock
+vi.mock('@/app/components/workflow/store', () => ({
+ useWorkflowStore: vi.fn(() => mockState),
+}))
+```
+
+**Problems with manual mocking:**
+
+1. Loses `getState()`, `setState()`, `subscribe()` methods
+1. Conflicts with global Zustand mock behavior
+1. Requires manual maintenance of store API
+1. Tests don't reflect actual store behavior
+
+### When Manual Store Mocking is Necessary
+
+In rare cases where the store has complex initialization or side effects, you can mock it, but ensure you provide the full store API:
+
+```typescript
+// If you MUST mock (rare), include full store API
+const mockStore = {
+ appDetail: { id: 'test', name: 'Test' },
+ setAppDetail: vi.fn(),
+}
+
+vi.mock('@/app/components/app/store', () => ({
+ useStore: Object.assign(
+ (selector: (state: typeof mockStore) => unknown) => selector(mockStore),
+ {
+ getState: () => mockStore,
+ setState: vi.fn(),
+ subscribe: vi.fn(),
+ },
+ ),
+}))
+```
+
+### Store Testing Decision Tree
+
+```
+Need to test a component using Zustand store?
+│
+├─ Can you use the real store?
+│ └─ YES → Use real store + setState (RECOMMENDED)
+│ useAppStore.setState({ ... })
+│
+├─ Does the store have complex initialization/side effects?
+│ └─ YES → Consider mocking, but include full API
+│ (getState, setState, subscribe)
+│
+└─ Are you testing the store itself (not a component)?
+ └─ YES → Test store directly with getState/setState
+ const store = useMyStore
+ store.setState({ count: 0 })
+ store.getState().increment()
+ expect(store.getState().count).toBe(1)
+```
+
+### Example: Testing Store Actions
+
+```typescript
+import { useCounterStore } from '@/stores/counter'
+
+describe('Counter Store', () => {
+ it('should increment count', () => {
+ // Initial state (auto-reset by global mock)
+ expect(useCounterStore.getState().count).toBe(0)
+
+ // Call action
+ useCounterStore.getState().increment()
+
+ // Verify state change
+ expect(useCounterStore.getState().count).toBe(1)
+ })
+
+ it('should reset to initial state', () => {
+ // Set some state
+ useCounterStore.setState({ count: 100 })
+ expect(useCounterStore.getState().count).toBe(100)
+
+ // After this test, global mock will reset to initial state
+ })
+})
+```
+
## Factory Function Pattern
```typescript
diff --git a/.claude/skills/orpc-contract-first/SKILL.md b/.claude/skills/orpc-contract-first/SKILL.md
new file mode 100644
index 0000000000..4e3bfc7a37
--- /dev/null
+++ b/.claude/skills/orpc-contract-first/SKILL.md
@@ -0,0 +1,46 @@
+---
+name: orpc-contract-first
+description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
+---
+
+# oRPC Contract-First Development
+
+## Project Structure
+
+```
+web/contract/
+├── base.ts # Base contract (inputStructure: 'detailed')
+├── router.ts # Router composition & type exports
+├── marketplace.ts # Marketplace contracts
+└── console/ # Console contracts by domain
+ ├── system.ts
+ └── billing.ts
+```
+
+## Workflow
+
+1. **Create contract** in `web/contract/console/{domain}.ts`
+ - Import `base` from `../base` and `type` from `@orpc/contract`
+ - Define route with `path`, `method`, `input`, `output`
+
+2. **Register in router** at `web/contract/router.ts`
+ - Import directly from domain file (no barrel files)
+ - Nest by API prefix: `billing: { invoices, bindPartnerStack }`
+
+3. **Create hooks** in `web/service/use-{domain}.ts`
+ - Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
+ - Use `consoleClient.{group}.{contract}()` for API calls
+
+## Key Rules
+
+- **Input structure**: Always use `{ params, query?, body? }` format
+- **Path params**: Use `{paramName}` in path, match in `params` object
+- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`)
+- **No barrel files**: Import directly from specific files
+- **Types**: Import from `@/types/`, use `type()` helper
+
+## Type Export
+
+```typescript
+export type ConsoleInputs = InferContractRouterInputs
+```
diff --git a/.claude/skills/vercel-react-best-practices/AGENTS.md b/.claude/skills/vercel-react-best-practices/AGENTS.md
new file mode 100644
index 0000000000..f9b9e99c44
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/AGENTS.md
@@ -0,0 +1,2410 @@
+# React Best Practices
+
+**Version 1.0.0**
+Vercel Engineering
+January 2026
+
+> **Note:**
+> This document is mainly for agents and LLMs to follow when maintaining,
+> generating, or refactoring React and Next.js codebases at Vercel. Humans
+> may also find it useful, but guidance here is optimized for automation
+> and consistency by AI-assisted workflows.
+
+---
+
+## Abstract
+
+Comprehensive performance optimization guide for React and Next.js applications, designed for AI agents and LLMs. Contains 40+ rules across 8 categories, prioritized by impact from critical (eliminating waterfalls, reducing bundle size) to incremental (advanced patterns). Each rule includes detailed explanations, real-world examples comparing incorrect vs. correct implementations, and specific impact metrics to guide automated refactoring and code generation.
+
+---
+
+## Table of Contents
+
+1. [Eliminating Waterfalls](#1-eliminating-waterfalls) — **CRITICAL**
+ - 1.1 [Defer Await Until Needed](#11-defer-await-until-needed)
+ - 1.2 [Dependency-Based Parallelization](#12-dependency-based-parallelization)
+ - 1.3 [Prevent Waterfall Chains in API Routes](#13-prevent-waterfall-chains-in-api-routes)
+ - 1.4 [Promise.all() for Independent Operations](#14-promiseall-for-independent-operations)
+ - 1.5 [Strategic Suspense Boundaries](#15-strategic-suspense-boundaries)
+2. [Bundle Size Optimization](#2-bundle-size-optimization) — **CRITICAL**
+ - 2.1 [Avoid Barrel File Imports](#21-avoid-barrel-file-imports)
+ - 2.2 [Conditional Module Loading](#22-conditional-module-loading)
+ - 2.3 [Defer Non-Critical Third-Party Libraries](#23-defer-non-critical-third-party-libraries)
+ - 2.4 [Dynamic Imports for Heavy Components](#24-dynamic-imports-for-heavy-components)
+ - 2.5 [Preload Based on User Intent](#25-preload-based-on-user-intent)
+3. [Server-Side Performance](#3-server-side-performance) — **HIGH**
+ - 3.1 [Cross-Request LRU Caching](#31-cross-request-lru-caching)
+ - 3.2 [Minimize Serialization at RSC Boundaries](#32-minimize-serialization-at-rsc-boundaries)
+ - 3.3 [Parallel Data Fetching with Component Composition](#33-parallel-data-fetching-with-component-composition)
+ - 3.4 [Per-Request Deduplication with React.cache()](#34-per-request-deduplication-with-reactcache)
+ - 3.5 [Use after() for Non-Blocking Operations](#35-use-after-for-non-blocking-operations)
+4. [Client-Side Data Fetching](#4-client-side-data-fetching) — **MEDIUM-HIGH**
+ - 4.1 [Deduplicate Global Event Listeners](#41-deduplicate-global-event-listeners)
+ - 4.2 [Use Passive Event Listeners for Scrolling Performance](#42-use-passive-event-listeners-for-scrolling-performance)
+ - 4.3 [Use SWR for Automatic Deduplication](#43-use-swr-for-automatic-deduplication)
+ - 4.4 [Version and Minimize localStorage Data](#44-version-and-minimize-localstorage-data)
+5. [Re-render Optimization](#5-re-render-optimization) — **MEDIUM**
+ - 5.1 [Defer State Reads to Usage Point](#51-defer-state-reads-to-usage-point)
+ - 5.2 [Extract to Memoized Components](#52-extract-to-memoized-components)
+ - 5.3 [Narrow Effect Dependencies](#53-narrow-effect-dependencies)
+ - 5.4 [Subscribe to Derived State](#54-subscribe-to-derived-state)
+ - 5.5 [Use Functional setState Updates](#55-use-functional-setstate-updates)
+ - 5.6 [Use Lazy State Initialization](#56-use-lazy-state-initialization)
+ - 5.7 [Use Transitions for Non-Urgent Updates](#57-use-transitions-for-non-urgent-updates)
+6. [Rendering Performance](#6-rendering-performance) — **MEDIUM**
+ - 6.1 [Animate SVG Wrapper Instead of SVG Element](#61-animate-svg-wrapper-instead-of-svg-element)
+ - 6.2 [CSS content-visibility for Long Lists](#62-css-content-visibility-for-long-lists)
+ - 6.3 [Hoist Static JSX Elements](#63-hoist-static-jsx-elements)
+ - 6.4 [Optimize SVG Precision](#64-optimize-svg-precision)
+ - 6.5 [Prevent Hydration Mismatch Without Flickering](#65-prevent-hydration-mismatch-without-flickering)
+ - 6.6 [Use Activity Component for Show/Hide](#66-use-activity-component-for-showhide)
+ - 6.7 [Use Explicit Conditional Rendering](#67-use-explicit-conditional-rendering)
+7. [JavaScript Performance](#7-javascript-performance) — **LOW-MEDIUM**
+ - 7.1 [Batch DOM CSS Changes](#71-batch-dom-css-changes)
+ - 7.2 [Build Index Maps for Repeated Lookups](#72-build-index-maps-for-repeated-lookups)
+ - 7.3 [Cache Property Access in Loops](#73-cache-property-access-in-loops)
+ - 7.4 [Cache Repeated Function Calls](#74-cache-repeated-function-calls)
+ - 7.5 [Cache Storage API Calls](#75-cache-storage-api-calls)
+ - 7.6 [Combine Multiple Array Iterations](#76-combine-multiple-array-iterations)
+ - 7.7 [Early Length Check for Array Comparisons](#77-early-length-check-for-array-comparisons)
+ - 7.8 [Early Return from Functions](#78-early-return-from-functions)
+ - 7.9 [Hoist RegExp Creation](#79-hoist-regexp-creation)
+ - 7.10 [Use Loop for Min/Max Instead of Sort](#710-use-loop-for-minmax-instead-of-sort)
+ - 7.11 [Use Set/Map for O(1) Lookups](#711-use-setmap-for-o1-lookups)
+ - 7.12 [Use toSorted() Instead of sort() for Immutability](#712-use-tosorted-instead-of-sort-for-immutability)
+8. [Advanced Patterns](#8-advanced-patterns) — **LOW**
+ - 8.1 [Store Event Handlers in Refs](#81-store-event-handlers-in-refs)
+ - 8.2 [useLatest for Stable Callback Refs](#82-uselatest-for-stable-callback-refs)
+
+---
+
+## 1. Eliminating Waterfalls
+
+**Impact: CRITICAL**
+
+Waterfalls are the #1 performance killer. Each sequential await adds full network latency. Eliminating them yields the largest gains.
+
+### 1.1 Defer Await Until Needed
+
+**Impact: HIGH (avoids blocking unused code paths)**
+
+Move `await` operations into the branches where they're actually used to avoid blocking code paths that don't need them.
+
+**Incorrect: blocks both branches**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ const userData = await fetchUserData(userId)
+
+ if (skipProcessing) {
+ // Returns immediately but still waited for userData
+ return { skipped: true }
+ }
+
+ // Only this branch uses userData
+ return processUserData(userData)
+}
+```
+
+**Correct: only blocks when needed**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ if (skipProcessing) {
+ // Returns immediately without waiting
+ return { skipped: true }
+ }
+
+ // Fetch only when needed
+ const userData = await fetchUserData(userId)
+ return processUserData(userData)
+}
+```
+
+**Another example: early return optimization**
+
+```typescript
+// Incorrect: always fetches permissions
+async function updateResource(resourceId: string, userId: string) {
+ const permissions = await fetchPermissions(userId)
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+
+// Correct: fetches only when needed
+async function updateResource(resourceId: string, userId: string) {
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ const permissions = await fetchPermissions(userId)
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+```
+
+This optimization is especially valuable when the skipped branch is frequently taken, or when the deferred operation is expensive.
+
+### 1.2 Dependency-Based Parallelization
+
+**Impact: CRITICAL (2-10× improvement)**
+
+For operations with partial dependencies, use `better-all` to maximize parallelism. It automatically starts each task at the earliest possible moment.
+
+**Incorrect: profile waits for config unnecessarily**
+
+```typescript
+const [user, config] = await Promise.all([
+ fetchUser(),
+ fetchConfig()
+])
+const profile = await fetchProfile(user.id)
+```
+
+**Correct: config and profile run in parallel**
+
+```typescript
+import { all } from 'better-all'
+
+const { user, config, profile } = await all({
+ async user() { return fetchUser() },
+ async config() { return fetchConfig() },
+ async profile() {
+ return fetchProfile((await this.$.user).id)
+ }
+})
+```
+
+Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
+
+### 1.3 Prevent Waterfall Chains in API Routes
+
+**Impact: CRITICAL (2-10× improvement)**
+
+In API routes and Server Actions, start independent operations immediately, even if you don't await them yet.
+
+**Incorrect: config waits for auth, data waits for both**
+
+```typescript
+export async function GET(request: Request) {
+ const session = await auth()
+ const config = await fetchConfig()
+ const data = await fetchData(session.user.id)
+ return Response.json({ data, config })
+}
+```
+
+**Correct: auth and config start immediately**
+
+```typescript
+export async function GET(request: Request) {
+ const sessionPromise = auth()
+ const configPromise = fetchConfig()
+ const session = await sessionPromise
+ const [config, data] = await Promise.all([
+ configPromise,
+ fetchData(session.user.id)
+ ])
+ return Response.json({ data, config })
+}
+```
+
+For operations with more complex dependency chains, use `better-all` to automatically maximize parallelism (see Dependency-Based Parallelization).
+
+### 1.4 Promise.all() for Independent Operations
+
+**Impact: CRITICAL (2-10× improvement)**
+
+When async operations have no interdependencies, execute them concurrently using `Promise.all()`.
+
+**Incorrect: sequential execution, 3 round trips**
+
+```typescript
+const user = await fetchUser()
+const posts = await fetchPosts()
+const comments = await fetchComments()
+```
+
+**Correct: parallel execution, 1 round trip**
+
+```typescript
+const [user, posts, comments] = await Promise.all([
+ fetchUser(),
+ fetchPosts(),
+ fetchComments()
+])
+```
+
+### 1.5 Strategic Suspense Boundaries
+
+**Impact: HIGH (faster initial paint)**
+
+Instead of awaiting data in async components before returning JSX, use Suspense boundaries to show the wrapper UI faster while data loads.
+
+**Incorrect: wrapper blocked by data fetching**
+
+```tsx
+async function Page() {
+ const data = await fetchData() // Blocks entire page
+
+ return (
+
+
Sidebar
+
Header
+
+
+
+
Footer
+
+ )
+}
+```
+
+The entire layout waits for data even though only the middle section needs it.
+
+**Correct: wrapper shows immediately, data streams in**
+
+```tsx
+function Page() {
+ return (
+
+
Sidebar
+
Header
+
+ }>
+
+
+
+
Footer
+
+ )
+}
+
+async function DataDisplay() {
+ const data = await fetchData() // Only blocks this component
+ return
{data.content}
+}
+```
+
+Sidebar, Header, and Footer render immediately. Only DataDisplay waits for data.
+
+**Alternative: share promise across components**
+
+```tsx
+function Page() {
+ // Start fetch immediately, but don't await
+ const dataPromise = fetchData()
+
+ return (
+
+}
+
+function DataSummary({ dataPromise }: { dataPromise: Promise }) {
+ const data = use(dataPromise) // Reuses the same promise
+ return
{data.summary}
+}
+```
+
+Both components share the same promise, so only one fetch occurs. Layout renders immediately while both components wait together.
+
+**When NOT to use this pattern:**
+
+- Critical data needed for layout decisions (affects positioning)
+
+- SEO-critical content above the fold
+
+- Small, fast queries where suspense overhead isn't worth it
+
+- When you want to avoid layout shift (loading → content jump)
+
+**Trade-off:** Faster initial paint vs potential layout shift. Choose based on your UX priorities.
+
+---
+
+## 2. Bundle Size Optimization
+
+**Impact: CRITICAL**
+
+Reducing initial bundle size improves Time to Interactive and Largest Contentful Paint.
+
+### 2.1 Avoid Barrel File Imports
+
+**Impact: CRITICAL (200-800ms import cost, slow builds)**
+
+Import directly from source files instead of barrel files to avoid loading thousands of unused modules. **Barrel files** are entry points that re-export multiple modules (e.g., `index.js` that does `export * from './module'`).
+
+Popular icon and component libraries can have **up to 10,000 re-exports** in their entry file. For many React packages, **it takes 200-800ms just to import them**, affecting both development speed and production cold starts.
+
+**Why tree-shaking doesn't help:** When a library is marked as external (not bundled), the bundler can't optimize it. If you bundle it to enable tree-shaking, builds become substantially slower analyzing the entire module graph.
+
+**Incorrect: imports entire library**
+
+```tsx
+import { Check, X, Menu } from 'lucide-react'
+// Loads 1,583 modules, takes ~2.8s extra in dev
+// Runtime cost: 200-800ms on every cold start
+
+import { Button, TextField } from '@mui/material'
+// Loads 2,225 modules, takes ~4.2s extra in dev
+```
+
+**Correct: imports only what you need**
+
+```tsx
+import Check from 'lucide-react/dist/esm/icons/check'
+import X from 'lucide-react/dist/esm/icons/x'
+import Menu from 'lucide-react/dist/esm/icons/menu'
+// Loads only 3 modules (~2KB vs ~1MB)
+
+import Button from '@mui/material/Button'
+import TextField from '@mui/material/TextField'
+// Loads only what you use
+```
+
+**Alternative: Next.js 13.5+**
+
+```js
+// next.config.js - use optimizePackageImports
+module.exports = {
+ experimental: {
+ optimizePackageImports: ['lucide-react', '@mui/material']
+ }
+}
+
+// Then you can keep the ergonomic barrel imports:
+import { Check, X, Menu } from 'lucide-react'
+// Automatically transformed to direct imports at build time
+```
+
+Direct imports provide 15-70% faster dev boot, 28% faster builds, 40% faster cold starts, and significantly faster HMR.
+
+Libraries commonly affected: `lucide-react`, `@mui/material`, `@mui/icons-material`, `@tabler/icons-react`, `react-icons`, `@headlessui/react`, `@radix-ui/react-*`, `lodash`, `ramda`, `date-fns`, `rxjs`, `react-use`.
+
+Reference: [https://vercel.com/blog/how-we-optimized-package-imports-in-next-js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
+
+### 2.2 Conditional Module Loading
+
+**Impact: HIGH (loads large data only when needed)**
+
+Load large data or modules only when a feature is activated.
+
+**Example: lazy-load animation frames**
+
+```tsx
+function AnimationPlayer({ enabled, setEnabled }: { enabled: boolean; setEnabled: React.Dispatch> }) {
+ const [frames, setFrames] = useState(null)
+
+ useEffect(() => {
+ if (enabled && !frames && typeof window !== 'undefined') {
+ import('./animation-frames.js')
+ .then(mod => setFrames(mod.frames))
+ .catch(() => setEnabled(false))
+ }
+ }, [enabled, frames, setEnabled])
+
+ if (!frames) return
+ return
+}
+```
+
+The `typeof window !== 'undefined'` check prevents bundling this module for SSR, optimizing server bundle size and build speed.
+
+### 2.3 Defer Non-Critical Third-Party Libraries
+
+**Impact: MEDIUM (loads after hydration)**
+
+Analytics, logging, and error tracking don't block user interaction. Load them after hydration.
+
+**Incorrect: blocks initial bundle**
+
+```tsx
+import { Analytics } from '@vercel/analytics/react'
+
+export default function RootLayout({ children }) {
+ return (
+
+
+ {children}
+
+
+
+ )
+}
+```
+
+**Correct: loads after hydration**
+
+```tsx
+import dynamic from 'next/dynamic'
+
+const Analytics = dynamic(
+ () => import('@vercel/analytics/react').then(m => m.Analytics),
+ { ssr: false }
+)
+
+export default function RootLayout({ children }) {
+ return (
+
+
+ {children}
+
+
+
+ )
+}
+```
+
+### 2.4 Dynamic Imports for Heavy Components
+
+**Impact: CRITICAL (directly affects TTI and LCP)**
+
+Use `next/dynamic` to lazy-load large components not needed on initial render.
+
+**Incorrect: Monaco bundles with main chunk ~300KB**
+
+```tsx
+import { MonacoEditor } from './monaco-editor'
+
+function CodePanel({ code }: { code: string }) {
+ return
+}
+```
+
+**Correct: Monaco loads on demand**
+
+```tsx
+import dynamic from 'next/dynamic'
+
+const MonacoEditor = dynamic(
+ () => import('./monaco-editor').then(m => m.MonacoEditor),
+ { ssr: false }
+)
+
+function CodePanel({ code }: { code: string }) {
+ return
+}
+```
+
+### 2.5 Preload Based on User Intent
+
+**Impact: MEDIUM (reduces perceived latency)**
+
+Preload heavy bundles before they're needed to reduce perceived latency.
+
+**Example: preload on hover/focus**
+
+```tsx
+function EditorButton({ onClick }: { onClick: () => void }) {
+ const preload = () => {
+ if (typeof window !== 'undefined') {
+ void import('./monaco-editor')
+ }
+ }
+
+ return (
+
+ )
+}
+```
+
+**Example: preload when feature flag is enabled**
+
+```tsx
+function FlagsProvider({ children, flags }: Props) {
+ useEffect(() => {
+ if (flags.editorEnabled && typeof window !== 'undefined') {
+ void import('./monaco-editor').then(mod => mod.init())
+ }
+ }, [flags.editorEnabled])
+
+ return
+ {children}
+
+}
+```
+
+The `typeof window !== 'undefined'` check prevents bundling preloaded modules for SSR, optimizing server bundle size and build speed.
+
+---
+
+## 3. Server-Side Performance
+
+**Impact: HIGH**
+
+Optimizing server-side rendering and data fetching eliminates server-side waterfalls and reduces response times.
+
+### 3.1 Cross-Request LRU Caching
+
+**Impact: HIGH (caches across requests)**
+
+`React.cache()` only works within one request. For data shared across sequential requests (user clicks button A then button B), use an LRU cache.
+
+**Implementation:**
+
+```typescript
+import { LRUCache } from 'lru-cache'
+
+const cache = new LRUCache({
+ max: 1000,
+ ttl: 5 * 60 * 1000 // 5 minutes
+})
+
+export async function getUser(id: string) {
+ const cached = cache.get(id)
+ if (cached) return cached
+
+ const user = await db.user.findUnique({ where: { id } })
+ cache.set(id, user)
+ return user
+}
+
+// Request 1: DB query, result cached
+// Request 2: cache hit, no DB query
+```
+
+Use when sequential user actions hit multiple endpoints needing the same data within seconds.
+
+**With Vercel's [Fluid Compute](https://vercel.com/docs/fluid-compute):** LRU caching is especially effective because multiple concurrent requests can share the same function instance and cache. This means the cache persists across requests without needing external storage like Redis.
+
+**In traditional serverless:** Each invocation runs in isolation, so consider Redis for cross-process caching.
+
+Reference: [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
+
+### 3.2 Minimize Serialization at RSC Boundaries
+
+**Impact: HIGH (reduces data transfer size)**
+
+The React Server/Client boundary serializes all object properties into strings and embeds them in the HTML response and subsequent RSC requests. This serialized data directly impacts page weight and load time, so **size matters a lot**. Only pass fields that the client actually uses.
+
+**Incorrect: serializes all 50 fields**
+
+```tsx
+async function Page() {
+ const user = await fetchUser() // 50 fields
+ return
+}
+
+'use client'
+function Profile({ user }: { user: User }) {
+ return
{user.name}
// uses 1 field
+}
+```
+
+**Correct: serializes only 1 field**
+
+```tsx
+async function Page() {
+ const user = await fetchUser()
+ return
+}
+
+'use client'
+function Profile({ name }: { name: string }) {
+ return
{name}
+}
+```
+
+### 3.3 Parallel Data Fetching with Component Composition
+
+**Impact: CRITICAL (eliminates server-side waterfalls)**
+
+React Server Components execute sequentially within a tree. Restructure with composition to parallelize data fetching.
+
+**Incorrect: Sidebar waits for Page's fetch to complete**
+
+```tsx
+export default async function Page() {
+ const header = await fetchHeader()
+ return (
+
+
{header}
+
+
+ )
+}
+
+async function Sidebar() {
+ const items = await fetchSidebarItems()
+ return
+}
+```
+
+**Correct: both fetch simultaneously**
+
+```tsx
+async function Header() {
+ const data = await fetchHeader()
+ return
+ )
+}
+```
+
+This is especially helpful for large and static SVG nodes, which can be expensive to recreate on every render.
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler automatically hoists static JSX elements and optimizes component re-renders, making manual hoisting unnecessary.
+
+### 6.4 Optimize SVG Precision
+
+**Impact: LOW (reduces file size)**
+
+Reduce SVG coordinate precision to decrease file size. The optimal precision depends on the viewBox size, but in general reducing precision should be considered.
+
+**Incorrect: excessive precision**
+
+```svg
+
+```
+
+**Correct: 1 decimal place**
+
+```svg
+
+```
+
+**Automate with SVGO:**
+
+```bash
+npx svgo --precision=1 --multipass icon.svg
+```
+
+### 6.5 Prevent Hydration Mismatch Without Flickering
+
+**Impact: MEDIUM (avoids visual flicker and hydration errors)**
+
+When rendering content that depends on client-side storage (localStorage, cookies), avoid both SSR breakage and post-hydration flickering by injecting a synchronous script that updates the DOM before React hydrates.
+
+**Incorrect: breaks SSR**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ // localStorage is not available on server - throws error
+ const theme = localStorage.getItem('theme') || 'light'
+
+ return (
+
+ )
+}
+```
+
+Component first renders with default value (`light`), then updates after hydration, causing a visible flash of incorrect content.
+
+**Correct: no flicker, no hydration mismatch**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ return (
+ <>
+
+ {children}
+
+
+ >
+ )
+}
+```
+
+The inline script executes synchronously before showing the element, ensuring the DOM already has the correct value. No flickering, no hydration mismatch.
+
+This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
+
+### 6.6 Use Activity Component for Show/Hide
+
+**Impact: MEDIUM (preserves state/DOM)**
+
+Use React's `` to preserve state/DOM for expensive components that frequently toggle visibility.
+
+**Usage:**
+
+```tsx
+import { Activity } from 'react'
+
+function Dropdown({ isOpen }: Props) {
+ return (
+
+
+
+ )
+}
+```
+
+Avoids expensive re-renders and state loss.
+
+### 6.7 Use Explicit Conditional Rendering
+
+**Impact: LOW (prevents rendering 0 or NaN)**
+
+Use explicit ternary operators (`? :`) instead of `&&` for conditional rendering when the condition can be `0`, `NaN`, or other falsy values that render.
+
+**Incorrect: renders "0" when count is 0**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count && {count}}
+
+ )
+}
+
+// When count = 0, renders:
0
+// When count = 5, renders:
5
+```
+
+**Correct: renders nothing when count is 0**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count > 0 ? {count} : null}
+
+ )
+}
+
+// When count = 0, renders:
+// When count = 5, renders:
5
+```
+
+---
+
+## 7. JavaScript Performance
+
+**Impact: LOW-MEDIUM**
+
+Micro-optimizations for hot paths can add up to meaningful improvements.
+
+### 7.1 Batch DOM CSS Changes
+
+**Impact: MEDIUM (reduces reflows/repaints)**
+
+Avoid changing styles one property at a time. Group multiple CSS changes together via classes or `cssText` to minimize browser reflows.
+
+**Incorrect: multiple reflows**
+
+```typescript
+function updateElementStyles(element: HTMLElement) {
+ // Each line triggers a reflow
+ element.style.width = '100px'
+ element.style.height = '200px'
+ element.style.backgroundColor = 'blue'
+ element.style.border = '1px solid black'
+}
+```
+
+**Correct: add class - single reflow**
+
+```typescript
+// CSS file
+.highlighted-box {
+ width: 100px;
+ height: 200px;
+ background-color: blue;
+ border: 1px solid black;
+}
+
+// JavaScript
+function updateElementStyles(element: HTMLElement) {
+ element.classList.add('highlighted-box')
+}
+```
+
+**Correct: change cssText - single reflow**
+
+```typescript
+function updateElementStyles(element: HTMLElement) {
+ element.style.cssText = `
+ width: 100px;
+ height: 200px;
+ background-color: blue;
+ border: 1px solid black;
+ `
+}
+```
+
+**React example:**
+
+```tsx
+// Incorrect: changing styles one by one
+function Box({ isHighlighted }: { isHighlighted: boolean }) {
+ const ref = useRef(null)
+
+ useEffect(() => {
+ if (ref.current && isHighlighted) {
+ ref.current.style.width = '100px'
+ ref.current.style.height = '200px'
+ ref.current.style.backgroundColor = 'blue'
+ }
+ }, [isHighlighted])
+
+ return
+}
+```
+
+**Why this matters in React:**
+
+1. Props/state mutations break React's immutability model - React expects props and state to be treated as read-only
+
+2. Causes stale closure bugs - Mutating arrays inside closures (callbacks, effects) can lead to unexpected behavior
+
+**Browser support: fallback for older browsers**
+
+```typescript
+// Fallback for older browsers
+const sorted = [...items].sort((a, b) => a.value - b.value)
+```
+
+`.toSorted()` is available in all modern browsers (Chrome 110+, Safari 16+, Firefox 115+, Node.js 20+). For older environments, use spread operator:
+
+**Other immutable array methods:**
+
+- `.toSorted()` - immutable sort
+
+- `.toReversed()` - immutable reverse
+
+- `.toSpliced()` - immutable splice
+
+- `.with()` - immutable element replacement
+
+---
+
+## 8. Advanced Patterns
+
+**Impact: LOW**
+
+Advanced patterns for specific cases that require careful implementation.
+
+### 8.1 Store Event Handlers in Refs
+
+**Impact: LOW (stable subscriptions)**
+
+Store callbacks in refs when used in effects that shouldn't re-subscribe on callback changes.
+
+**Incorrect: re-subscribes on every render**
+
+```tsx
+function useWindowEvent(event: string, handler: () => void) {
+ useEffect(() => {
+ window.addEventListener(event, handler)
+ return () => window.removeEventListener(event, handler)
+ }, [event, handler])
+}
+```
+
+**Correct: stable subscription**
+
+```tsx
+import { useEffectEvent } from 'react'
+
+function useWindowEvent(event: string, handler: () => void) {
+ const onEvent = useEffectEvent(handler)
+
+ useEffect(() => {
+ window.addEventListener(event, onEvent)
+ return () => window.removeEventListener(event, onEvent)
+ }, [event])
+}
+```
+
+**Alternative: use `useEffectEvent` if you're on latest React:**
+
+`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
+
+### 8.2 useLatest for Stable Callback Refs
+
+**Impact: LOW (prevents effect re-runs)**
+
+Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
+
+**Implementation:**
+
+```typescript
+function useLatest(value: T) {
+ const ref = useRef(value)
+ useEffect(() => {
+ ref.current = value
+ }, [value])
+ return ref
+}
+```
+
+**Incorrect: effect re-runs on every callback change**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearch(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query, onSearch])
+}
+```
+
+**Correct: stable effect, fresh callback**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+ const onSearchRef = useLatest(onSearch)
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearchRef.current(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query])
+}
+```
+
+---
+
+## References
+
+1. [https://react.dev](https://react.dev)
+2. [https://nextjs.org](https://nextjs.org)
+3. [https://swr.vercel.app](https://swr.vercel.app)
+4. [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
+5. [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
+6. [https://vercel.com/blog/how-we-optimized-package-imports-in-next-js](https://vercel.com/blog/how-we-optimized-package-imports-in-next-js)
+7. [https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast](https://vercel.com/blog/how-we-made-the-vercel-dashboard-twice-as-fast)
diff --git a/.claude/skills/vercel-react-best-practices/SKILL.md b/.claude/skills/vercel-react-best-practices/SKILL.md
new file mode 100644
index 0000000000..b064716f60
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/SKILL.md
@@ -0,0 +1,125 @@
+---
+name: vercel-react-best-practices
+description: React and Next.js performance optimization guidelines from Vercel Engineering. This skill should be used when writing, reviewing, or refactoring React/Next.js code to ensure optimal performance patterns. Triggers on tasks involving React components, Next.js pages, data fetching, bundle optimization, or performance improvements.
+license: MIT
+metadata:
+ author: vercel
+ version: "1.0.0"
+---
+
+# Vercel React Best Practices
+
+Comprehensive performance optimization guide for React and Next.js applications, maintained by Vercel. Contains 45 rules across 8 categories, prioritized by impact to guide automated refactoring and code generation.
+
+## When to Apply
+
+Reference these guidelines when:
+- Writing new React components or Next.js pages
+- Implementing data fetching (client or server-side)
+- Reviewing code for performance issues
+- Refactoring existing React/Next.js code
+- Optimizing bundle size or load times
+
+## Rule Categories by Priority
+
+| Priority | Category | Impact | Prefix |
+|----------|----------|--------|--------|
+| 1 | Eliminating Waterfalls | CRITICAL | `async-` |
+| 2 | Bundle Size Optimization | CRITICAL | `bundle-` |
+| 3 | Server-Side Performance | HIGH | `server-` |
+| 4 | Client-Side Data Fetching | MEDIUM-HIGH | `client-` |
+| 5 | Re-render Optimization | MEDIUM | `rerender-` |
+| 6 | Rendering Performance | MEDIUM | `rendering-` |
+| 7 | JavaScript Performance | LOW-MEDIUM | `js-` |
+| 8 | Advanced Patterns | LOW | `advanced-` |
+
+## Quick Reference
+
+### 1. Eliminating Waterfalls (CRITICAL)
+
+- `async-defer-await` - Move await into branches where actually used
+- `async-parallel` - Use Promise.all() for independent operations
+- `async-dependencies` - Use better-all for partial dependencies
+- `async-api-routes` - Start promises early, await late in API routes
+- `async-suspense-boundaries` - Use Suspense to stream content
+
+### 2. Bundle Size Optimization (CRITICAL)
+
+- `bundle-barrel-imports` - Import directly, avoid barrel files
+- `bundle-dynamic-imports` - Use next/dynamic for heavy components
+- `bundle-defer-third-party` - Load analytics/logging after hydration
+- `bundle-conditional` - Load modules only when feature is activated
+- `bundle-preload` - Preload on hover/focus for perceived speed
+
+### 3. Server-Side Performance (HIGH)
+
+- `server-cache-react` - Use React.cache() for per-request deduplication
+- `server-cache-lru` - Use LRU cache for cross-request caching
+- `server-serialization` - Minimize data passed to client components
+- `server-parallel-fetching` - Restructure components to parallelize fetches
+- `server-after-nonblocking` - Use after() for non-blocking operations
+
+### 4. Client-Side Data Fetching (MEDIUM-HIGH)
+
+- `client-swr-dedup` - Use SWR for automatic request deduplication
+- `client-event-listeners` - Deduplicate global event listeners
+
+### 5. Re-render Optimization (MEDIUM)
+
+- `rerender-defer-reads` - Don't subscribe to state only used in callbacks
+- `rerender-memo` - Extract expensive work into memoized components
+- `rerender-dependencies` - Use primitive dependencies in effects
+- `rerender-derived-state` - Subscribe to derived booleans, not raw values
+- `rerender-functional-setstate` - Use functional setState for stable callbacks
+- `rerender-lazy-state-init` - Pass function to useState for expensive values
+- `rerender-transitions` - Use startTransition for non-urgent updates
+
+### 6. Rendering Performance (MEDIUM)
+
+- `rendering-animate-svg-wrapper` - Animate div wrapper, not SVG element
+- `rendering-content-visibility` - Use content-visibility for long lists
+- `rendering-hoist-jsx` - Extract static JSX outside components
+- `rendering-svg-precision` - Reduce SVG coordinate precision
+- `rendering-hydration-no-flicker` - Use inline script for client-only data
+- `rendering-activity` - Use Activity component for show/hide
+- `rendering-conditional-render` - Use ternary, not && for conditionals
+
+### 7. JavaScript Performance (LOW-MEDIUM)
+
+- `js-batch-dom-css` - Group CSS changes via classes or cssText
+- `js-index-maps` - Build Map for repeated lookups
+- `js-cache-property-access` - Cache object properties in loops
+- `js-cache-function-results` - Cache function results in module-level Map
+- `js-cache-storage` - Cache localStorage/sessionStorage reads
+- `js-combine-iterations` - Combine multiple filter/map into one loop
+- `js-length-check-first` - Check array length before expensive comparison
+- `js-early-exit` - Return early from functions
+- `js-hoist-regexp` - Hoist RegExp creation outside loops
+- `js-min-max-loop` - Use loop for min/max instead of sort
+- `js-set-map-lookups` - Use Set/Map for O(1) lookups
+- `js-tosorted-immutable` - Use toSorted() for immutability
+
+### 8. Advanced Patterns (LOW)
+
+- `advanced-event-handler-refs` - Store event handlers in refs
+- `advanced-use-latest` - useLatest for stable callback refs
+
+## How to Use
+
+Read individual rule files for detailed explanations and code examples:
+
+```
+rules/async-parallel.md
+rules/bundle-barrel-imports.md
+rules/_sections.md
+```
+
+Each rule file contains:
+- Brief explanation of why it matters
+- Incorrect code example with explanation
+- Correct code example with explanation
+- Additional context and references
+
+## Full Compiled Document
+
+For the complete guide with all rules expanded: `AGENTS.md`
diff --git a/.claude/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md b/.claude/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md
new file mode 100644
index 0000000000..97e7ade243
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md
@@ -0,0 +1,55 @@
+---
+title: Store Event Handlers in Refs
+impact: LOW
+impactDescription: stable subscriptions
+tags: advanced, hooks, refs, event-handlers, optimization
+---
+
+## Store Event Handlers in Refs
+
+Store callbacks in refs when used in effects that shouldn't re-subscribe on callback changes.
+
+**Incorrect (re-subscribes on every render):**
+
+```tsx
+function useWindowEvent(event: string, handler: (e) => void) {
+ useEffect(() => {
+ window.addEventListener(event, handler)
+ return () => window.removeEventListener(event, handler)
+ }, [event, handler])
+}
+```
+
+**Correct (stable subscription):**
+
+```tsx
+function useWindowEvent(event: string, handler: (e) => void) {
+ const handlerRef = useRef(handler)
+ useEffect(() => {
+ handlerRef.current = handler
+ }, [handler])
+
+ useEffect(() => {
+ const listener = (e) => handlerRef.current(e)
+ window.addEventListener(event, listener)
+ return () => window.removeEventListener(event, listener)
+ }, [event])
+}
+```
+
+**Alternative: use `useEffectEvent` if you're on latest React:**
+
+```tsx
+import { useEffectEvent } from 'react'
+
+function useWindowEvent(event: string, handler: (e) => void) {
+ const onEvent = useEffectEvent(handler)
+
+ useEffect(() => {
+ window.addEventListener(event, onEvent)
+ return () => window.removeEventListener(event, onEvent)
+ }, [event])
+}
+```
+
+`useEffectEvent` provides a cleaner API for the same pattern: it creates a stable function reference that always calls the latest version of the handler.
diff --git a/.claude/skills/vercel-react-best-practices/rules/advanced-use-latest.md b/.claude/skills/vercel-react-best-practices/rules/advanced-use-latest.md
new file mode 100644
index 0000000000..483c2ef7da
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/advanced-use-latest.md
@@ -0,0 +1,49 @@
+---
+title: useLatest for Stable Callback Refs
+impact: LOW
+impactDescription: prevents effect re-runs
+tags: advanced, hooks, useLatest, refs, optimization
+---
+
+## useLatest for Stable Callback Refs
+
+Access latest values in callbacks without adding them to dependency arrays. Prevents effect re-runs while avoiding stale closures.
+
+**Implementation:**
+
+```typescript
+function useLatest(value: T) {
+ const ref = useRef(value)
+ useLayoutEffect(() => {
+ ref.current = value
+ }, [value])
+ return ref
+}
+```
+
+**Incorrect (effect re-runs on every callback change):**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearch(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query, onSearch])
+}
+```
+
+**Correct (stable effect, fresh callback):**
+
+```tsx
+function SearchInput({ onSearch }: { onSearch: (q: string) => void }) {
+ const [query, setQuery] = useState('')
+ const onSearchRef = useLatest(onSearch)
+
+ useEffect(() => {
+ const timeout = setTimeout(() => onSearchRef.current(query), 300)
+ return () => clearTimeout(timeout)
+ }, [query])
+}
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/async-api-routes.md b/.claude/skills/vercel-react-best-practices/rules/async-api-routes.md
new file mode 100644
index 0000000000..6feda1ef0a
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/async-api-routes.md
@@ -0,0 +1,38 @@
+---
+title: Prevent Waterfall Chains in API Routes
+impact: CRITICAL
+impactDescription: 2-10× improvement
+tags: api-routes, server-actions, waterfalls, parallelization
+---
+
+## Prevent Waterfall Chains in API Routes
+
+In API routes and Server Actions, start independent operations immediately, even if you don't await them yet.
+
+**Incorrect (config waits for auth, data waits for both):**
+
+```typescript
+export async function GET(request: Request) {
+ const session = await auth()
+ const config = await fetchConfig()
+ const data = await fetchData(session.user.id)
+ return Response.json({ data, config })
+}
+```
+
+**Correct (auth and config start immediately):**
+
+```typescript
+export async function GET(request: Request) {
+ const sessionPromise = auth()
+ const configPromise = fetchConfig()
+ const session = await sessionPromise
+ const [config, data] = await Promise.all([
+ configPromise,
+ fetchData(session.user.id)
+ ])
+ return Response.json({ data, config })
+}
+```
+
+For operations with more complex dependency chains, use `better-all` to automatically maximize parallelism (see Dependency-Based Parallelization).
diff --git a/.claude/skills/vercel-react-best-practices/rules/async-defer-await.md b/.claude/skills/vercel-react-best-practices/rules/async-defer-await.md
new file mode 100644
index 0000000000..ea7082a362
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/async-defer-await.md
@@ -0,0 +1,80 @@
+---
+title: Defer Await Until Needed
+impact: HIGH
+impactDescription: avoids blocking unused code paths
+tags: async, await, conditional, optimization
+---
+
+## Defer Await Until Needed
+
+Move `await` operations into the branches where they're actually used to avoid blocking code paths that don't need them.
+
+**Incorrect (blocks both branches):**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ const userData = await fetchUserData(userId)
+
+ if (skipProcessing) {
+ // Returns immediately but still waited for userData
+ return { skipped: true }
+ }
+
+ // Only this branch uses userData
+ return processUserData(userData)
+}
+```
+
+**Correct (only blocks when needed):**
+
+```typescript
+async function handleRequest(userId: string, skipProcessing: boolean) {
+ if (skipProcessing) {
+ // Returns immediately without waiting
+ return { skipped: true }
+ }
+
+ // Fetch only when needed
+ const userData = await fetchUserData(userId)
+ return processUserData(userData)
+}
+```
+
+**Another example (early return optimization):**
+
+```typescript
+// Incorrect: always fetches permissions
+async function updateResource(resourceId: string, userId: string) {
+ const permissions = await fetchPermissions(userId)
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+
+// Correct: fetches only when needed
+async function updateResource(resourceId: string, userId: string) {
+ const resource = await getResource(resourceId)
+
+ if (!resource) {
+ return { error: 'Not found' }
+ }
+
+ const permissions = await fetchPermissions(userId)
+
+ if (!permissions.canEdit) {
+ return { error: 'Forbidden' }
+ }
+
+ return await updateResourceData(resource, permissions)
+}
+```
+
+This optimization is especially valuable when the skipped branch is frequently taken, or when the deferred operation is expensive.
diff --git a/.claude/skills/vercel-react-best-practices/rules/async-dependencies.md b/.claude/skills/vercel-react-best-practices/rules/async-dependencies.md
new file mode 100644
index 0000000000..fb90d861ac
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/async-dependencies.md
@@ -0,0 +1,36 @@
+---
+title: Dependency-Based Parallelization
+impact: CRITICAL
+impactDescription: 2-10× improvement
+tags: async, parallelization, dependencies, better-all
+---
+
+## Dependency-Based Parallelization
+
+For operations with partial dependencies, use `better-all` to maximize parallelism. It automatically starts each task at the earliest possible moment.
+
+**Incorrect (profile waits for config unnecessarily):**
+
+```typescript
+const [user, config] = await Promise.all([
+ fetchUser(),
+ fetchConfig()
+])
+const profile = await fetchProfile(user.id)
+```
+
+**Correct (config and profile run in parallel):**
+
+```typescript
+import { all } from 'better-all'
+
+const { user, config, profile } = await all({
+ async user() { return fetchUser() },
+ async config() { return fetchConfig() },
+ async profile() {
+ return fetchProfile((await this.$.user).id)
+ }
+})
+```
+
+Reference: [https://github.com/shuding/better-all](https://github.com/shuding/better-all)
diff --git a/.claude/skills/vercel-react-best-practices/rules/async-parallel.md b/.claude/skills/vercel-react-best-practices/rules/async-parallel.md
new file mode 100644
index 0000000000..64133f6c31
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/async-parallel.md
@@ -0,0 +1,28 @@
+---
+title: Promise.all() for Independent Operations
+impact: CRITICAL
+impactDescription: 2-10× improvement
+tags: async, parallelization, promises, waterfalls
+---
+
+## Promise.all() for Independent Operations
+
+When async operations have no interdependencies, execute them concurrently using `Promise.all()`.
+
+**Incorrect (sequential execution, 3 round trips):**
+
+```typescript
+const user = await fetchUser()
+const posts = await fetchPosts()
+const comments = await fetchComments()
+```
+
+**Correct (parallel execution, 1 round trip):**
+
+```typescript
+const [user, posts, comments] = await Promise.all([
+ fetchUser(),
+ fetchPosts(),
+ fetchComments()
+])
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md b/.claude/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md
new file mode 100644
index 0000000000..1fbc05b04e
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md
@@ -0,0 +1,99 @@
+---
+title: Strategic Suspense Boundaries
+impact: HIGH
+impactDescription: faster initial paint
+tags: async, suspense, streaming, layout-shift
+---
+
+## Strategic Suspense Boundaries
+
+Instead of awaiting data in async components before returning JSX, use Suspense boundaries to show the wrapper UI faster while data loads.
+
+**Incorrect (wrapper blocked by data fetching):**
+
+```tsx
+async function Page() {
+ const data = await fetchData() // Blocks entire page
+
+ return (
+
+
Sidebar
+
Header
+
+
+
+
Footer
+
+ )
+}
+```
+
+The entire layout waits for data even though only the middle section needs it.
+
+**Correct (wrapper shows immediately, data streams in):**
+
+```tsx
+function Page() {
+ return (
+
+
Sidebar
+
Header
+
+ }>
+
+
+
+
Footer
+
+ )
+}
+
+async function DataDisplay() {
+ const data = await fetchData() // Only blocks this component
+ return
{data.content}
+}
+```
+
+Sidebar, Header, and Footer render immediately. Only DataDisplay waits for data.
+
+**Alternative (share promise across components):**
+
+```tsx
+function Page() {
+ // Start fetch immediately, but don't await
+ const dataPromise = fetchData()
+
+ return (
+
+ )
+}
+```
+
+This applies to all CSS transforms and transitions (`transform`, `opacity`, `translate`, `scale`, `rotate`). The wrapper div allows browsers to use GPU acceleration for smoother animations.
diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-conditional-render.md b/.claude/skills/vercel-react-best-practices/rules/rendering-conditional-render.md
new file mode 100644
index 0000000000..7e866f5852
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rendering-conditional-render.md
@@ -0,0 +1,40 @@
+---
+title: Use Explicit Conditional Rendering
+impact: LOW
+impactDescription: prevents rendering 0 or NaN
+tags: rendering, conditional, jsx, falsy-values
+---
+
+## Use Explicit Conditional Rendering
+
+Use explicit ternary operators (`? :`) instead of `&&` for conditional rendering when the condition can be `0`, `NaN`, or other falsy values that render.
+
+**Incorrect (renders "0" when count is 0):**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count && {count}}
+
+ )
+}
+
+// When count = 0, renders:
0
+// When count = 5, renders:
5
+```
+
+**Correct (renders nothing when count is 0):**
+
+```tsx
+function Badge({ count }: { count: number }) {
+ return (
+
+ {count > 0 ? {count} : null}
+
+ )
+}
+
+// When count = 0, renders:
+// When count = 5, renders:
+ )
+}
+```
+
+This is especially helpful for large and static SVG nodes, which can be expensive to recreate on every render.
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler automatically hoists static JSX elements and optimizes component re-renders, making manual hoisting unnecessary.
diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md b/.claude/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md
new file mode 100644
index 0000000000..5cf0e79b69
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md
@@ -0,0 +1,82 @@
+---
+title: Prevent Hydration Mismatch Without Flickering
+impact: MEDIUM
+impactDescription: avoids visual flicker and hydration errors
+tags: rendering, ssr, hydration, localStorage, flicker
+---
+
+## Prevent Hydration Mismatch Without Flickering
+
+When rendering content that depends on client-side storage (localStorage, cookies), avoid both SSR breakage and post-hydration flickering by injecting a synchronous script that updates the DOM before React hydrates.
+
+**Incorrect (breaks SSR):**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ // localStorage is not available on server - throws error
+ const theme = localStorage.getItem('theme') || 'light'
+
+ return (
+
+ )
+}
+```
+
+Component first renders with default value (`light`), then updates after hydration, causing a visible flash of incorrect content.
+
+**Correct (no flicker, no hydration mismatch):**
+
+```tsx
+function ThemeWrapper({ children }: { children: ReactNode }) {
+ return (
+ <>
+
+ {children}
+
+
+ >
+ )
+}
+```
+
+The inline script executes synchronously before showing the element, ensuring the DOM already has the correct value. No flickering, no hydration mismatch.
+
+This pattern is especially useful for theme toggles, user preferences, authentication states, and any client-only data that should render immediately without flashing default values.
diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-svg-precision.md b/.claude/skills/vercel-react-best-practices/rules/rendering-svg-precision.md
new file mode 100644
index 0000000000..6d77128603
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rendering-svg-precision.md
@@ -0,0 +1,28 @@
+---
+title: Optimize SVG Precision
+impact: LOW
+impactDescription: reduces file size
+tags: rendering, svg, optimization, svgo
+---
+
+## Optimize SVG Precision
+
+Reduce SVG coordinate precision to decrease file size. The optimal precision depends on the viewBox size, but in general reducing precision should be considered.
+
+**Incorrect (excessive precision):**
+
+```svg
+
+```
+
+**Correct (1 decimal place):**
+
+```svg
+
+```
+
+**Automate with SVGO:**
+
+```bash
+npx svgo --precision=1 --multipass icon.svg
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-defer-reads.md b/.claude/skills/vercel-react-best-practices/rules/rerender-defer-reads.md
new file mode 100644
index 0000000000..e867c95f02
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rerender-defer-reads.md
@@ -0,0 +1,39 @@
+---
+title: Defer State Reads to Usage Point
+impact: MEDIUM
+impactDescription: avoids unnecessary subscriptions
+tags: rerender, searchParams, localStorage, optimization
+---
+
+## Defer State Reads to Usage Point
+
+Don't subscribe to dynamic state (searchParams, localStorage) if you only read it inside callbacks.
+
+**Incorrect (subscribes to all searchParams changes):**
+
+```tsx
+function ShareButton({ chatId }: { chatId: string }) {
+ const searchParams = useSearchParams()
+
+ const handleShare = () => {
+ const ref = searchParams.get('ref')
+ shareChat(chatId, { ref })
+ }
+
+ return
+}
+```
+
+**Correct (reads on demand, no subscription):**
+
+```tsx
+function ShareButton({ chatId }: { chatId: string }) {
+ const handleShare = () => {
+ const params = new URLSearchParams(window.location.search)
+ const ref = params.get('ref')
+ shareChat(chatId, { ref })
+ }
+
+ return
+}
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-dependencies.md b/.claude/skills/vercel-react-best-practices/rules/rerender-dependencies.md
new file mode 100644
index 0000000000..47a4d92685
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rerender-dependencies.md
@@ -0,0 +1,45 @@
+---
+title: Narrow Effect Dependencies
+impact: LOW
+impactDescription: minimizes effect re-runs
+tags: rerender, useEffect, dependencies, optimization
+---
+
+## Narrow Effect Dependencies
+
+Specify primitive dependencies instead of objects to minimize effect re-runs.
+
+**Incorrect (re-runs on any user field change):**
+
+```tsx
+useEffect(() => {
+ console.log(user.id)
+}, [user])
+```
+
+**Correct (re-runs only when id changes):**
+
+```tsx
+useEffect(() => {
+ console.log(user.id)
+}, [user.id])
+```
+
+**For derived state, compute outside effect:**
+
+```tsx
+// Incorrect: runs on width=767, 766, 765...
+useEffect(() => {
+ if (width < 768) {
+ enableMobileMode()
+ }
+}, [width])
+
+// Correct: runs only on boolean transition
+const isMobile = width < 768
+useEffect(() => {
+ if (isMobile) {
+ enableMobileMode()
+ }
+}, [isMobile])
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-derived-state.md b/.claude/skills/vercel-react-best-practices/rules/rerender-derived-state.md
new file mode 100644
index 0000000000..e5c899f6c0
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rerender-derived-state.md
@@ -0,0 +1,29 @@
+---
+title: Subscribe to Derived State
+impact: MEDIUM
+impactDescription: reduces re-render frequency
+tags: rerender, derived-state, media-query, optimization
+---
+
+## Subscribe to Derived State
+
+Subscribe to derived boolean state instead of continuous values to reduce re-render frequency.
+
+**Incorrect (re-renders on every pixel change):**
+
+```tsx
+function Sidebar() {
+ const width = useWindowWidth() // updates continuously
+ const isMobile = width < 768
+ return
+}
+```
+
+**Correct (re-renders only when boolean changes):**
+
+```tsx
+function Sidebar() {
+ const isMobile = useMediaQuery('(max-width: 767px)')
+ return
+}
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md b/.claude/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md
new file mode 100644
index 0000000000..b004ef45e3
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md
@@ -0,0 +1,74 @@
+---
+title: Use Functional setState Updates
+impact: MEDIUM
+impactDescription: prevents stale closures and unnecessary callback recreations
+tags: react, hooks, useState, useCallback, callbacks, closures
+---
+
+## Use Functional setState Updates
+
+When updating state based on the current state value, use the functional update form of setState instead of directly referencing the state variable. This prevents stale closures, eliminates unnecessary dependencies, and creates stable callback references.
+
+**Incorrect (requires state as dependency):**
+
+```tsx
+function TodoList() {
+ const [items, setItems] = useState(initialItems)
+
+ // Callback must depend on items, recreated on every items change
+ const addItems = useCallback((newItems: Item[]) => {
+ setItems([...items, ...newItems])
+ }, [items]) // ❌ items dependency causes recreations
+
+ // Risk of stale closure if dependency is forgotten
+ const removeItem = useCallback((id: string) => {
+ setItems(items.filter(item => item.id !== id))
+ }, []) // ❌ Missing items dependency - will use stale items!
+
+ return
+}
+```
+
+The first callback is recreated every time `items` changes, which can cause child components to re-render unnecessarily. The second callback has a stale closure bug—it will always reference the initial `items` value.
+
+**Correct (stable callbacks, no stale closures):**
+
+```tsx
+function TodoList() {
+ const [items, setItems] = useState(initialItems)
+
+ // Stable callback, never recreated
+ const addItems = useCallback((newItems: Item[]) => {
+ setItems(curr => [...curr, ...newItems])
+ }, []) // ✅ No dependencies needed
+
+ // Always uses latest state, no stale closure risk
+ const removeItem = useCallback((id: string) => {
+ setItems(curr => curr.filter(item => item.id !== id))
+ }, []) // ✅ Safe and stable
+
+ return
+}
+```
+
+**Benefits:**
+
+1. **Stable callback references** - Callbacks don't need to be recreated when state changes
+2. **No stale closures** - Always operates on the latest state value
+3. **Fewer dependencies** - Simplifies dependency arrays and reduces memory leaks
+4. **Prevents bugs** - Eliminates the most common source of React closure bugs
+
+**When to use functional updates:**
+
+- Any setState that depends on the current state value
+- Inside useCallback/useMemo when state is needed
+- Event handlers that reference state
+- Async operations that update state
+
+**When direct updates are fine:**
+
+- Setting state to a static value: `setCount(0)`
+- Setting state from props/arguments only: `setName(newName)`
+- State doesn't depend on previous value
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, the compiler can automatically optimize some cases, but functional updates are still recommended for correctness and to prevent stale closure bugs.
diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md b/.claude/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md
new file mode 100644
index 0000000000..4ecb350fba
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md
@@ -0,0 +1,58 @@
+---
+title: Use Lazy State Initialization
+impact: MEDIUM
+impactDescription: wasted computation on every render
+tags: react, hooks, useState, performance, initialization
+---
+
+## Use Lazy State Initialization
+
+Pass a function to `useState` for expensive initial values. Without the function form, the initializer runs on every render even though the value is only used once.
+
+**Incorrect (runs on every render):**
+
+```tsx
+function FilteredList({ items }: { items: Item[] }) {
+ // buildSearchIndex() runs on EVERY render, even after initialization
+ const [searchIndex, setSearchIndex] = useState(buildSearchIndex(items))
+ const [query, setQuery] = useState('')
+
+ // When query changes, buildSearchIndex runs again unnecessarily
+ return
+}
+
+function UserProfile() {
+ // JSON.parse runs on every render
+ const [settings, setSettings] = useState(
+ JSON.parse(localStorage.getItem('settings') || '{}')
+ )
+
+ return
+}
+```
+
+**Correct (runs only once):**
+
+```tsx
+function FilteredList({ items }: { items: Item[] }) {
+ // buildSearchIndex() runs ONLY on initial render
+ const [searchIndex, setSearchIndex] = useState(() => buildSearchIndex(items))
+ const [query, setQuery] = useState('')
+
+ return
+}
+
+function UserProfile() {
+ // JSON.parse runs only on initial render
+ const [settings, setSettings] = useState(() => {
+ const stored = localStorage.getItem('settings')
+ return stored ? JSON.parse(stored) : {}
+ })
+
+ return
+}
+```
+
+Use lazy initialization when computing initial values from localStorage/sessionStorage, building data structures (indexes, maps), reading from the DOM, or performing heavy transformations.
+
+For simple primitives (`useState(0)`), direct references (`useState(props.value)`), or cheap literals (`useState({})`), the function form is unnecessary.
diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-memo.md b/.claude/skills/vercel-react-best-practices/rules/rerender-memo.md
new file mode 100644
index 0000000000..f8982ab612
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rerender-memo.md
@@ -0,0 +1,44 @@
+---
+title: Extract to Memoized Components
+impact: MEDIUM
+impactDescription: enables early returns
+tags: rerender, memo, useMemo, optimization
+---
+
+## Extract to Memoized Components
+
+Extract expensive work into memoized components to enable early returns before computation.
+
+**Incorrect (computes avatar even when loading):**
+
+```tsx
+function Profile({ user, loading }: Props) {
+ const avatar = useMemo(() => {
+ const id = computeAvatarId(user)
+ return
+ }, [user])
+
+ if (loading) return
+ return
{avatar}
+}
+```
+
+**Correct (skips computation when loading):**
+
+```tsx
+const UserAvatar = memo(function UserAvatar({ user }: { user: User }) {
+ const id = useMemo(() => computeAvatarId(user), [user])
+ return
+})
+
+function Profile({ user, loading }: Props) {
+ if (loading) return
+ return (
+
+
+
+ )
+}
+```
+
+**Note:** If your project has [React Compiler](https://react.dev/learn/react-compiler) enabled, manual memoization with `memo()` and `useMemo()` is not necessary. The compiler automatically optimizes re-renders.
diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-transitions.md b/.claude/skills/vercel-react-best-practices/rules/rerender-transitions.md
new file mode 100644
index 0000000000..d99f43f764
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/rerender-transitions.md
@@ -0,0 +1,40 @@
+---
+title: Use Transitions for Non-Urgent Updates
+impact: MEDIUM
+impactDescription: maintains UI responsiveness
+tags: rerender, transitions, startTransition, performance
+---
+
+## Use Transitions for Non-Urgent Updates
+
+Mark frequent, non-urgent state updates as transitions to maintain UI responsiveness.
+
+**Incorrect (blocks UI on every scroll):**
+
+```tsx
+function ScrollTracker() {
+ const [scrollY, setScrollY] = useState(0)
+ useEffect(() => {
+ const handler = () => setScrollY(window.scrollY)
+ window.addEventListener('scroll', handler, { passive: true })
+ return () => window.removeEventListener('scroll', handler)
+ }, [])
+}
+```
+
+**Correct (non-blocking updates):**
+
+```tsx
+import { startTransition } from 'react'
+
+function ScrollTracker() {
+ const [scrollY, setScrollY] = useState(0)
+ useEffect(() => {
+ const handler = () => {
+ startTransition(() => setScrollY(window.scrollY))
+ }
+ window.addEventListener('scroll', handler, { passive: true })
+ return () => window.removeEventListener('scroll', handler)
+ }, [])
+}
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/server-after-nonblocking.md b/.claude/skills/vercel-react-best-practices/rules/server-after-nonblocking.md
new file mode 100644
index 0000000000..e8f5b260f5
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/server-after-nonblocking.md
@@ -0,0 +1,73 @@
+---
+title: Use after() for Non-Blocking Operations
+impact: MEDIUM
+impactDescription: faster response times
+tags: server, async, logging, analytics, side-effects
+---
+
+## Use after() for Non-Blocking Operations
+
+Use Next.js's `after()` to schedule work that should execute after a response is sent. This prevents logging, analytics, and other side effects from blocking the response.
+
+**Incorrect (blocks response):**
+
+```tsx
+import { logUserAction } from '@/app/utils'
+
+export async function POST(request: Request) {
+ // Perform mutation
+ await updateDatabase(request)
+
+ // Logging blocks the response
+ const userAgent = request.headers.get('user-agent') || 'unknown'
+ await logUserAction({ userAgent })
+
+ return new Response(JSON.stringify({ status: 'success' }), {
+ status: 200,
+ headers: { 'Content-Type': 'application/json' }
+ })
+}
+```
+
+**Correct (non-blocking):**
+
+```tsx
+import { after } from 'next/server'
+import { headers, cookies } from 'next/headers'
+import { logUserAction } from '@/app/utils'
+
+export async function POST(request: Request) {
+ // Perform mutation
+ await updateDatabase(request)
+
+ // Log after response is sent
+ after(async () => {
+ const userAgent = (await headers()).get('user-agent') || 'unknown'
+ const sessionCookie = (await cookies()).get('session-id')?.value || 'anonymous'
+
+ logUserAction({ sessionCookie, userAgent })
+ })
+
+ return new Response(JSON.stringify({ status: 'success' }), {
+ status: 200,
+ headers: { 'Content-Type': 'application/json' }
+ })
+}
+```
+
+The response is sent immediately while logging happens in the background.
+
+**Common use cases:**
+
+- Analytics tracking
+- Audit logging
+- Sending notifications
+- Cache invalidation
+- Cleanup tasks
+
+**Important notes:**
+
+- `after()` runs even if the response fails or redirects
+- Works in Server Actions, Route Handlers, and Server Components
+
+Reference: [https://nextjs.org/docs/app/api-reference/functions/after](https://nextjs.org/docs/app/api-reference/functions/after)
diff --git a/.claude/skills/vercel-react-best-practices/rules/server-cache-lru.md b/.claude/skills/vercel-react-best-practices/rules/server-cache-lru.md
new file mode 100644
index 0000000000..ef6938aa53
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/server-cache-lru.md
@@ -0,0 +1,41 @@
+---
+title: Cross-Request LRU Caching
+impact: HIGH
+impactDescription: caches across requests
+tags: server, cache, lru, cross-request
+---
+
+## Cross-Request LRU Caching
+
+`React.cache()` only works within one request. For data shared across sequential requests (user clicks button A then button B), use an LRU cache.
+
+**Implementation:**
+
+```typescript
+import { LRUCache } from 'lru-cache'
+
+const cache = new LRUCache({
+ max: 1000,
+ ttl: 5 * 60 * 1000 // 5 minutes
+})
+
+export async function getUser(id: string) {
+ const cached = cache.get(id)
+ if (cached) return cached
+
+ const user = await db.user.findUnique({ where: { id } })
+ cache.set(id, user)
+ return user
+}
+
+// Request 1: DB query, result cached
+// Request 2: cache hit, no DB query
+```
+
+Use when sequential user actions hit multiple endpoints needing the same data within seconds.
+
+**With Vercel's [Fluid Compute](https://vercel.com/docs/fluid-compute):** LRU caching is especially effective because multiple concurrent requests can share the same function instance and cache. This means the cache persists across requests without needing external storage like Redis.
+
+**In traditional serverless:** Each invocation runs in isolation, so consider Redis for cross-process caching.
+
+Reference: [https://github.com/isaacs/node-lru-cache](https://github.com/isaacs/node-lru-cache)
diff --git a/.claude/skills/vercel-react-best-practices/rules/server-cache-react.md b/.claude/skills/vercel-react-best-practices/rules/server-cache-react.md
new file mode 100644
index 0000000000..87c9ca3316
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/server-cache-react.md
@@ -0,0 +1,76 @@
+---
+title: Per-Request Deduplication with React.cache()
+impact: MEDIUM
+impactDescription: deduplicates within request
+tags: server, cache, react-cache, deduplication
+---
+
+## Per-Request Deduplication with React.cache()
+
+Use `React.cache()` for server-side request deduplication. Authentication and database queries benefit most.
+
+**Usage:**
+
+```typescript
+import { cache } from 'react'
+
+export const getCurrentUser = cache(async () => {
+ const session = await auth()
+ if (!session?.user?.id) return null
+ return await db.user.findUnique({
+ where: { id: session.user.id }
+ })
+})
+```
+
+Within a single request, multiple calls to `getCurrentUser()` execute the query only once.
+
+**Avoid inline objects as arguments:**
+
+`React.cache()` uses shallow equality (`Object.is`) to determine cache hits. Inline objects create new references each call, preventing cache hits.
+
+**Incorrect (always cache miss):**
+
+```typescript
+const getUser = cache(async (params: { uid: number }) => {
+ return await db.user.findUnique({ where: { id: params.uid } })
+})
+
+// Each call creates new object, never hits cache
+getUser({ uid: 1 })
+getUser({ uid: 1 }) // Cache miss, runs query again
+```
+
+**Correct (cache hit):**
+
+```typescript
+const getUser = cache(async (uid: number) => {
+ return await db.user.findUnique({ where: { id: uid } })
+})
+
+// Primitive args use value equality
+getUser(1)
+getUser(1) // Cache hit, returns cached result
+```
+
+If you must pass objects, pass the same reference:
+
+```typescript
+const params = { uid: 1 }
+getUser(params) // Query runs
+getUser(params) // Cache hit (same reference)
+```
+
+**Next.js-Specific Note:**
+
+In Next.js, the `fetch` API is automatically extended with request memoization. Requests with the same URL and options are automatically deduplicated within a single request, so you don't need `React.cache()` for `fetch` calls. However, `React.cache()` is still essential for other async tasks:
+
+- Database queries (Prisma, Drizzle, etc.)
+- Heavy computations
+- Authentication checks
+- File system operations
+- Any non-fetch async work
+
+Use `React.cache()` to deduplicate these operations across your component tree.
+
+Reference: [React.cache documentation](https://react.dev/reference/react/cache)
diff --git a/.claude/skills/vercel-react-best-practices/rules/server-parallel-fetching.md b/.claude/skills/vercel-react-best-practices/rules/server-parallel-fetching.md
new file mode 100644
index 0000000000..1affc835a6
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/server-parallel-fetching.md
@@ -0,0 +1,83 @@
+---
+title: Parallel Data Fetching with Component Composition
+impact: CRITICAL
+impactDescription: eliminates server-side waterfalls
+tags: server, rsc, parallel-fetching, composition
+---
+
+## Parallel Data Fetching with Component Composition
+
+React Server Components execute sequentially within a tree. Restructure with composition to parallelize data fetching.
+
+**Incorrect (Sidebar waits for Page's fetch to complete):**
+
+```tsx
+export default async function Page() {
+ const header = await fetchHeader()
+ return (
+
+ )
+}
+
+export default function Page() {
+ return (
+
+
+
+ )
+}
+```
diff --git a/.claude/skills/vercel-react-best-practices/rules/server-serialization.md b/.claude/skills/vercel-react-best-practices/rules/server-serialization.md
new file mode 100644
index 0000000000..39c5c4164c
--- /dev/null
+++ b/.claude/skills/vercel-react-best-practices/rules/server-serialization.md
@@ -0,0 +1,38 @@
+---
+title: Minimize Serialization at RSC Boundaries
+impact: HIGH
+impactDescription: reduces data transfer size
+tags: server, rsc, serialization, props
+---
+
+## Minimize Serialization at RSC Boundaries
+
+The React Server/Client boundary serializes all object properties into strings and embeds them in the HTML response and subsequent RSC requests. This serialized data directly impacts page weight and load time, so **size matters a lot**. Only pass fields that the client actually uses.
+
+**Incorrect (serializes all 50 fields):**
+
+```tsx
+async function Page() {
+ const user = await fetchUser() // 50 fields
+ return
+}
+
+'use client'
+function Profile({ user }: { user: User }) {
+ return
{user.name}
// uses 1 field
+}
+```
+
+**Correct (serializes only 1 field):**
+
+```tsx
+async function Page() {
+ const user = await fetchUser()
+ return
+}
+
+'use client'
+function Profile({ name }: { name: string }) {
+ return
{name}
+}
+```
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index 152a9caee8..190e00d9fe 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -39,12 +39,6 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- - name: Run pyrefly check
- run: |
- cd api
- uv add --dev pyrefly
- uv run pyrefly check || true
-
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index 5413f83c27..ff006324bb 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -16,14 +16,14 @@ jobs:
- name: Check Docker Compose inputs
id: docker-compose-changes
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- - uses: actions/setup-python@v5
+ - uses: actions/setup-python@v6
with:
python-version: "3.11"
@@ -82,6 +82,6 @@ jobs:
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
- uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
+ uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index bbf89236de..704d896192 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -112,7 +112,7 @@ jobs:
context: "web"
steps:
- name: Download digests
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v7
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*
diff --git a/.github/workflows/deploy-trigger-dev.yml b/.github/workflows/deploy-agent-dev.yml
similarity index 69%
rename from .github/workflows/deploy-trigger-dev.yml
rename to .github/workflows/deploy-agent-dev.yml
index 2d9a904fc5..dd759f7ba5 100644
--- a/.github/workflows/deploy-trigger-dev.yml
+++ b/.github/workflows/deploy-agent-dev.yml
@@ -1,4 +1,4 @@
-name: Deploy Trigger Dev
+name: Deploy Agent Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- - "deploy/trigger-dev"
+ - "deploy/agent-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
- github.event.workflow_run.head_branch == 'deploy/trigger-dev'
+ github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
+ uses: appleboy/ssh-action@v1
with:
- host: ${{ secrets.TRIGGER_SSH_HOST }}
+ host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml
index cd1c86e668..38fa0b9a7f 100644
--- a/.github/workflows/deploy-dev.yml
+++ b/.github/workflows/deploy-dev.yml
@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
+ uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml
new file mode 100644
index 0000000000..7d5f0a22e7
--- /dev/null
+++ b/.github/workflows/deploy-hitl.yml
@@ -0,0 +1,29 @@
+name: Deploy HITL
+
+on:
+ workflow_run:
+ workflows: ["Build and Push API & Web"]
+ branches:
+ - "feat/hitl-frontend"
+ - "feat/hitl-backend"
+ types:
+ - completed
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ if: |
+ github.event.workflow_run.conclusion == 'success' &&
+ (
+ github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
+ github.event.workflow_run.head_branch == 'feat/hitl-backend'
+ )
+ steps:
+ - name: Deploy to server
+ uses: appleboy/ssh-action@v1
+ with:
+ host: ${{ secrets.HITL_SSH_HOST }}
+ username: ${{ secrets.SSH_USER }}
+ key: ${{ secrets.SSH_PRIVATE_KEY }}
+ script: |
+ ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 1870b1f670..b6df1d7e93 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- - uses: actions/stale@v5
+ - uses: actions/stale@v10
with:
days-before-issue-stale: 15
days-before-issue-close: 3
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index 462ece303e..debf4ba648 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -65,6 +65,9 @@ jobs:
defaults:
run:
working-directory: ./web
+ permissions:
+ checks: write
+ pull-requests: read
steps:
- name: Checkout code
@@ -90,7 +93,7 @@ jobs:
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
- node-version: 22
+ node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@@ -103,7 +106,16 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
- pnpm run lint
+ pnpm run lint:ci
+ # pnpm run lint:report
+ # continue-on-error: true
+
+ # - name: Annotate Code
+ # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
+ # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
+ # with:
+ # eslint-report: web/eslint_report.json
+ # github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
@@ -115,11 +127,6 @@ jobs:
working-directory: ./web
run: pnpm run knip
- - name: Web build check
- if: steps.changed-files.outputs.any_changed == 'true'
- working-directory: ./web
- run: pnpm run build
-
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index 0259ef2232..ec392cb3b2 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -16,10 +16,6 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
- strategy:
- matrix:
- node-version: [16, 18, 20, 22]
-
defaults:
run:
working-directory: sdks/nodejs-client
@@ -29,10 +25,10 @@ jobs:
with:
persist-credentials: false
- - name: Use Node.js ${{ matrix.node-version }}
+ - name: Use Node.js
uses: actions/setup-node@v6
with:
- node-version: ${{ matrix.node-version }}
+ node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'
diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml
index 003e7ffc6e..8344af9890 100644
--- a/.github/workflows/translate-i18n-claude.yml
+++ b/.github/workflows/translate-i18n-claude.yml
@@ -57,7 +57,7 @@ jobs:
- name: Set up Node.js
uses: actions/setup-node@v6
with:
- node-version: 'lts/*'
+ node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml
index de093c9235..66a29453b4 100644
--- a/.github/workflows/trigger-i18n-sync.yml
+++ b/.github/workflows/trigger-i18n-sync.yml
@@ -21,7 +21,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index 0fd1d5d22b..191ce56aaa 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -31,7 +31,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
- node-version: 22
+ node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@@ -366,3 +366,48 @@ jobs:
path: web/coverage
retention-days: 30
if-no-files-found: error
+
+ web-build:
+ name: Web Build
+ runs-on: ubuntu-latest
+ defaults:
+ run:
+ working-directory: ./web
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v6
+ with:
+ persist-credentials: false
+
+ - name: Check changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v47
+ with:
+ files: |
+ web/**
+ .github/workflows/web-tests.yml
+
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Setup NodeJS
+ uses: actions/setup-node@v6
+ if: steps.changed-files.outputs.any_changed == 'true'
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Web dependencies
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm install --frozen-lockfile
+
+ - name: Web build check
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm run build
diff --git a/.nvmrc b/.nvmrc
deleted file mode 100644
index 7af24b7ddb..0000000000
--- a/.nvmrc
+++ /dev/null
@@ -1 +0,0 @@
-22.11.0
diff --git a/AGENTS.md b/AGENTS.md
index 782861ad36..deab7c8629 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -12,12 +12,8 @@ The codebase is split into:
## Backend Workflow
+- Read `api/AGENTS.md` for details
- Run backend CLI commands through `uv run --project api `.
-
-- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-
-- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
-
- Integration tests are CI-only and are not expected to run in the local environment.
## Frontend Workflow
diff --git a/Makefile b/Makefile
index 60c32948b9..e92a7b1314 100644
--- a/Makefile
+++ b/Makefile
@@ -61,7 +61,8 @@ check:
lint:
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
- @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
+ @uv run --project api --dev ruff format ./api
+ @uv run --project api --dev ruff check --fix ./api
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
@@ -73,7 +74,12 @@ type-check:
test:
@echo "🧪 Running backend unit tests..."
- @uv run --project api --dev dev/pytest/pytest_unit_tests.sh
+ @if [ -n "$(TARGET_TESTS)" ]; then \
+ echo "Target: $(TARGET_TESTS)"; \
+ uv run --project api --dev pytest $(TARGET_TESTS); \
+ else \
+ uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
+ fi
@echo "✅ Tests complete"
# Build Docker images
@@ -125,7 +131,7 @@ help:
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
- @echo " make test - Run backend unit tests"
+ @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
diff --git a/agent-notes/.gitkeep b/agent-notes/.gitkeep
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/.env.example b/api/.env.example
index 44d770ed70..15981c14b8 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -417,6 +417,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
+# Optional: override the local hostname used for SMTP HELO/EHLO
+SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
@@ -589,6 +591,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
+ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
@@ -712,3 +715,4 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
+
diff --git a/api/AGENTS.md b/api/AGENTS.md
index 17398ec4b8..6ce419828b 100644
--- a/api/AGENTS.md
+++ b/api/AGENTS.md
@@ -1,62 +1,236 @@
-# Agent Skill Index
+# API Agent Guide
+
+## Agent Notes (must-check)
+
+Before you start work on any backend file under `api/`, you MUST check whether a related note exists under:
+
+- `agent-notes/.md`
+
+Rules:
+
+- **Path mapping**: for a target file `/.py`, the note must be `agent-notes//.py.md` (same folder structure, same filename, plus `.md`).
+- **Before working**:
+ - If the note exists, read it first and follow any constraints/decisions recorded there.
+ - If the note conflicts with the current code, or references an "origin" file/path that has been deleted, renamed, or migrated, treat the **code as the single source of truth** and update the note to match reality.
+ - If the note does not exist, create it with a short architecture/intent summary and any relevant invariants/edge cases.
+- **During working**:
+ - Keep the note in sync as you discover constraints, make decisions, or change approach.
+ - If you move/rename a file, migrate its note to the new mapped path (and fix any outdated references inside the note).
+ - Record non-obvious edge cases, trade-offs, and the test/verification plan as you go (not just at the end).
+ - Keep notes **coherent**: integrate new findings into the relevant sections and rewrite for clarity; avoid append-only “recent fix” / changelog-style additions unless the note is explicitly intended to be a changelog.
+- **When finishing work**:
+ - Update the related note(s) to reflect what changed, why, and any new edge cases/tests.
+ - If a file is deleted, remove or clearly deprecate the corresponding note so it cannot be mistaken as current guidance.
+ - Keep notes concise and accurate; they are meant to prevent repeated rediscovery.
+
+## Skill Index
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
-______________________________________________________________________
+### Platform Foundations
-## Platform Foundations
-
-- **[Infrastructure Overview](agent_skills/infra.md)**\
- When to read this:
+#### [Infrastructure Overview](agent_skills/infra.md)
+- **When to read this**
- You need to understand where a feature belongs in the architecture.
- You’re wiring storage, Redis, vector stores, or OTEL.
- - You’re about to add CLI commands or async jobs.\
- What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
+ - You’re about to add CLI commands or async jobs.
+- **What it covers**
+ - Configuration stack (`configs/app_config.py`, remote settings)
+ - Storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`)
+ - Redis conventions (`extensions/ext_redis.py`)
+ - Plugin runtime topology
+ - Vector-store factory (`core/rag/datasource/vdb/*`)
+ - Observability hooks
+ - SSRF proxy usage
+ - Core CLI commands
-- **[Coding Style](agent_skills/coding_style.md)**\
- When to read this:
+### Plugin & Extension Development
- - You’re writing or reviewing backend code and need the authoritative checklist.
- - You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- - You want the exact lint/type/test commands used in PRs.\
- Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
-
-______________________________________________________________________
-
-## Plugin & Extension Development
-
-- **[Plugin Systems](agent_skills/plugin.md)**\
- When to read this:
+#### [Plugin Systems](agent_skills/plugin.md)
+- **When to read this**
- You’re building or debugging a marketplace plugin.
- - You need to know how manifests, providers, daemons, and migrations fit together.\
- What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
+ - You need to know how manifests, providers, daemons, and migrations fit together.
+- **What it covers**
+ - Plugin manifests (`core/plugin/entities/plugin.py`)
+ - Installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands)
+ - Runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent)
+ - Daemon coordination (`core/plugin/entities/plugin_daemon.py`)
+ - How provider registries surface capabilities to the rest of the platform
-- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
- When to read this:
+#### [Plugin OAuth](agent_skills/plugin_oauth.md)
+- **When to read this**
- You must integrate OAuth for a plugin or datasource.
- - You’re handling credential encryption or refresh flows.\
- Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
+ - You’re handling credential encryption or refresh flows.
+- **Topics**
+ - Credential storage
+ - Encryption helpers (`core/helper/provider_encryption.py`)
+ - OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`)
+ - How console/API layers expose the flows
-______________________________________________________________________
+### Workflow Entry & Execution
-## Workflow Entry & Execution
+#### [Trigger Concepts](agent_skills/trigger.md)
-- **[Trigger Concepts](agent_skills/trigger.md)**\
- When to read this:
+- **When to read this**
- You’re debugging why a workflow didn’t start.
- You’re adding a new trigger type or hook.
- - You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
- Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
+ - You need to trace async execution, draft debugging, or webhook/schedule pipelines.
+- **Details**
+ - Start-node taxonomy
+ - Webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`)
+ - Async orchestration (`services/async_workflow_service.py`, Celery queues)
+ - Debug event bus
+ - Storage/logging interactions
-______________________________________________________________________
+## General Reminders
-## Additional Notes for Agents
-
-- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
+- All skill docs assume you follow the coding style rules below—run the lint/type/test commands before submitting changes.
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
+
+## Coding Style
+
+This is the default standard for backend code in this repo. Follow it for new code and use it as the checklist when reviewing changes.
+
+### Linting & Formatting
+
+- Use Ruff for formatting and linting (follow `.ruff.toml`).
+- Keep each line under 120 characters (including spaces).
+
+### Naming Conventions
+
+- Use `snake_case` for variables and functions.
+- Use `PascalCase` for classes.
+- Use `UPPER_CASE` for constants.
+
+### Typing & Class Layout
+
+- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
+- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
+- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
+
+```python
+from datetime import datetime
+
+
+class Example:
+ user_id: str
+ created_at: datetime
+
+ def __init__(self, user_id: str, created_at: datetime) -> None:
+ self.user_id = user_id
+ self.created_at = created_at
+```
+
+### General Rules
+
+- Use Pydantic v2 conventions.
+- Use `uv` for Python package management in this repo (usually with `--project api`).
+- Prefer simple functions over small “utility classes” for lightweight helpers.
+- Avoid implementing dunder methods unless it’s clearly needed and matches existing patterns.
+- Never start long-running services as part of agent work (`uv run app.py`, `flask run`, etc.); running tests is allowed.
+- Keep files below ~800 lines; split when necessary.
+- Keep code readable and explicit—avoid clever hacks.
+
+### Architecture & Boundaries
+
+- Mirror the layered architecture: controller → service → core/domain.
+- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
+- Optimise for observability: deterministic control flow, clear logging, actionable errors.
+
+### Logging & Errors
+
+- Never use `print`; use a module-level logger:
+ - `logger = logging.getLogger(__name__)`
+- Include tenant/app/workflow identifiers in log context when relevant.
+- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate them into HTTP responses in controllers.
+- Log retryable events at `warning`, terminal failures at `error`.
+
+### SQLAlchemy Patterns
+
+- Models inherit from `models.base.TypeBase`; do not create ad-hoc metadata or engines.
+- Open sessions with context managers:
+
+```python
+from sqlalchemy.orm import Session
+
+with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(Workflow).where(
+ Workflow.id == workflow_id,
+ Workflow.tenant_id == tenant_id,
+ )
+ workflow = session.execute(stmt).scalar_one_or_none()
+```
+
+- Prefer SQLAlchemy expressions; avoid raw SQL unless necessary.
+- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
+- Introduce repository abstractions only for very large tables (e.g., workflow executions) or when alternative storage strategies are required.
+
+### Storage & External I/O
+
+- Access storage via `extensions.ext_storage.storage`.
+- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
+- Background tasks that touch storage must be idempotent, and should log relevant object identifiers.
+
+### Pydantic Usage
+
+- Define DTOs with Pydantic v2 models and forbid extras by default.
+- Use `@field_validator` / `@model_validator` for domain rules.
+
+Example:
+
+```python
+from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
+
+
+class TriggerConfig(BaseModel):
+ endpoint: HttpUrl
+ secret: str
+
+ model_config = ConfigDict(extra="forbid")
+
+ @field_validator("secret")
+ def ensure_secret_prefix(cls, value: str) -> str:
+ if not value.startswith("dify_"):
+ raise ValueError("secret must start with dify_")
+ return value
+```
+
+### Generics & Protocols
+
+- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
+- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
+- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
+
+### Tooling & Checks
+
+Quick checks while iterating:
+
+- Format: `make format`
+- Lint (includes auto-fix): `make lint`
+- Type check: `make type-check`
+- Targeted tests: `make test TARGET_TESTS=./api/tests/`
+
+Before opening a PR / submitting:
+
+- `make lint`
+- `make type-check`
+- `make test`
+
+### Controllers & Services
+
+- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
+- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
+- Document non-obvious behaviour with concise comments.
+
+### Miscellaneous
+
+- Use `configs.dify_config` for configuration—never read environment variables directly.
+- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
+- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
+- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/agent_skills/coding_style.md b/api/agent_skills/coding_style.md
deleted file mode 100644
index a2b66f0bd5..0000000000
--- a/api/agent_skills/coding_style.md
+++ /dev/null
@@ -1,115 +0,0 @@
-## Linter
-
-- Always follow `.ruff.toml`.
-- Run `uv run ruff check --fix --unsafe-fixes`.
-- Keep each line under 100 characters (including spaces).
-
-## Code Style
-
-- `snake_case` for variables and functions.
-- `PascalCase` for classes.
-- `UPPER_CASE` for constants.
-
-## Rules
-
-- Use Pydantic v2 standard.
-- Use `uv` for package management.
-- Do not override dunder methods like `__init__`, `__iadd__`, etc.
-- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
-- Prefer simple functions over classes for lightweight helpers.
-- Keep files below 800 lines; split when necessary.
-- Keep code readable—no clever hacks.
-- Never use `print`; log with `logger = logging.getLogger(__name__)`.
-
-## Guiding Principles
-
-- Mirror the project’s layered architecture: controller → service → core/domain.
-- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
-- Optimise for observability: deterministic control flow, clear logging, actionable errors.
-
-## SQLAlchemy Patterns
-
-- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
-
-- Open sessions with context managers:
-
- ```python
- from sqlalchemy.orm import Session
-
- with Session(db.engine, expire_on_commit=False) as session:
- stmt = select(Workflow).where(
- Workflow.id == workflow_id,
- Workflow.tenant_id == tenant_id,
- )
- workflow = session.execute(stmt).scalar_one_or_none()
- ```
-
-- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
-
-- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
-
-- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
-
-## Storage & External IO
-
-- Access storage via `extensions.ext_storage.storage`.
-- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
-- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
-
-## Pydantic Usage
-
-- Define DTOs with Pydantic v2 models and forbid extras by default.
-
-- Use `@field_validator` / `@model_validator` for domain rules.
-
-- Example:
-
- ```python
- from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
-
- class TriggerConfig(BaseModel):
- endpoint: HttpUrl
- secret: str
-
- model_config = ConfigDict(extra="forbid")
-
- @field_validator("secret")
- def ensure_secret_prefix(cls, value: str) -> str:
- if not value.startswith("dify_"):
- raise ValueError("secret must start with dify_")
- return value
- ```
-
-## Generics & Protocols
-
-- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
-- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
-- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
-
-## Error Handling & Logging
-
-- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
-- Declare `logger = logging.getLogger(__name__)` at module top.
-- Include tenant/app/workflow identifiers in log context.
-- Log retryable events at `warning`, terminal failures at `error`.
-
-## Tooling & Checks
-
-- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
-- Type checks: `uv run --directory api --dev basedpyright`.
-- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-- Run all of the above before submitting your work.
-
-## Controllers & Services
-
-- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
-- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
-- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
-- Document non-obvious behaviour with concise comments.
-
-## Miscellaneous
-
-- Use `configs.dify_config` for configuration—never read environment variables directly.
-- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
-- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
-- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/app_factory.py b/api/app_factory.py
index f827842d68..1fb01d2e91 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -71,6 +71,8 @@ def create_app() -> DifyApp:
def initialize_extensions(app: DifyApp):
+ # Initialize Flask context capture for workflow execution
+ from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
@@ -100,6 +102,8 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
)
+ init_flask_context()
+
extensions = [
ext_timezone,
ext_logging,
diff --git a/api/commands.py b/api/commands.py
index 7ebf5b4874..aa7b731a27 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -1,7 +1,9 @@
import base64
+import datetime
import json
import logging
import secrets
+import time
from typing import Any
import click
@@ -34,7 +36,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
-from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
+from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@@ -45,6 +47,9 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -62,8 +67,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
+ normalized_email = email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@@ -84,7 +91,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -100,20 +107,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
+ normalized_new_email = new_email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
- email_validate(new_email)
+ email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
- account.email = new_email
+ account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@@ -658,7 +667,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
- email = email.strip()
+ email = email.strip().lower()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@@ -852,6 +861,95 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
+@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
+@click.option(
+ "--before-days",
+ "--days",
+ default=30,
+ show_default=True,
+ type=click.IntRange(min=0),
+ help="Delete workflow runs created before N days ago.",
+)
+@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
+@click.option(
+ "--from-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
+)
+@click.option(
+ "--to-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
+)
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option(
+ "--dry-run",
+ is_flag=True,
+ help="Preview cleanup results without deleting any workflow run data.",
+)
+def clean_workflow_runs(
+ before_days: int,
+ batch_size: int,
+ from_days_ago: int | None,
+ to_days_ago: int | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ dry_run: bool,
+):
+ """
+ Clean workflow runs and related workflow data for free tenants.
+ """
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+
+ if (from_days_ago is None) ^ (to_days_ago is None):
+ raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
+
+ if from_days_ago is not None and to_days_ago is not None:
+ if start_from or end_before:
+ raise click.UsageError("Choose either day offsets or explicit dates, not both.")
+ if from_days_ago <= to_days_ago:
+ raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
+ now = datetime.datetime.now()
+ start_from = now - datetime.timedelta(days=from_days_ago)
+ end_before = now - datetime.timedelta(days=to_days_ago)
+ before_days = 0
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
+
+ WorkflowRunCleanup(
+ days=before_days,
+ batch_size=batch_size,
+ start_from=start_from,
+ end_before=end_before,
+ dry_run=dry_run,
+ ).run()
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Workflow run cleanup completed. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+
+
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):
@@ -2111,3 +2209,79 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
+
+
+@click.command("clean-expired-messages", help="Clean expired messages.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Lower bound (inclusive) for created_at.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Upper bound (exclusive) for created_at.",
+)
+@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
+@click.option(
+ "--graceful-period",
+ default=21,
+ show_default=True,
+ help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
+)
+@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
+def clean_expired_messages(
+ batch_size: int,
+ graceful_period: int,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ dry_run: bool,
+):
+ """
+ Clean expired messages and related data for tenants based on clean policy.
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ # NOTE: graceful_period will be ignored when billing is disabled.
+ policy = create_message_clean_policy(graceful_period_days=graceful_period)
+
+ # Create and run the cleanup service
+ service = MessagesCleanService.from_time_range(
+ policy=policy,
+ start_from=start_from,
+ end_before=end_before,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
+
+ click.echo(click.style("messages cleanup completed.", fg="green"))
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 1869e407b4..7f9f90a590 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -949,6 +949,12 @@ class MailConfig(BaseSettings):
default=False,
)
+ SMTP_LOCAL_HOSTNAME: str | None = Field(
+ description="Override the local hostname used in SMTP HELO/EHLO. "
+ "Useful behind NAT or when the default hostname causes rejections.",
+ default=None,
+ )
+
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
@@ -1111,6 +1117,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
+ ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
+ description="Enable scheduled workflow run cleanup task",
+ default=False,
+ )
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py
index be01f2dc36..2a35300401 100644
--- a/api/configs/middleware/storage/volcengine_tos_storage_config.py
+++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py
@@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
- Configuration settings for Volcengine Tinder Object Storage (TOS)
+ Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(
diff --git a/api/context/__init__.py b/api/context/__init__.py
new file mode 100644
index 0000000000..aebf9750ce
--- /dev/null
+++ b/api/context/__init__.py
@@ -0,0 +1,74 @@
+"""
+Core Context - Framework-agnostic context management.
+
+This module provides context management that is independent of any specific
+web framework. Framework-specific implementations register their context
+capture functions at application initialization time.
+
+This ensures the workflow layer remains completely decoupled from Flask
+or any other web framework.
+"""
+
+import contextvars
+from collections.abc import Callable
+
+from core.workflow.context.execution_context import (
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+)
+
+# Global capturer function - set by framework-specific modules
+_capturer: Callable[[], IExecutionContext] | None = None
+
+
+def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
+ """
+ Register a context capture function.
+
+ This should be called by framework-specific modules (e.g., Flask)
+ during application initialization.
+
+ Args:
+ capturer: Function that captures current context and returns IExecutionContext
+ """
+ global _capturer
+ _capturer = capturer
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context.
+
+ This function uses the registered context capturer. If no capturer
+ is registered, it returns a minimal context with only contextvars
+ (suitable for non-framework environments like tests or standalone scripts).
+
+ Returns:
+ IExecutionContext with captured context
+ """
+ if _capturer is None:
+ # No framework registered - return minimal context
+ return ExecutionContext(
+ app_context=NullAppContext(),
+ context_vars=contextvars.copy_context(),
+ )
+
+ return _capturer()
+
+
+def reset_context_provider() -> None:
+ """
+ Reset the context capturer.
+
+ This is primarily useful for testing to ensure a clean state.
+ """
+ global _capturer
+ _capturer = None
+
+
+__all__ = [
+ "capture_current_context",
+ "register_context_capturer",
+ "reset_context_provider",
+]
diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py
new file mode 100644
index 0000000000..4b693cd91f
--- /dev/null
+++ b/api/context/flask_app_context.py
@@ -0,0 +1,198 @@
+"""
+Flask App Context - Flask implementation of AppContext interface.
+"""
+
+import contextvars
+from collections.abc import Generator
+from contextlib import contextmanager
+from typing import Any, final
+
+from flask import Flask, current_app, g
+
+from context import register_context_capturer
+from core.workflow.context.execution_context import (
+ AppContext,
+ IExecutionContext,
+)
+
+
+@final
+class FlaskAppContext(AppContext):
+ """
+ Flask implementation of AppContext.
+
+ This adapts Flask's app context to the AppContext interface.
+ """
+
+ def __init__(self, flask_app: Flask) -> None:
+ """
+ Initialize Flask app context.
+
+ Args:
+ flask_app: The Flask application instance
+ """
+ self._flask_app = flask_app
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value from Flask app config."""
+ return self._flask_app.config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name."""
+ return self._flask_app.extensions.get(name)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask app context."""
+ with self._flask_app.app_context():
+ yield
+
+ @property
+ def flask_app(self) -> Flask:
+ """Get the underlying Flask app instance."""
+ return self._flask_app
+
+
+def capture_flask_context(user: Any = None) -> IExecutionContext:
+ """
+ Capture current Flask execution context.
+
+ This function captures the Flask app context and contextvars from the
+ current environment. It should be called from within a Flask request or
+ app context.
+
+ Args:
+ user: Optional user object to include in context
+
+ Returns:
+ IExecutionContext with captured Flask context
+
+ Raises:
+ RuntimeError: If called outside Flask context
+ """
+ # Get Flask app instance
+ flask_app = current_app._get_current_object() # type: ignore
+
+ # Save current user if available
+ saved_user = user
+ if saved_user is None:
+ # Check for user in g (flask-login)
+ if hasattr(g, "_login_user"):
+ saved_user = g._login_user
+
+ # Capture contextvars
+ context_vars = contextvars.copy_context()
+
+ return FlaskExecutionContext(
+ flask_app=flask_app,
+ context_vars=context_vars,
+ user=saved_user,
+ )
+
+
+@final
+class FlaskExecutionContext:
+ """
+ Flask-specific execution context.
+
+ This is a specialized version of ExecutionContext that includes Flask app
+ context. It provides the same interface as ExecutionContext but with
+ Flask-specific implementation.
+ """
+
+ def __init__(
+ self,
+ flask_app: Flask,
+ context_vars: contextvars.Context,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize Flask execution context.
+
+ Args:
+ flask_app: Flask application instance
+ context_vars: Python contextvars
+ user: Optional user object
+ """
+ self._app_context = FlaskAppContext(flask_app)
+ self._context_vars = context_vars
+ self._user = user
+ self._flask_app = flask_app
+
+ @property
+ def app_context(self) -> FlaskAppContext:
+ """Get Flask app context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ def __enter__(self) -> "FlaskExecutionContext":
+ """Enter the Flask execution context."""
+ # Restore context variables
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Save current user from g if available
+ saved_user = None
+ if hasattr(g, "_login_user"):
+ saved_user = g._login_user
+
+ # Enter Flask app context
+ self._cm = self._app_context.enter()
+ self._cm.__enter__()
+
+ # Restore user in new app context
+ if saved_user is not None:
+ g._login_user = saved_user
+
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the Flask execution context."""
+ if hasattr(self, "_cm"):
+ self._cm.__exit__(*args)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask execution context as context manager."""
+ # Restore context variables
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Save current user from g if available
+ saved_user = None
+ if hasattr(g, "_login_user"):
+ saved_user = g._login_user
+
+ # Enter Flask app context
+ with self._flask_app.app_context():
+ # Restore user in new app context
+ if saved_user is not None:
+ g._login_user = saved_user
+ yield
+
+
+def init_flask_context() -> None:
+ """
+ Initialize Flask context capture by registering the capturer.
+
+ This function should be called during Flask application initialization
+ to register the Flask-specific context capturer with the core context module.
+
+ Example:
+ app = Flask(__name__)
+ init_flask_context() # Register Flask context capturer
+
+ Note:
+ This function does not need the app instance as it uses Flask's
+ `current_app` to get the app when capturing context.
+ """
+ register_context_capturer(capture_flask_context)
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 56816dd462..55fdcb51e4 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -592,9 +592,12 @@ def _get_conversation(app_model, conversation_id):
if not conversation:
raise NotFound("Conversation Not Exists.")
- if not conversation.read_at:
- conversation.read_at = naive_utc_now()
- conversation.read_account_id = current_user.id
- db.session.commit()
+ db.session.execute(
+ sa.update(Conversation)
+ .where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
+ .values(read_at=naive_utc_now(), read_account_id=current_user.id)
+ )
+ db.session.commit()
+ db.session.refresh(conversation)
return conversation
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index fe70d930fb..f741107b87 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -63,13 +63,19 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
- reg_email = args.email
token = args.token
- invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
+ invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
+
+ # Check workspace permission
+ if tenant:
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
@@ -100,11 +106,12 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
- invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
+ normalized_request_email = args.email.lower() if args.email else None
+ invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
- RegisterService.revoke_token(args.workspace_id, args.email, args.token)
+ RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"]
account.name = args.name
diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py
index fa082c735d..c2a95ddad2 100644
--- a/api/controllers/console/auth/email_register.py
+++ b/api/controllers/console/auth/email_register.py
@@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
- token = None
- token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
+ token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
+ is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_email_register_error_rate_limit(args.email)
+ AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
- AccountService.reset_email_register_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_email_register_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register")
@@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
+ normalized_email = email.lower()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
- account = self._create_new_account(email, args.password_confirm)
+ account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
- def _create_new_account(self, email, password) -> Account | None:
+ def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
try:
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index 661f591182..394f205d93 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
- email=args.email,
+ email=normalized_email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(args.email)
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=args.code, additional_data={"phase": "reset"}
+ token_email, code=args.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets")
@@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
-
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index 4a52bf8abe..400df138b8 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -90,32 +90,38 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
+ request_email = args.email
+ normalized_email = request_email.lower()
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
- is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
+ is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
+ invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
- if args.invite_token:
- invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
+ if invite_token:
+ invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
+ if invitation_data is None:
+ invite_token = None
try:
if invitation_data:
data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
- if invitee_email != args.email:
+ invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
+ if invitee_email_normalized != normalized_email:
raise InvalidEmailError()
- account = AccountService.authenticate(args.email, args.password, args.invite_token)
- else:
- account = AccountService.authenticate(args.email, args.password)
+ account = _authenticate_account_with_case_fallback(
+ request_email, normalized_email, args.password, invite_token
+ )
except services.errors.account.AccountLoginError:
raise AccountBannedError()
- except services.errors.account.AccountPasswordError:
- AccountService.add_login_error_rate_limit(args.email)
- raise AuthenticationFailedError()
+ except services.errors.account.AccountPasswordError as exc:
+ AccountService.add_login_error_rate_limit(normalized_email)
+ raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@@ -130,7 +136,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
- email=args.email,
+ email=normalized_email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
- token = AccountService.send_email_code_login_email(email=args.email, language=language)
+ token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else:
raise AccountNotFound()
else:
@@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
- user_email = args.email
+ original_email = args.email
+ user_email = original_email.lower()
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args.email:
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
- account = AccountService.get_user_through_email(user_email)
+ account = _get_account_with_case_fallback(original_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
+
+
+def _get_account_with_case_fallback(email: str):
+ account = AccountService.get_user_through_email(email)
+ if account or email == email.lower():
+ return account
+
+ return AccountService.get_user_through_email(email.lower())
+
+
+def _authenticate_account_with_case_fallback(
+ original_email: str, normalized_email: str, password: str, invite_token: str | None
+):
+ try:
+ return AccountService.authenticate(original_email, password, invite_token)
+ except services.errors.account.AccountPasswordError:
+ if original_email == normalized_email:
+ raise
+ return AccountService.authenticate(normalized_email, password, invite_token)
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index c20e83d36f..112e152432 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -3,7 +3,6 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
-from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -118,7 +117,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
- if invitation_email != user_info.email:
+ invitation_email_normalized = (
+ invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
+ )
+ if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
@@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
+ normalized_email = user_info.email.lower()
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
- email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
+ email=normalized_email,
+ name=account_name,
+ password=None,
+ open_id=user_info.id,
+ provider=provider,
)
# Set interface language
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index ac78d3854b..707d90f044 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -7,7 +7,7 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
name: str
+class DocumentDatasetListParam(BaseModel):
+ page: int = Field(1, title="Page", description="Page number.")
+ limit: int = Field(20, title="Limit", description="Page size.")
+ search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
+ sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
+ status: str | None = Field(None, title="Status", description="Document status.")
+ fetch_val: str = Field("false", alias="fetch")
+
+
register_schema_models(
console_ns,
KnowledgeConfig,
@@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
- sort = request.args.get("sort", default="-created_at", type=str)
- status = request.args.get("status", default=None, type=str)
+ raw_args = request.args.to_dict()
+ param = DocumentDatasetListParam.model_validate(raw_args)
+ page = param.page
+ limit = param.limit
+ search = param.search
+ sort = param.sort_by
+ status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
- fetch_val = request.args.get("fetch", default="false")
+ fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index 89c9fcad36..a70a7ce480 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -81,7 +81,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
- name: str = Field(..., min_length=1, max_length=40)
+ name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index 7fa02ae280..ed22ef045d 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -84,10 +84,11 @@ class SetupApi(Resource):
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
# setup
RegisterService.setup(
- email=args.email,
+ email=normalized_email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index e9fbb515e4..023ffc991a 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -30,6 +30,11 @@ class TagBindingRemovePayload(BaseModel):
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
+class TagListQueryParam(BaseModel):
+ type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
+ keyword: str | None = Field(None, description="Search keyword")
+
+
register_schema_models(
console_ns,
TagBasePayload,
@@ -43,12 +48,15 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
+ @console_ns.doc(
+ params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
+ )
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
- tag_type = request.args.get("type", type=str, default="")
- keyword = request.args.get("keyword", default=None, type=str)
- tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
+ raw_args = request.args.to_dict()
+ param = TagListQueryParam.model_validate(raw_args)
+ tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
return tags, 200
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 03ad0f423b..527aabbc3d 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -41,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
-from models import Account, AccountIntegrate, InvitationCode
+from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
- user_email = args.email
+ user_email = None
+ email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
- if user_email != current_user.email:
+ if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
+
+ user_email = current_user.email
else:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
+ email_for_sending = account.email
+ user_email = account.email
token = AccountService.send_change_email_email(
- account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
+ account=account,
+ email=email_for_sending,
+ old_email=user_email,
+ language=language,
+ phase=args.phase,
)
return {"result": "success", "data": token}
@@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_change_email_error_rate_limit(args.email)
+ AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
- AccountService.reset_change_email_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_change_email_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset")
@@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
+ normalized_new_email = args.new_email.lower()
- if AccountService.is_account_in_freeze(args.new_email):
+ if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.new_email):
+ if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
- if current_user.email != old_email:
+ if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
- updated_account = AccountService.update_account_email(current_user, email=args.new_email)
+ updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
- email=args.new_email,
+ email=normalized_new_email,
)
return updated_account
@@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
- if AccountService.is_account_in_freeze(args.email):
+ normalized_email = args.email.lower()
+ if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.email):
+ if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index 0142e14fb0..01cca2a8a0 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -107,6 +107,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
+
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(inviter.current_tenant.id)
+
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
@@ -116,26 +122,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
+ normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
- inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
+ tenant=inviter.current_tenant,
+ email=invitee_email,
+ language=interface_language,
+ role=invitee_role,
+ inviter=inviter,
)
- encoded_invitee_email = parse.quote(invitee_email)
+ encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
- "email": invitee_email,
+ "email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
- {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
+ {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
- invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
+ invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index 52e6f7d737..94be81d94f 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
+ only_edition_enterprise,
setup_required,
)
from enums.cloud_plan import CloudPlan
@@ -28,6 +29,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
+from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@@ -288,3 +290,31 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
+
+
+@console_ns.route("/workspaces/current/permission")
+class WorkspacePermissionApi(Resource):
+ """Get workspace permissions for the current workspace."""
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @only_edition_enterprise
+ def get(self):
+ """
+ Get workspace permission settings.
+ Returns permission flags that control workspace features like member invitations and owner transfer.
+ """
+ _, current_tenant_id = current_account_with_tenant()
+
+ if not current_tenant_id:
+ raise ValueError("No current tenant")
+
+ # Get workspace permissions from enterprise service
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
+
+ return {
+ "workspace_id": permission.workspace_id,
+ "allow_member_invite": permission.allow_member_invite,
+ "allow_owner_transfer": permission.allow_owner_transfer,
+ }, 200
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index 95fc006a12..fd928b077d 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -286,13 +286,12 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
- _, current_tenant_id = current_account_with_tenant()
- features = FeatureService.get_features(current_tenant_id)
- if features.is_allow_transfer_workspace:
- return view(*args, **kwargs)
+ from libs.workspace_permission import check_workspace_owner_transfer_permission
- # otherwise, return 403
- abort(403)
+ _, current_tenant_id = current_account_with_tenant()
+ # Check both billing/plan level and workspace policy level permissions
+ check_workspace_owner_transfer_permission(current_tenant_id)
+ return view(*args, **kwargs)
return decorated
diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py
index 690b76655f..91d206f727 100644
--- a/api/controllers/web/forgot_password.py
+++ b/api/controllers/web/forgot_password.py
@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@@ -22,7 +21,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
+from models.account import Account
from services.account_service import AccountService
@@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
+ request_email = payload.email
+ normalized_email = request_email.lower()
+
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
raise AuthenticationFailedError()
else:
- token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
+ token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token}
@@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
- user_email = payload.email
+ user_email = payload.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if payload.code != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(payload.email)
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=payload.code, additional_data={"phase": "reset"}
+ token_email, code=payload.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(payload.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets")
@@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py
index 5847f4ae3a..e8053acdfd 100644
--- a/api/controllers/web/login.py
+++ b/api/controllers/web/login.py
@@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
)
args = parser.parse_args()
- user_email = args["email"]
+ user_email = args["email"].lower()
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args["email"]:
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+ if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
- account = WebAppAuthService.get_user_through_email(user_email)
+ account = WebAppAuthService.get_user_through_email(token_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
- AccountService.reset_login_error_rate_limit(args["email"])
+ AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index c196dbbdf1..3c6d36afe4 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -1,6 +1,7 @@
import json
import logging
import uuid
+from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
@@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
+from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
+ tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
- message_unit_price=0,
- message_price_unit=0,
+ message_unit_price=Decimal(0),
+ message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
- answer_unit_price=0,
- answer_price_unit=0,
+ answer_unit_price=Decimal(0),
+ answer_price_unit=Decimal("0.001"),
tokens=0,
- total_price=0,
+ total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
- created_by_role="account",
+ created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
- agent_thought.thought += thought
+ existing_thought = agent_thought.thought or ""
+ agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
- tools = agent_thought.tool
- if tools:
- tools = tools.split(";")
+ tool_names_raw = agent_thought.tool
+ if tool_names_raw:
+ tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
- try:
- tool_inputs = json.loads(agent_thought.tool_input)
- except Exception:
- tool_inputs = {tool: {} for tool in tools}
- try:
- tool_responses = json.loads(agent_thought.observation)
- except Exception:
- tool_responses = dict.fromkeys(tools, agent_thought.observation)
+ tool_input_payload = agent_thought.tool_input
+ if tool_input_payload:
+ try:
+ tool_inputs = json.loads(tool_input_payload)
+ except Exception:
+ tool_inputs = {tool: {} for tool in tool_names}
+ else:
+ tool_inputs = {tool: {} for tool in tool_names}
- for tool in tools:
+ observation_payload = agent_thought.observation
+ if observation_payload:
+ try:
+ tool_responses = json.loads(observation_payload)
+ except Exception:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+ else:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+
+ for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
- if not tools:
+ if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index 68d14ad027..7c5c9136a7 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
- assistant_message = AssistantPromptMessage(content="", tool_calls=[])
+ assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
- else:
- assistant_message.content = response
self._current_thoughts.append(assistant_message)
diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py
index 307af3747c..13c51529cc 100644
--- a/api/core/app/app_config/entities.py
+++ b/api/core/app/app_config/entities.py
@@ -1,4 +1,3 @@
-import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
- json_schema: str | None = Field(default=None)
+ json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
- def validate_json_schema(cls, schema: str | None) -> str | None:
+ def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
-
try:
- json_schema = json.loads(schema)
- except json.JSONDecodeError:
- raise ValueError(f"invalid json_schema value {schema}")
-
- try:
- Draft7Validator.check_schema(json_schema)
+ Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py
index e4b308a6f6..c21c494efe 100644
--- a/api/core/app/apps/advanced_chat/app_config_manager.py
+++ b/api/core/app/apps/advanced_chat/app_config_manager.py
@@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
-
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index d636548f2b..a258144d35 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
- def _initialize_conversation_variables(self) -> list[VariableUnion]:
+ def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
- return cast(list[VariableUnion], conversation_variables)
+ return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 4dd95be52d..da1e9f19b6 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -358,25 +358,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
- # For ANSWER nodes, check if we need to send a message_replace event
- # Only send if the final output differs from the accumulated task_state.answer
- # This happens when variables were updated by variable_assigner during workflow execution
- if event.node_type == NodeType.ANSWER and event.outputs:
- final_answer = event.outputs.get("answer")
- if final_answer is not None and final_answer != self._task_state.answer:
- logger.info(
- "ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
- final_answer,
- self._task_state.answer,
- )
- # Update the task state answer
- self._task_state.answer = str(final_answer)
- # Send message_replace event to update the UI
- yield self._message_cycle_manager.message_replace_to_stream_response(
- answer=str(final_answer),
- reason="variable_update",
- )
-
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py
index a6aace168e..07bae66867 100644
--- a/api/core/app/apps/base_app_generator.py
+++ b/api/core/app/apps/base_app_generator.py
@@ -1,4 +1,3 @@
-import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -76,12 +75,24 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
- if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
- raise ValueError("Invalid input type")
- if any(
- filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
- ):
- raise ValueError("Invalid input type")
+ invalid_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, dict)
+ and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
+ ]
+ if invalid_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_dict_keys}")
+
+ invalid_list_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, list)
+ and any(isinstance(item, dict) for item in v)
+ and entity_dictionary[k].type != VariableEntityType.FILE_LIST
+ ]
+ if invalid_list_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
return user_inputs
@@ -178,12 +189,8 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
- if not isinstance(value, str):
- raise ValueError(f"{variable_entity.variable} in input form must be a string")
- try:
- json.loads(value)
- except json.JSONDecodeError:
- raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
+ if value and not isinstance(value, dict):
+ raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 0165c74295..2be773f103 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -8,7 +8,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
-from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -23,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
+from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@@ -476,7 +477,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
- with Session(db.engine, expire_on_commit=False) as session:
+ with session_factory.create_session() as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py
index 77cc00bdc9..c070845b73 100644
--- a/api/core/app/layers/conversation_variable_persist_layer.py
+++ b/api/core/app/layers/conversation_variable_persist_layer.py
@@ -1,6 +1,6 @@
import logging
-from core.variables import Variable
+from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index 1785cbde4c..128c64ff2c 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
+request_error = httpx.RequestError
+max_retries_exceeded_error = MaxRetriesExceededError
+
+
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index b4c3ec1caf..be1e306d47 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -71,8 +71,8 @@ class LLMGenerator:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
- answer = cast(str, response.message.content)
- if answer is None:
+ answer = response.message.get_text_content()
+ if answer == "":
return ""
try:
result_dict = json.loads(answer)
@@ -184,7 +184,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- rule_config["prompt"] = cast(str, response.message.content)
+ rule_config["prompt"] = response.message.get_text_content()
except InvokeError as e:
error = str(e)
@@ -237,13 +237,11 @@ class LLMGenerator:
return rule_config
- rule_config["prompt"] = cast(str, prompt_content.message.content)
+ rule_config["prompt"] = prompt_content.message.get_text_content()
- if not isinstance(prompt_content.message.content, str):
- raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
- "INPUT_TEXT": prompt_content.message.content,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -253,7 +251,7 @@ class LLMGenerator:
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
- "INPUT_TEXT": prompt_content.message.content,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -263,7 +261,7 @@ class LLMGenerator:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
- rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
+ rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
@@ -272,7 +270,7 @@ class LLMGenerator:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
- rule_config["opening_statement"] = cast(str, statement_content.message.content)
+ rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -315,7 +313,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- generated_code = cast(str, response.message.content)
+ generated_code = response.message.get_text_content()
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -351,7 +349,7 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
- answer = cast(str, response.message.content)
+ answer = response.message.get_text_content()
return answer.strip()
@classmethod
@@ -375,10 +373,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- raw_content = response.message.content
-
- if not isinstance(raw_content, str):
- raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
+ raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 3ac83b4c96..9e46d72893 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
- if not super().is_empty() and not self.tool_calls:
- return False
-
- return True
+ return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py
index d6bd4d2015..22ad756c91 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/core/ops/aliyun_trace/aliyun_trace.py
@@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
+from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
@@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)
diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
index d3324f8f82..7624586367 100644
--- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py
+++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
- kind=trace_api.SpanKind.INTERNAL,
+ kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
index 20ff2d0875..9078031490 100644
--- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
+++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
-from opentelemetry.trace import Status, StatusCode
+from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
+ span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py
index 0e49824ad0..7a6a598a2f 100644
--- a/api/core/plugin/impl/base.py
+++ b/api/core/plugin/impl/base.py
@@ -320,18 +320,17 @@ class BasePluginClient:
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
- args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
- raise InvokeRateLimitError(description=args.get("description"))
+ raise InvokeRateLimitError(description=error_object.get("message"))
case InvokeAuthorizationError.__name__:
- raise InvokeAuthorizationError(description=args.get("description"))
+ raise InvokeAuthorizationError(description=error_object.get("message"))
case InvokeBadRequestError.__name__:
- raise InvokeBadRequestError(description=args.get("description"))
+ raise InvokeBadRequestError(description=error_object.get("message"))
case InvokeConnectionError.__name__:
- raise InvokeConnectionError(description=args.get("description"))
+ raise InvokeConnectionError(description=error_object.get("message"))
case InvokeServerUnavailableError.__name__:
- raise InvokeServerUnavailableError(description=args.get("description"))
+ raise InvokeServerUnavailableError(description=error_object.get("message"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
@@ -339,11 +338,11 @@ class BasePluginClient:
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
- raise TriggerPluginInvokeError(description=error_object.get("description"))
+ raise TriggerPluginInvokeError(description=error_object.get("message"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
- raise EventIgnoreError(description=error_object.get("description"))
+ raise EventIgnoreError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:
diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py
index 5b88742be5..2db5185a2c 100644
--- a/api/core/plugin/impl/endpoint.py
+++ b/api/core/plugin/impl/endpoint.py
@@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
+from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
+
+ This operation is idempotent: if the endpoint is already deleted (record not found),
+ it will return True instead of raising an error.
"""
- return self._request_with_plugin_daemon_response(
- "POST",
- f"plugin/{tenant_id}/endpoint/remove",
- bool,
- data={
- "endpoint_id": endpoint_id,
- },
- headers={
- "Content-Type": "application/json",
- },
- )
+ try:
+ return self._request_with_plugin_daemon_response(
+ "POST",
+ f"plugin/{tenant_id}/endpoint/remove",
+ bool,
+ data={
+ "endpoint_id": endpoint_id,
+ },
+ headers={
+ "Content-Type": "application/json",
+ },
+ )
+ except PluginDaemonInternalServerError as e:
+ # Make delete idempotent: if record is not found, consider it a success
+ if "record not found" in str(e.description).lower():
+ return True
+ raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 81a1d54199..283744b43b 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -5,10 +5,9 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
-from flask import has_request_context
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@@ -20,7 +19,6 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
-from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
@@ -30,6 +28,21 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
+def _try_resolve_user_from_request() -> Account | EndUser | None:
+ """
+ Try to resolve user from Flask request context.
+
+ Returns None if not in a request context or if user is not available.
+ """
+ # Note: `current_user` is a LocalProxy. Never compare it with None directly.
+ # Use _get_current_object() to dereference the proxy
+ user = getattr(current_user, "_get_current_object", lambda: current_user)()
+ # Check if we got a valid user object
+ if user is not None and hasattr(user, "id"):
+ return user
+ return None
+
+
class WorkflowTool(Tool):
"""
Workflow tool.
@@ -210,50 +223,44 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
- if has_request_context():
- return self._resolve_user_from_request()
- else:
- return self._resolve_user_from_database(user_id=user_id)
+ # Try to resolve user from request context first
+ user = _try_resolve_user_from_request()
+ if user is not None:
+ return user
- def _resolve_user_from_request(self) -> Account | EndUser | None:
- """
- Resolve user from Flask request context.
- """
- try:
- # Note: `current_user` is a LocalProxy. Never compare it with None directly.
- return getattr(current_user, "_get_current_object", lambda: current_user)()
- except Exception as e:
- logger.warning("Failed to resolve user from request context: %s", e)
- return None
+ # Fall back to database resolution
+ return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user from database (worker/Celery context).
"""
+ with session_factory.create_session() as session:
+ tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
+ tenant = session.scalar(tenant_stmt)
+ if not tenant:
+ return None
+
+ user_stmt = select(Account).where(Account.id == user_id)
+ user = session.scalar(user_stmt)
+ if user:
+ user.current_tenant = tenant
+ session.expunge(user)
+ return user
+
+ end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
+ end_user = session.scalar(end_user_stmt)
+ if end_user:
+ session.expunge(end_user)
+ return end_user
- tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
- tenant = db.session.scalar(tenant_stmt)
- if not tenant:
return None
- user_stmt = select(Account).where(Account.id == user_id)
- user = db.session.scalar(user_stmt)
- if user:
- user.current_tenant = tenant
- return user
-
- end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
- end_user = db.session.scalar(end_user_stmt)
- if end_user:
- return end_user
-
- return None
-
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@@ -265,22 +272,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
- if not workflow:
- raise ValueError("workflow not found or not published")
+ if not workflow:
+ raise ValueError("workflow not found or not published")
- return workflow
+ session.expunge(workflow)
+ return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
- if not app:
- raise ValueError("app not found")
+ if not app:
+ raise ValueError("app not found")
- return app
+ session.expunge(app)
+ return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py
index 7a1cbf9940..7498224923 100644
--- a/api/core/variables/__init__.py
+++ b/api/core/variables/__init__.py
@@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
+ VariableBase,
)
__all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
+ "VariableBase",
]
diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py
index 406b4e6f93..8330f1fe19 100644
--- a/api/core/variables/segments.py
+++ b/api/core/variables/segments.py
@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
-# - `Variable` and its subclasses, which are handled by `VariableUnion`.
+# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index 9fd0bbc5b2..a19c53918d 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
-class Variable(Segment):
+class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
-class StringVariable(StringSegment, Variable):
+class StringVariable(StringSegment, VariableBase):
pass
-class FloatVariable(FloatSegment, Variable):
+class FloatVariable(FloatSegment, VariableBase):
pass
-class IntegerVariable(IntegerSegment, Variable):
+class IntegerVariable(IntegerSegment, VariableBase):
pass
-class ObjectVariable(ObjectSegment, Variable):
+class ObjectVariable(ObjectSegment, VariableBase):
pass
-class ArrayVariable(ArraySegment, Variable):
+class ArrayVariable(ArraySegment, VariableBase):
pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
-class NoneVariable(NoneSegment, Variable):
+class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
-class FileVariable(FileSegment, Variable):
+class FileVariable(FileSegment, VariableBase):
pass
-class BooleanVariable(BooleanSegment, Variable):
+class BooleanVariable(BooleanSegment, VariableBase):
pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
-# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
-# Use `Variable` for type hinting when serialization is not required.
+# The `Variable` type is used to enable serialization and deserialization with Pydantic.
+# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
-# - All variants in `VariableUnion` must inherit from the `Variable` class.
-# - The union must include all non-abstract subclasses of `Segment`, except:
-VariableUnion: TypeAlias = Annotated[
+# - All variants in `Variable` must inherit from the `VariableBase` class.
+# - The union must include all non-abstract subclasses of `VariableBase`.
+Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py
new file mode 100644
index 0000000000..31e1f2c8d9
--- /dev/null
+++ b/api/core/workflow/context/__init__.py
@@ -0,0 +1,22 @@
+"""
+Execution Context - Context management for workflow execution.
+
+This package provides Flask-independent context management for workflow
+execution in multi-threaded environments.
+"""
+
+from core.workflow.context.execution_context import (
+ AppContext,
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+ capture_current_context,
+)
+
+__all__ = [
+ "AppContext",
+ "ExecutionContext",
+ "IExecutionContext",
+ "NullAppContext",
+ "capture_current_context",
+]
diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py
new file mode 100644
index 0000000000..5a4203be93
--- /dev/null
+++ b/api/core/workflow/context/execution_context.py
@@ -0,0 +1,216 @@
+"""
+Execution Context - Abstracted context management for workflow execution.
+"""
+
+import contextvars
+from abc import ABC, abstractmethod
+from collections.abc import Generator
+from contextlib import AbstractContextManager, contextmanager
+from typing import Any, Protocol, final, runtime_checkable
+
+
+class AppContext(ABC):
+ """
+ Abstract application context interface.
+
+ This abstraction allows workflow execution to work with or without Flask
+ by providing a common interface for application context management.
+ """
+
+ @abstractmethod
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ pass
+
+ @abstractmethod
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name (e.g., 'db', 'cache')."""
+ pass
+
+ @abstractmethod
+ def enter(self) -> AbstractContextManager[None]:
+ """Enter the application context."""
+ pass
+
+
+@runtime_checkable
+class IExecutionContext(Protocol):
+ """
+ Protocol for execution context.
+
+ This protocol defines the interface that all execution contexts must implement,
+ allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
+ """
+
+ def __enter__(self) -> "IExecutionContext":
+ """Enter the execution context."""
+ ...
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ ...
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ ...
+
+
+@final
+class ExecutionContext:
+ """
+ Execution context for workflow execution in worker threads.
+
+ This class encapsulates all context needed for workflow execution:
+ - Application context (Flask app or standalone)
+ - Context variables for Python contextvars
+ - User information (optional)
+
+ It is designed to be serializable and passable to worker threads.
+ """
+
+ def __init__(
+ self,
+ app_context: AppContext | None = None,
+ context_vars: contextvars.Context | None = None,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize execution context.
+
+ Args:
+ app_context: Application context (Flask or standalone)
+ context_vars: Python contextvars to preserve
+ user: User object (optional)
+ """
+ self._app_context = app_context
+ self._context_vars = context_vars
+ self._user = user
+
+ @property
+ def app_context(self) -> AppContext | None:
+ """Get application context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context | None:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """
+ Enter this execution context.
+
+ This is a convenience method that creates a context manager.
+ """
+ # Restore context variables if provided
+ if self._context_vars:
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter app context if available
+ if self._app_context is not None:
+ with self._app_context.enter():
+ yield
+ else:
+ yield
+
+ def __enter__(self) -> "ExecutionContext":
+ """Enter the execution context."""
+ self._cm = self.enter()
+ self._cm.__enter__()
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ if hasattr(self, "_cm"):
+ self._cm.__exit__(*args)
+
+
+class NullAppContext(AppContext):
+ """
+ Null implementation of AppContext for non-Flask environments.
+
+ This is used when running without Flask (e.g., in tests or standalone mode).
+ """
+
+ def __init__(self, config: dict[str, Any] | None = None) -> None:
+ """
+ Initialize null app context.
+
+ Args:
+ config: Optional configuration dictionary
+ """
+ self._config = config or {}
+ self._extensions: dict[str, Any] = {}
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ return self._config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get extension by name."""
+ return self._extensions.get(name)
+
+ def set_extension(self, name: str, extension: Any) -> None:
+ """Set extension by name."""
+ self._extensions[name] = extension
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter null context (no-op)."""
+ yield
+
+
+class ExecutionContextBuilder:
+ """
+ Builder for creating ExecutionContext instances.
+
+ This provides a fluent API for building execution contexts.
+ """
+
+ def __init__(self) -> None:
+ self._app_context: AppContext | None = None
+ self._context_vars: contextvars.Context | None = None
+ self._user: Any = None
+
+ def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
+ """Set application context."""
+ self._app_context = app_context
+ return self
+
+ def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
+ """Set context variables."""
+ self._context_vars = context_vars
+ return self
+
+ def with_user(self, user: Any) -> "ExecutionContextBuilder":
+ """Set user."""
+ self._user = user
+ return self
+
+ def build(self) -> ExecutionContext:
+ """Build the execution context."""
+ return ExecutionContext(
+ app_context=self._app_context,
+ context_vars=self._context_vars,
+ user=self._user,
+ )
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context from the calling environment.
+
+ Returns:
+ IExecutionContext with captured context
+ """
+ from context import capture_current_context
+
+ return capture_current_context()
diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py
index fd78248c17..75f47691da 100644
--- a/api/core/workflow/conversation_variable_updater.py
+++ b/api/core/workflow/conversation_variable_updater.py
@@ -1,7 +1,7 @@
import abc
from typing import Protocol
-from core.variables import Variable
+from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
- def update(self, conversation_id: str, variable: "Variable"):
+ def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
- :param variable: The `Variable` instance containing the updated value.
+ :param variable: The `VariableBase` instance containing the updated value.
"""
pass
diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py
index c08b62a253..bb3b13e8c6 100644
--- a/api/core/workflow/enums.py
+++ b/api/core/workflow/enums.py
@@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
+ @classmethod
+ def ended_values(cls) -> list[str]:
+ return [status.value for status in _END_STATE]
+
_END_STATE = frozenset(
[
diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py
index 6dce03c94d..41276eb444 100644
--- a/api/core/workflow/graph_engine/entities/commands.py
+++ b/api/core/workflow/graph_engine/entities/commands.py
@@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
class CommandType(StrEnum):
@@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
- value: VariableUnion = Field(description="New variable value")
+ value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index 9a870d7bf5..dbb2727c98 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -7,15 +7,13 @@ Domain-Driven Design principles for improved maintainability and testability.
from __future__ import annotations
-import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
-from flask import Flask, current_app
-
+from core.workflow.context import capture_current_context
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@@ -159,17 +157,8 @@ class GraphEngine:
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
- # Capture Flask app context for worker threads
- flask_app: Flask | None = None
- try:
- app = current_app._get_current_object() # type: ignore
- if isinstance(app, Flask):
- flask_app = app
- except RuntimeError:
- pass
-
- # Capture context variables for worker threads
- context_vars = contextvars.copy_context()
+ # Capture execution context for worker threads
+ execution_context = capture_current_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
@@ -177,8 +166,7 @@ class GraphEngine:
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
- flask_app=flask_app,
- context_vars=context_vars,
+ execution_context=execution_context,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py
index 83419830b6..95db5c5c92 100644
--- a/api/core/workflow/graph_engine/worker.py
+++ b/api/core/workflow/graph_engine/worker.py
@@ -5,26 +5,27 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
-import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
-from typing import final
+from typing import TYPE_CHECKING, final
from uuid import uuid4
-from flask import Flask
from typing_extensions import override
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
-from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
+if TYPE_CHECKING:
+ pass
+
@final
class Worker(threading.Thread):
@@ -44,8 +45,7 @@ class Worker(threading.Thread):
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
- flask_app: Flask | None = None,
- context_vars: contextvars.Context | None = None,
+ execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize worker thread.
@@ -56,19 +56,17 @@ class Worker(threading.Thread):
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
- flask_app: Optional Flask application for context preservation
- context_vars: Optional context variables to preserve in worker thread
+ execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
- self._flask_app = flask_app
- self._context_vars = context_vars
- self._last_task_time = time.time()
+ self._execution_context = execution_context
self._stop_event = stop_event
self._layers = layers if layers is not None else []
+ self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
@@ -135,11 +133,9 @@ class Worker(threading.Thread):
error: Exception | None = None
- if self._flask_app and self._context_vars:
- with preserve_flask_contexts(
- flask_app=self._flask_app,
- context_vars=self._context_vars,
- ):
+ # Execute the node with preserved context if execution context is provided
+ if self._execution_context is not None:
+ with self._execution_context:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py
index df76ebe882..9ce7d16e93 100644
--- a/api/core/workflow/graph_engine/worker_management/worker_pool.py
+++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py
@@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
-from typing import TYPE_CHECKING, final
+from typing import final
from configs import dify_config
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
@@ -20,11 +21,6 @@ from ..worker import Worker
logger = logging.getLogger(__name__)
-if TYPE_CHECKING:
- from contextvars import Context
-
- from flask import Flask
-
@final
class WorkerPool:
@@ -42,8 +38,7 @@ class WorkerPool:
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
- flask_app: "Flask | None" = None,
- context_vars: "Context | None" = None,
+ execution_context: IExecutionContext | None = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
@@ -57,8 +52,7 @@ class WorkerPool:
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
- flask_app: Optional Flask app for context preservation
- context_vars: Optional context variables
+ execution_context: Optional execution context for context preservation
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
@@ -67,8 +61,7 @@ class WorkerPool:
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
- self._flask_app = flask_app
- self._context_vars = context_vars
+ self._execution_context = execution_context
self._layers = layers
# Scaling parameters with defaults
@@ -152,8 +145,7 @@ class WorkerPool:
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
- flask_app=self._flask_app,
- context_vars=self._context_vars,
+ execution_context=self._execution_context,
stop_event=self._stop_event,
)
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index 931c6113a7..429f8411a6 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
+from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
+ http_client: HttpClientProtocol = ssrf_proxy,
+ file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@@ -104,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
+ self._http_client = http_client
+ self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@@ -200,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
- self.content = file_manager.download(file)
+ self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@@ -239,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
- file_manager.download(file),
+ self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@@ -332,19 +337,18 @@ class Executor:
do http request depending on api bundle
"""
_METHOD_MAP = {
- "get": ssrf_proxy.get,
- "head": ssrf_proxy.head,
- "post": ssrf_proxy.post,
- "put": ssrf_proxy.put,
- "delete": ssrf_proxy.delete,
- "patch": ssrf_proxy.patch,
+ "get": self._http_client.get,
+ "head": self._http_client.head,
+ "post": self._http_client.post,
+ "put": self._http_client.put,
+ "delete": self._http_client.delete,
+ "patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
- "url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
@@ -357,8 +361,12 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
- response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
- except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
+ response: httpx.Response = _METHOD_MAP[method_lc](
+ url=self.url,
+ **request_args,
+ max_retries=self.max_retries,
+ )
+ except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index 9bd1cb9761..964e53e03c 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -1,10 +1,11 @@
import logging
import mimetypes
-from collections.abc import Mapping, Sequence
-from typing import Any
+from collections.abc import Callable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any
from configs import dify_config
-from core.file import File, FileTransferMethod
+from core.file import File, FileTransferMethod, file_manager
+from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ http_client: HttpClientProtocol = ssrf_proxy,
+ tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ file_manager: FileManagerProtocol = file_manager,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._http_client = http_client
+ self._tool_file_manager_factory = tool_file_manager_factory
+ self._file_manager = file_manager
+
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
+ http_client=self._http_client,
+ file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
- tool_file_manager = ToolFileManager()
+ tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index e5d86414c1..569a4196fb 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -1,17 +1,15 @@
-import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
-from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
-from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -51,6 +48,7 @@ from .exc import (
)
if TYPE_CHECKING:
+ from core.workflow.context import IExecutionContext
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@@ -240,7 +238,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
- dict[str, VariableUnion],
+ dict[str, Variable],
LLMUsage,
]
],
@@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._execute_single_iteration_parallel,
index=index,
item=item,
- flask_app=current_app._get_current_object(), # type: ignore
- context_vars=contextvars.copy_context(),
+ execution_context=self._capture_execution_context(),
)
future_to_index[future] = index
@@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self,
index: int,
item: object,
- flask_app: Flask,
- context_vars: contextvars.Context,
- ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
+ execution_context: "IExecutionContext",
+ ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
- with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
+ with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
@@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph_engine.graph_runtime_state.llm_usage,
)
+ def _capture_execution_context(self) -> "IExecutionContext":
+ """Capture current execution context for parallel iterations."""
+ from core.workflow.context import capture_current_context
+
+ return capture_current_context()
+
def _handle_iteration_success(
self,
started_at: datetime,
@@ -515,11 +517,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
- def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
+ def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
- def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
+ def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py
index 557d3a330d..5c04e5110f 100644
--- a/api/core/workflow/nodes/node_factory.py
+++ b/api/core/workflow/nodes/node_factory.py
@@ -1,16 +1,21 @@
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
+from core.file import file_manager
+from core.helper import ssrf_proxy
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
+from core.tools.tool_file_manager import ToolFileManager
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
+from core.workflow.nodes.http_request.node import HttpRequestNode
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
@@ -43,6 +48,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
+ http_request_http_client: HttpClientProtocol = ssrf_proxy,
+ http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ http_request_file_manager: FileManagerProtocol = file_manager,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@@ -61,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
+ self._http_request_http_client = http_request_http_client
+ self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
+ self._http_request_file_manager = http_request_file_manager
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@@ -113,6 +124,7 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
+
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
@@ -122,6 +134,17 @@ class DifyNodeFactory(NodeFactory):
template_renderer=self._template_renderer,
)
+ if node_type == NodeType.HTTP_REQUEST:
+ return HttpRequestNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ http_client=self._http_request_http_client,
+ tool_file_manager_factory=self._http_request_tool_file_manager_factory,
+ file_manager=self._http_request_file_manager,
+ )
+
return node_class(
id=node_id,
config=node_config,
diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py
new file mode 100644
index 0000000000..e7dcf62fcf
--- /dev/null
+++ b/api/core/workflow/nodes/protocols.py
@@ -0,0 +1,29 @@
+from typing import Protocol
+
+import httpx
+
+from core.file import File
+
+
+class HttpClientProtocol(Protocol):
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]: ...
+
+ @property
+ def request_error(self) -> type[Exception]: ...
+
+ def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+ def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
+
+
+class FileManagerProtocol(Protocol):
+ def download(self, f: File, /) -> bytes: ...
diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py
index 36fc5078c5..53c1b4ee6b 100644
--- a/api/core/workflow/nodes/start/start_node.py
+++ b/api/core/workflow/nodes/start/start_node.py
@@ -1,4 +1,3 @@
-import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
+ # If no value provided, skip further processing for this key
+ if not value:
+ continue
+
+ if not isinstance(value, dict):
+ raise ValueError(f"JSON object for '{key}' must be an object")
+
+ # Overwrite with normalized dict to ensure downstream consistency
+ node_inputs[key] = value
+
+ # If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
- if not value:
- continue
-
try:
- json_schema = json.loads(schema)
- except json.JSONDecodeError as e:
- raise ValueError(f"{schema} must be a valid JSON object")
-
- try:
- json_value = json.loads(value)
- except json.JSONDecodeError as e:
- raise ValueError(f"{value} must be a valid JSON object")
-
- try:
- Draft7Validator(json_schema).validate(json_value)
+ Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
- node_inputs[key] = json_value
diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py
index d2ea7d94ea..9f5818f4bb 100644
--- a/api/core/workflow/nodes/variable_assigner/v1/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v1/node.py
@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_runtime_state=graph_runtime_state,
)
+ def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
+ """
+ Check if this Variable Assigner node blocks the output of specific variables.
+
+ Returns True if this node updates any of the requested conversation variables.
+ """
+ assigned_selector = tuple(self.node_data.assigned_variable_selector)
+ return assigned_selector in variable_selectors
+
@classmethod
def version(cls) -> str:
return "1"
@@ -64,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
- if not isinstance(original_variable, Variable):
+ if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index 486e6bb6a7..5857702e72 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
@@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
- variable: Variable,
+ variable: VariableBase,
operation: Operation,
value: Any,
):
diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py
index 85ceb9d59e..d205c6ac8f 100644
--- a/api/core/workflow/runtime/variable_pool.py
+++ b/api/core/workflow/runtime/variable_pool.py
@@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
-from core.variables import Segment, SegmentGroup, Variable
+from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
-from core.variables.variables import RAGPipelineVariableInput, VariableUnion
+from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
- variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
+ variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables",
default_factory=SystemVariable.empty,
)
- environment_variables: Sequence[VariableUnion] = Field(
+ environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
- conversation_variables: Sequence[VariableUnion] = Field(
+ conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py
index ea0bdc3537..7992785fe1 100644
--- a/api/core/workflow/variable_loader.py
+++ b/api/core/workflow/variable_loader.py
@@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
-from core.variables import Variable
+from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
- :return: a list of Variable objects that match the provided selectors.
+ :return: a list of VariableBase objects that match the provided selectors.
"""
pass
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index ddf545bb34..ee37314721 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -19,6 +19,7 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
+from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -136,13 +137,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
- node_config = workflow.get_node_config_by_id(node_id)
+ node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
- # Get node class
+ # Get node type
node_type = NodeType(node_config_data.get("type"))
- node_version = node_config_data.get("version", "1")
- node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -158,12 +157,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
- node = node_cls(
- id=str(uuid.uuid4()),
- config=node_config,
+ node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
+ node = node_factory.create_node(node_config)
+ node_cls = type(node)
try:
# variable selector to variable mapping
@@ -190,8 +189,7 @@ class WorkflowEntry:
)
try:
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
@@ -324,8 +322,7 @@ class WorkflowEntry:
tenant_id=tenant_id,
)
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
return node, generator
except Exception as e:
@@ -431,3 +428,26 @@ class WorkflowEntry:
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)
+
+ @staticmethod
+ def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]:
+ """
+ Wraps a node's run method with OpenTelemetry tracing and returns a generator.
+ """
+ # Wrap node.run() with ObservabilityLayer hooks to produce node-level spans
+ layer = ObservabilityLayer()
+ layer.on_graph_start()
+ node.ensure_execution_id()
+
+ def _gen():
+ error: Exception | None = None
+ layer.on_node_run_start(node)
+ try:
+ yield from node.run()
+ except Exception as exc:
+ error = exc
+ raise
+ finally:
+ layer.on_node_run_end(node, error)
+
+ return _gen()
diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py
index c79764983b..d37217e168 100644
--- a/api/events/event_handlers/__init__.py
+++ b/api/events/event_handlers/__init__.py
@@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
)
+from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
@@ -30,6 +31,7 @@ __all__ = [
"handle_create_installed_app_when_app_created",
"handle_create_site_record_when_app_created",
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
+ "handle_queue_credential_sync_when_tenant_created",
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",
diff --git a/api/events/event_handlers/queue_credential_sync_when_tenant_created.py b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py
new file mode 100644
index 0000000000..6566c214b0
--- /dev/null
+++ b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py
@@ -0,0 +1,19 @@
+from configs import dify_config
+from events.tenant_event import tenant_was_created
+from services.enterprise.workspace_sync import WorkspaceSyncService
+
+
+@tenant_was_created.connect
+def handle(sender, **kwargs):
+ """Queue credential sync when a tenant/workspace is created."""
+ # Only queue sync tasks if plugin manager (enterprise feature) is enabled
+ if not dify_config.ENTERPRISE_ENABLED:
+ return
+
+ tenant = sender
+
+ # Determine source from kwargs if available, otherwise use generic
+ source = kwargs.get("source", "tenant_created")
+
+ # Queue credential sync task to Redis for enterprise backend to process
+ WorkspaceSyncService.queue_credential_sync(tenant.id, source=source)
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 2fbab001d0..08cf96c1c1 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -163,6 +163,13 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
+ if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
+ # for saas only
+ imports.append("schedule.clean_workflow_runs_task")
+ beat_schedule["clean_workflow_runs_task"] = {
+ "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
+ "schedule": crontab(minute="0", hour="0"),
+ }
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {
diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py
index daa3756dba..51e2c6cdd5 100644
--- a/api/extensions/ext_commands.py
+++ b/api/extensions/ext_commands.py
@@ -4,6 +4,8 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
+ clean_expired_messages,
+ clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
@@ -56,6 +58,8 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
+ clean_workflow_runs,
+ clean_expired_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)
diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py
index 502f0bb46b..cda2d1ad1e 100644
--- a/api/extensions/ext_logstore.py
+++ b/api/extensions/ext_logstore.py
@@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
+from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
+ Logstore is considered enabled when:
+ 1. All required Aliyun SLS environment variables are set
+ 2. At least one repository configuration points to a logstore implementation
+
Returns:
- True if all required Aliyun SLS environment variables are set, False otherwise
+ True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
+ # Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
- all_set = all(os.environ.get(var) for var in required_vars)
+ sls_vars_set = all(os.environ.get(var) for var in required_vars)
- if not all_set:
- logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
+ if not sls_vars_set:
+ return False
- return all_set
+ # Check if any repository configuration points to logstore implementation
+ repository_configs = [
+ dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
+ dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_RUN_REPOSITORY,
+ ]
+
+ uses_logstore = any("logstore" in config.lower() for config in repository_configs)
+
+ if not uses_logstore:
+ return False
+
+ logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
+ return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
-
- This function:
- 1. Creates Aliyun SLS project if it doesn't exist
- 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
- 3. Creates indexes with field configurations based on PostgreSQL table structures
-
- This operation is idempotent and only executes once during application startup.
+ If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
- logger.info("Initializing logstore...")
+ logger.info("Initializing Aliyun SLS Logstore...")
- # Create logstore client and initialize project/logstores/indexes
+ # Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
- # Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
+
except Exception:
- logger.exception("Failed to initialize logstore")
- # Don't raise - allow application to continue even if logstore init fails
- # This ensures that the application can still run if logstore is misconfigured
+ logger.exception(
+ "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
+ "Application will continue but logstore features will NOT work.",
+ os.environ.get("ALIYUN_SLS_ENDPOINT"),
+ os.environ.get("ALIYUN_SLS_REGION"),
+ os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
+ os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
+ )
+ # Don't raise - allow application to continue even if logstore setup fails
diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py
index 8c64a25be4..f6a4765f14 100644
--- a/api/extensions/logstore/aliyun_logstore.py
+++ b/api/extensions/logstore/aliyun_logstore.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import os
+import socket
import threading
import time
from collections.abc import Sequence
@@ -179,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
- self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
+ self.log_enabled: bool = (
+ os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
+ or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
+ )
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
+ # Get timeout configuration
+ check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
+
+ # Pre-check endpoint connectivity to prevent indefinite hangs
+ self._check_endpoint_connectivity(self.endpoint, check_timeout)
+
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@@ -199,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
+ @staticmethod
+ def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
+ """
+ Check if the SLS endpoint is reachable before creating LogClient.
+ Prevents indefinite hangs when the endpoint is unreachable.
+
+ Args:
+ endpoint: SLS endpoint URL
+ timeout: Connection timeout in seconds
+
+ Raises:
+ ConnectionError: If endpoint is not reachable
+ """
+ # Parse endpoint URL to extract hostname and port
+ from urllib.parse import urlparse
+
+ parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
+ hostname = parsed_url.hostname
+ port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
+
+ if not hostname:
+ raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
+
+ sock = None
+ try:
+ # Create socket and set timeout
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(timeout)
+ sock.connect((hostname, port))
+ except Exception as e:
+ # Catch all exceptions and provide clear error message
+ error_type = type(e).__name__
+ raise ConnectionError(
+ f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
+ ) from e
+ finally:
+ # Ensure socket is properly closed
+ if sock:
+ try:
+ sock.close()
+ except Exception: # noqa: S110
+ pass # Ignore errors during cleanup
+
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@@ -220,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
- logger.info("Successfully connected to project %s using PG protocol", self.project_name)
+ logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
- logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
+ logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
- logger.warning(
- "Failed to establish PG connection for project %s: %s. Will use SDK mode.",
- self.project_name,
- str(e),
- )
+ logger.info("Using SDK mode for project %s", self.project_name)
+ logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@@ -246,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
- logger.info(
- "Attempting delayed PG connection for newly created project %s ...",
- self.project_name,
- )
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@@ -284,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
- logger.info(
- "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
- self.project_name,
- self.__class__._pg_connection_delay,
- )
+ logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@@ -299,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
- logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@@ -318,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
- "Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
- "PG protocol requires scan_index to be enabled.",
+ "Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
+ self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@@ -748,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
- # Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@@ -770,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
- # Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@@ -845,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
- # Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@@ -853,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
- query,
- sql,
+ full_query,
)
try:
@@ -865,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
- # Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py
index 35aa51ce53..874c20d144 100644
--- a/api/extensions/logstore/aliyun_logstore_pg.py
+++ b/api/extensions/logstore/aliyun_logstore_pg.py
@@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
-import psycopg2.pool
-from psycopg2 import InterfaceError, OperationalError
+from sqlalchemy import create_engine
from configs import dify_config
@@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
- """
- PostgreSQL protocol support for Aliyun SLS LogStore.
-
- Handles PG connection pooling and operations for regions that support PG protocol.
- """
+ """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
- self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
+ self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
- """
- Check if a TCP port is reachable using socket connection.
-
- This provides a fast check before attempting full database connection,
- preventing long waits when connecting to unsupported regions.
-
- Args:
- host: Hostname or IP address
- port: Port number
- timeout: Connection timeout in seconds (default: 2.0)
-
- Returns:
- True if port is reachable, False otherwise
- """
+ """Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
- """
- Initialize PostgreSQL connection pool for SLS PG protocol support.
-
- Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
- _use_pg_protocol to True and creates a connection pool. If connection fails
- (region doesn't support PG protocol or other errors), returns False.
-
- Returns:
- True if PG protocol is supported and initialized, False otherwise
- """
+ """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
- # Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
- # Get pool configuration
- pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
+ # Pool configuration
+ pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
+ max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
+ pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
+ pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
- logger.debug(
- "Check PG protocol connection to SLS: host=%s, project=%s",
- pg_host,
- self.project_name,
- )
+ logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
- # Fast port connectivity check before attempting full connection
- # This prevents long waits when connecting to unsupported regions
+ # Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
- logger.info(
- "USE SDK mode for read/write operations, host=%s",
- pg_host,
- )
+ logger.debug("Using SDK mode for host=%s", pg_host)
return False
- # Create connection pool
- self._pg_pool = psycopg2.pool.SimpleConnectionPool(
- minconn=1,
- maxconn=pg_max_connections,
- host=pg_host,
- port=5432,
- database=self.project_name,
- user=self._access_key_id,
- password=self._access_key_secret,
- sslmode="require",
- connect_timeout=5,
- application_name=f"Dify-{dify_config.project.version}",
+ # Build connection URL
+ from urllib.parse import quote_plus
+
+ username = quote_plus(self._access_key_id)
+ password = quote_plus(self._access_key_secret)
+ database_url = (
+ f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
- # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
- # Connection pool creation success already indicates connectivity
+ # Create SQLAlchemy engine with connection pool
+ self._engine = create_engine(
+ database_url,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ pool_pre_ping=pool_pre_ping,
+ pool_timeout=30,
+ connect_args={
+ "connect_timeout": 5,
+ "application_name": f"Dify-{dify_config.project.version}-fixautocommit",
+ "keepalives": 1,
+ "keepalives_idle": 60,
+ "keepalives_interval": 10,
+ "keepalives_count": 5,
+ },
+ )
self._use_pg_protocol = True
logger.info(
- "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
+ "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
+ pool_size,
+ pool_recycle,
)
return True
except Exception as e:
- # PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
- if self._pg_pool:
+ if self._engine:
try:
- self._pg_pool.closeall()
+ self._engine.dispose()
except Exception:
- logger.debug("Failed to close PG connection pool during cleanup, ignoring")
- self._pg_pool = None
+ logger.debug("Failed to dispose engine during cleanup, ignoring")
+ self._engine = None
- logger.info(
- "PG protocol connection failed (region may not support PG protocol): %s. "
- "Falling back to SDK mode for read/write operations.",
- str(e),
- )
- return False
-
- def _is_connection_valid(self, conn: Any) -> bool:
- """
- Check if a connection is still valid.
-
- Args:
- conn: psycopg2 connection object
-
- Returns:
- True if connection is valid, False otherwise
- """
- try:
- # Check if connection is closed
- if conn.closed:
- return False
-
- # Quick ping test - execute a lightweight query
- # For SLS PG protocol, we can't use SELECT 1 without FROM,
- # so we just check the connection status
- with conn.cursor() as cursor:
- cursor.execute("SELECT 1")
- cursor.fetchone()
- return True
- except Exception:
+ logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
- """
- Context manager to get a PostgreSQL connection from the pool.
+ """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
+ if not self._engine:
+ raise RuntimeError("SQLAlchemy engine is not initialized")
- Automatically validates and refreshes stale connections.
-
- Note: Aliyun SLS PG protocol does not support transactions, so we always
- use autocommit mode.
-
- Yields:
- psycopg2 connection object
-
- Raises:
- RuntimeError: If PG pool is not initialized
- """
- if not self._pg_pool:
- raise RuntimeError("PG connection pool is not initialized")
-
- conn = self._pg_pool.getconn()
+ connection = self._engine.raw_connection()
try:
- # Validate connection and get a fresh one if needed
- if not self._is_connection_valid(conn):
- logger.debug("Connection is stale, marking as bad and getting a new one")
- # Mark connection as bad and get a new one
- self._pg_pool.putconn(conn, close=True)
- conn = self._pg_pool.getconn()
-
- # Aliyun SLS PG protocol does not support transactions, always use autocommit
- conn.autocommit = True
- yield conn
+ connection.autocommit = True # SLS PG protocol does not support transactions
+ yield connection
+ except Exception:
+ raise
finally:
- # Return connection to pool (or close if it's bad)
- if self._is_connection_valid(conn):
- self._pg_pool.putconn(conn)
- else:
- self._pg_pool.putconn(conn, close=True)
+ connection.close()
def close(self) -> None:
- """Close the PostgreSQL connection pool."""
- if self._pg_pool:
+ """Dispose SQLAlchemy engine and close all connections."""
+ if self._engine:
try:
- self._pg_pool.closeall()
- logger.info("PG connection pool closed")
+ self._engine.dispose()
+ logger.info("SQLAlchemy engine disposed")
except Exception:
- logger.exception("Failed to close PG connection pool")
+ logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
- """
- Check if an error is retriable (connection-related issues).
-
- Args:
- error: Exception to check
-
- Returns:
- True if the error is retriable, False otherwise
- """
- # Retry on connection-related errors
- if isinstance(error, (OperationalError, InterfaceError)):
+ """Check if error is retriable (connection-related issues)."""
+ # Check for psycopg2 connection errors directly
+ if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
- # Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
+ "operational error",
+ "interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
- """
- Write log to SLS using PostgreSQL protocol with automatic retry.
-
- Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
- writes with log_version field for versioning, same as SDK implementation.
-
- Args:
- logstore: Name of the logstore table
- contents: List of (field_name, value) tuples
- log_enabled: Whether to enable logging
-
- Raises:
- psycopg2.Error: If database operation fails after all retries
- """
+ """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
- # Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
-
- # Build INSERT statement with literal values
- # Note: Aliyun SLS PG protocol doesn't support parameterized queries,
- # so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
- # Retry configuration
max_retries = 3
- retry_delay = 0.1 # Start with 100ms
+ retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
- # Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
- # Success - exit retry loop
return
except psycopg2.Error as e:
- # Check if error is retriable
if not self._is_retriable_error(e):
- # Not a retriable error (e.g., data validation error), fail immediately
- logger.exception(
- "Failed to put logs to logstore %s via PG protocol (non-retriable error)",
- logstore,
- )
+ logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
- # Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
- "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
+ "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
- retry_delay *= 2 # Exponential backoff
+ retry_delay *= 2
else:
- # Last attempt failed
- logger.exception(
- "Failed to put logs to logstore %s via PG protocol after %d attempts",
- logstore,
- max_retries,
- )
+ logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
- """
- Execute SQL query using PostgreSQL protocol with automatic retry.
-
- Args:
- sql: SQL query string
- logstore: Name of the logstore (for logging purposes)
- log_enabled: Whether to enable logging
-
- Returns:
- List of result rows as dictionaries
-
- Raises:
- psycopg2.Error: If database operation fails after all retries
- """
+ """Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
- # Retry configuration
max_retries = 3
- retry_delay = 0.1 # Start with 100ms
+ retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
-
- # Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
- # Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
- # Check if error is retriable
if not self._is_retriable_error(e):
- # Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
- "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
+ "Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
- # Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
- "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
+ "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
- retry_delay *= 2 # Exponential backoff
+ retry_delay *= 2
else:
- # Last attempt failed
logger.exception(
- "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
+ "Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
- # This line should never be reached due to raise above, but makes type checker happy
return []
diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py
index e69de29bb2..b5a4fcf844 100644
--- a/api/extensions/logstore/repositories/__init__.py
+++ b/api/extensions/logstore/repositories/__init__.py
@@ -0,0 +1,29 @@
+"""
+LogStore repository utilities.
+"""
+
+from typing import Any
+
+
+def safe_float(value: Any, default: float = 0.0) -> float:
+ """
+ Safely convert a value to float, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ return default
+
+
+def safe_int(value: Any, default: int = 0) -> int:
+ """
+ Safely convert a value to int, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return int(float(value))
+ except (ValueError, TypeError):
+ return default
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
index 8c804d6bb5..f67723630b 100644
--- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
@@ -14,6 +14,8 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
- # Numeric fields with defaults
- model.index = int(data.get("index", 0))
- model.elapsed_time = float(data.get("elapsed_time", 0))
+ model.index = safe_int(data.get("index", 0))
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_id = escape_identifier(workflow_id)
+ escaped_node_id = escape_identifier(node_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE tenant_id = '{tenant_id}'
- AND app_id = '{app_id}'
- AND workflow_id = '{workflow_id}'
- AND node_id = '{node_id}'
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_id = '{escaped_workflow_id}'
+ AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
- f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE tenant_id = '{tenant_id}'
- AND app_id = '{app_id}'
- AND workflow_run_id = '{workflow_run_id}'
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
- query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
+ query = (
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_run_id: {escaped_workflow_run_id}"
+ )
from_time = 0
to_time = int(time.time()) # now
@@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_execution_id = escape_identifier(execution_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
- tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
+ if tenant_id:
+ escaped_tenant_id = escape_identifier(tenant_id)
+ tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
+ else:
+ tenant_filter = ""
+
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
- WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
+ WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
- query = f"id: {execution_id} and tenant_id: {tenant_id}"
+ query = (
+ f"id:{escape_logstore_query_value(execution_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
+ )
else:
- query = f"id: {execution_id}"
+ query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
index 252cdcc4df..14382ed876 100644
--- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
@@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
+- SQL injection prevention via parameter escaping
"""
import logging
@@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
- # Numeric fields with defaults
- model.total_tokens = int(data.get("total_tokens", 0))
- model.total_steps = int(data.get("total_steps", 0))
- model.exceptions_count = int(data.get("exceptions_count", 0))
+ model.total_tokens = safe_int(data.get("total_tokens", 0))
+ model.total_steps = safe_int(data.get("total_steps", 0))
+ model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
- model.elapsed_time = float(data.get("elapsed_time", 0))
+ # Use safe conversion to handle 'null' strings and None values
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
- if isinstance(triggered_from, WorkflowRunTriggeredFrom):
+ if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
- # Build triggered_from filter
- triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
- # Build status filter
- status_filter = f"AND status='{status}'" if status else ""
+ # Build triggered_from filter with escaped values
+ # Support both enum and string values for triggered_from
+ triggered_from_filter = " OR ".join(
+ [
+ f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
+ for tf in triggered_from_list
+ ]
+ )
+
+ # Build status filter with escaped value
+ status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
+ # Escape parameters to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
- WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
+ WHERE id = '{escaped_run_id}'
+ AND tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
- query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
+ query = (
+ f"id:{escape_logstore_query_value(run_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
+ f"and app_id:{escape_logstore_query_value(app_id)}"
+ )
from_time = 0
to_time = int(time.time()) # now
@@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
+ # Escape parameter to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
- WHERE id = '{run_id}' AND __time__ > 0
+ WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
- query = f"id: {run_id}"
+ # Note: Values must be quoted in LogStore query syntax
+ query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
# Build time range filter
time_filter = ""
if time_range:
@@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
+ escaped_status = escape_sql_string(status)
+
if status == "running":
# Running status requires window function
sql = f"""
@@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
- AND status='{status}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
+ # Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
- # Build time range filter
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
- WHERE tenant_id='{tenant_id}'
- AND app_id='{app_id}'
- AND triggered_from='{triggered_from}'
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by
diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
index 1119534d52..9928879a7b 100644
--- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
@@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
-def to_serializable(obj):
- """
- Convert non-JSON-serializable objects into JSON-compatible formats.
-
- - Uses `to_dict()` if it's a callable method.
- - Falls back to string representation.
- """
- if hasattr(obj, "to_dict") and callable(obj.to_dict):
- return obj.to_dict()
- return str(obj)
-
-
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
- self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
@@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
+ # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
+ json_converter = WorkflowRuntimeTypeConverter()
+
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
("version", domain_model.workflow_version),
(
"graph",
- json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
+ json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
- json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
+ json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"outputs",
- json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
+ json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),
diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
index 400a089516..4897171b12 100644
--- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
+++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
@@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
- index=int(data.get("index", 0)),
+ index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
- elapsed_time=float(data.get("elapsed_time", 0.0)),
+ elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
- self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
- For LogStore implementation, this is similar to save() since we always write
- complete records. We append a new record with updated data fields.
+ For LogStore implementation, this is a no-op for the LogStore write because save()
+ already writes all fields including inputs, process_data, and outputs. The caller
+ typically calls save() first to persist status/metadata, then calls save_execution_data()
+ to persist data fields. Since LogStore writes complete records atomically, we don't
+ need a separate write here to avoid duplicate records.
+
+ However, if dual-write is enabled, we still need to call the SQL repository's
+ save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
- logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
- # In LogStore, we simply write a new complete record with the data
- # The log_version timestamp will ensure this is treated as the latest version
- self.save(execution)
+ logger.debug(
+ "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
+ execution.id,
+ execution.node_execution_id,
+ )
+ # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
+ # Calling save() again would create a duplicate record in the append-only LogStore
+
+ # Dual-write to SQL database if enabled (for safe migration)
+ if self._enable_dual_write:
+ try:
+ self.sql_repository.save_execution_data(execution)
+ logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
+ except Exception:
+ logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
+ # Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
- Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
- This ensures we only get the final version of each node execution.
+ Uses LogStore SQL query with window function to get the latest version of each node execution.
+ This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
- This method filters by finished_at IS NOT NULL to avoid duplicates from
- version updates. For complete history including intermediate states,
- a different query strategy would be needed.
+ This method uses ROW_NUMBER() window function partitioned by node_execution_id
+ to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
- # Build SQL query with deduplication using finished_at IS NOT NULL
- # This optimization avoids window functions for common case where we only
- # want the final state of each node execution
+ # Build SQL query with deduplication using window function
+ # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
+ # ensures we get the latest version of each node execution
- # Build ORDER BY clause
+ # Escape parameters to prevent SQL injection
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+ escaped_tenant_id = escape_identifier(self._tenant_id)
+
+ # Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
- sql = f"""
- SELECT *
- FROM {AliyunLogStore.workflow_node_execution_logstore}
- WHERE workflow_run_id='{workflow_run_id}'
- AND tenant_id='{self._tenant_id}'
- AND finished_at IS NOT NULL
- """
-
+ # Build app_id filter for subquery
+ app_id_filter = ""
if self._app_id:
- sql += f" AND app_id='{self._app_id}'"
+ escaped_app_id = escape_identifier(self._app_id)
+ app_id_filter = f" AND app_id='{escaped_app_id}'"
+
+ # Use window function to get latest version of each node execution
+ sql = f"""
+ SELECT * FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
+ FROM {AliyunLogStore.workflow_node_execution_logstore}
+ WHERE workflow_run_id='{escaped_workflow_run_id}'
+ AND tenant_id='{escaped_tenant_id}'
+ {app_id_filter}
+ ) t
+ WHERE rn = 1
+ """
if order_clause:
sql += f" {order_clause}"
diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py
new file mode 100644
index 0000000000..d88d6bd959
--- /dev/null
+++ b/api/extensions/logstore/sql_escape.py
@@ -0,0 +1,134 @@
+"""
+SQL Escape Utility for LogStore Queries
+
+This module provides escaping utilities to prevent injection attacks in LogStore queries.
+
+LogStore supports two query modes:
+1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
+2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
+
+Key Security Concerns:
+- Prevent tenant A from accessing tenant B's data via injection
+- SLS queries are read-only, so we focus on data access control
+- Different escaping strategies for SQL vs LogStore query syntax
+"""
+
+
+def escape_sql_string(value: str) -> str:
+ """
+ Escape a string value for safe use in SQL queries.
+
+ This function escapes single quotes by doubling them, which is the standard
+ SQL escaping method. This prevents SQL injection by ensuring that user input
+ cannot break out of string literals.
+
+ Args:
+ value: The string value to escape
+
+ Returns:
+ Escaped string safe for use in SQL queries
+
+ Examples:
+ >>> escape_sql_string("normal_value")
+ "normal_value"
+ >>> escape_sql_string("value' OR '1'='1")
+ "value'' OR ''1''=''1"
+ >>> escape_sql_string("tenant's_id")
+ "tenant''s_id"
+
+ Security:
+ - Prevents breaking out of string literals
+ - Stops injection attacks like: ' OR '1'='1
+ - Protects against cross-tenant data access
+ """
+ if not value:
+ return value
+
+ # Escape single quotes by doubling them (standard SQL escaping)
+ # This prevents breaking out of string literals in SQL queries
+ return value.replace("'", "''")
+
+
+def escape_identifier(value: str) -> str:
+ """
+ Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
+
+ This function is for PG protocol mode (SQL syntax).
+ For SDK mode, use escape_logstore_query_value() instead.
+
+ Args:
+ value: The identifier value to escape
+
+ Returns:
+ Escaped identifier safe for use in SQL queries
+
+ Examples:
+ >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
+ "550e8400-e29b-41d4-a716-446655440000"
+ >>> escape_identifier("tenant_id' OR '1'='1")
+ "tenant_id'' OR ''1''=''1"
+
+ Security:
+ - Prevents SQL injection via identifiers
+ - Stops cross-tenant access attempts
+ - Works for UUIDs, alphanumeric IDs, and similar identifiers
+ """
+ # For identifiers, use the same escaping as strings
+ # This is simple and effective for preventing injection
+ return escape_sql_string(value)
+
+
+def escape_logstore_query_value(value: str) -> str:
+ """
+ Escape value for LogStore query syntax (SDK mode).
+
+ LogStore query syntax rules:
+ 1. Keywords (and/or/not) are case-insensitive
+ 2. Single quotes are ordinary characters (no special meaning)
+ 3. Double quotes wrap values: key:"value"
+ 4. Backslash is the escape character:
+ - \" for double quote inside value
+ - \\ for backslash itself
+ 5. Parentheses can change query structure
+
+ To prevent injection:
+ - Wrap value in double quotes to treat special chars as literals
+ - Escape backslashes and double quotes using backslash
+
+ Args:
+ value: The value to escape for LogStore query syntax
+
+ Returns:
+ Quoted and escaped value safe for LogStore query syntax (includes the quotes)
+
+ Examples:
+ >>> escape_logstore_query_value("normal_value")
+ '"normal_value"'
+ >>> escape_logstore_query_value("value or field:evil")
+ '"value or field:evil"' # 'or' and ':' are now literals
+ >>> escape_logstore_query_value('value"test')
+ '"value\\"test"' # Internal double quote escaped
+ >>> escape_logstore_query_value('value\\test')
+ '"value\\\\test"' # Backslash escaped
+
+ Security:
+ - Prevents injection via and/or/not keywords
+ - Prevents injection via colons (:)
+ - Prevents injection via parentheses
+ - Protects against cross-tenant data access
+
+ Note:
+ Escape order is critical: backslash first, then double quotes.
+ Otherwise, we'd double-escape the escape character itself.
+ """
+ if not value:
+ return '""'
+
+ # IMPORTANT: Escape backslashes FIRST, then double quotes
+ # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
+ escaped = value.replace("\\", "\\\\") # \ -> \\
+ escaped = escaped.replace('"', '\\"') # " -> \"
+
+ # Wrap in double quotes to treat as literal string
+ # This prevents and/or/not/:/() from being interpreted as operators
+ return f'"{escaped}"'
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index bd71f18af2..0be836c8f1 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -115,7 +115,18 @@ def build_from_mappings(
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
- valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
+ def is_valid_mapping(m: Mapping[str, Any]) -> bool:
+ if not m or not m.get("transfer_method"):
+ return False
+ # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
+ transfer_method = m.get("transfer_method")
+ if transfer_method == FileTransferMethod.REMOTE_URL:
+ url = m.get("url") or m.get("remote_url")
+ if not url:
+ return False
+ return True
+
+ valid_mappings = [m for m in mappings if is_valid_mapping(m)]
files = [
build_from_mapping(
mapping=mapping,
diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py
index 494194369a..3f030ae127 100644
--- a/api/factories/variable_factory.py
+++ b/api/factories/variable_factory.py
@@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
- Variable,
+ VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
-def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
-def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
-def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
-def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
+def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
- result: Variable
+ result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
- return cast(Variable, result)
+ return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
-) -> Variable:
- if isinstance(segment, Variable):
+) -> VariableBase:
+ if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
- Variable,
+ VariableBase,
variable_class(
id=id,
name=name,
diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py
index 2bba198fa8..c81e482f73 100644
--- a/api/fields/message_fields.py
+++ b/api/fields/message_fields.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import datetime
from typing import TypeAlias
+from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -20,8 +21,8 @@ class SimpleFeedback(ResponseModel):
class RetrieverResource(ResponseModel):
- id: str
- message_id: str
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ message_id: str = Field(default_factory=lambda: str(uuid4()))
position: int
dataset_id: str | None = None
dataset_name: str | None = None
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index d037b0c442..2755f77f61 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
-from core.variables import SecretVariable, SegmentType, Variable
+from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,
diff --git a/api/libs/smtp.py b/api/libs/smtp.py
index 4044c6f7ed..6f82f1440a 100644
--- a/api/libs/smtp.py
+++ b/api/libs/smtp.py
@@ -3,6 +3,8 @@ import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
+from configs import dify_config
+
logger = logging.getLogger(__name__)
@@ -19,20 +21,21 @@ class SMTPClient:
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
- smtp = None
+ smtp: smtplib.SMTP | None = None
+ local_host = dify_config.SMTP_LOCAL_HOSTNAME
try:
- if self.use_tls:
- if self.opportunistic_tls:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
- # Send EHLO command with the HELO domain name as the server address
- smtp.ehlo(self.server)
- smtp.starttls()
- # Resend EHLO command to identify the TLS session
- smtp.ehlo(self.server)
- else:
- smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
+ if self.use_tls and not self.opportunistic_tls:
+ # SMTP with SSL (implicit TLS)
+ smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10, local_hostname=local_host)
else:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
+ # Plain SMTP or SMTP with STARTTLS (explicit TLS)
+ smtp = smtplib.SMTP(self.server, self.port, timeout=10, local_hostname=local_host)
+
+ assert smtp is not None
+ if self.use_tls and self.opportunistic_tls:
+ smtp.ehlo(self.server)
+ smtp.starttls()
+ smtp.ehlo(self.server)
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():
diff --git a/api/libs/workspace_permission.py b/api/libs/workspace_permission.py
new file mode 100644
index 0000000000..dd42a7facf
--- /dev/null
+++ b/api/libs/workspace_permission.py
@@ -0,0 +1,74 @@
+"""
+Workspace permission helper functions.
+
+These helpers check both billing/plan level and workspace-specific policy level permissions.
+Checks are performed at two levels:
+1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
+2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
+"""
+
+import logging
+
+from werkzeug.exceptions import Forbidden
+
+from configs import dify_config
+from services.enterprise.enterprise_service import EnterpriseService
+from services.feature_service import FeatureService
+
+logger = logging.getLogger(__name__)
+
+
+def check_workspace_member_invite_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows member invitations at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - For future expansion (currently no plan-level restriction)
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits member invitations
+ """
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_member_invite:
+ raise Forbidden("Workspace policy prohibits member invitations")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace invite permission for %s", workspace_id)
+
+
+def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows owner transfer at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - SANDBOX plan blocks owner transfer
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits ownership transfer
+ """
+ features = FeatureService.get_features(workspace_id)
+ if not features.is_allow_transfer_workspace:
+ raise Forbidden("Your current plan does not allow workspace ownership transfer")
+
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_owner_transfer:
+ raise Forbidden("Workspace policy prohibits ownership transfer")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace transfer permission for %s", workspace_id)
diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
new file mode 100644
index 0000000000..7e0cc8ec9d
--- /dev/null
+++ b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
@@ -0,0 +1,30 @@
+"""add workflow_run_created_at_id_idx
+
+Revision ID: 905527cc8fd3
+Revises: 7df29de0f6be
+Create Date: 2025-01-09 16:30:02.462084
+
+"""
+from alembic import op
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '905527cc8fd3'
+down_revision = '7df29de0f6be'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_run_created_at_id_idx')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
new file mode 100644
index 0000000000..758369ba99
--- /dev/null
+++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
@@ -0,0 +1,33 @@
+"""feat: add created_at id index to messages
+
+Revision ID: 3334862ee907
+Revises: 905527cc8fd3
+Create Date: 2026-01-12 17:29:44.846544
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '3334862ee907'
+down_revision = '905527cc8fd3'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.drop_index('message_created_at_id_idx')
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
new file mode 100644
index 0000000000..2e1af0c83f
--- /dev/null
+++ b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
@@ -0,0 +1,35 @@
+"""change workflow node execution workflow_run index
+
+Revision ID: 288345cd01d1
+Revises: 3334862ee907
+Create Date: 2026-01-16 17:15:00.000000
+
+"""
+from alembic import op
+
+
+# revision identifiers, used by Alembic.
+revision = "288345cd01d1"
+down_revision = "3334862ee907"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_id_idx",
+ ["workflow_run_id"],
+ unique=False,
+ )
+
+
+def downgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_id_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_idx",
+ ["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"],
+ unique=False,
+ )
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 445ac6086f..62f11b8c72 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -1149,7 +1149,7 @@ class DatasetCollectionBinding(TypeBase):
)
-class TidbAuthBinding(Base):
+class TidbAuthBinding(TypeBase):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@@ -1158,7 +1158,13 @@ class TidbAuthBinding(Base):
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1166,7 +1172,9 @@ class TidbAuthBinding(Base):
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
- created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
class Whitelist(TypeBase):
diff --git a/api/models/model.py b/api/models/model.py
index 11549cb132..f41ca1567f 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -1026,6 +1026,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
+ Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -1505,7 +1506,7 @@ class MessageAnnotation(Base):
return account
-class AppAnnotationHitHistory(Base):
+class AppAnnotationHitHistory(TypeBase):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1515,17 +1516,19 @@ class AppAnnotationHitHistory(Base):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- app_id = mapped_column(StringUUID, nullable=False)
+ id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- source = mapped_column(LongText, nullable=False)
- question = mapped_column(LongText, nullable=False)
- account_id = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
- message_id = mapped_column(StringUUID, nullable=False)
- annotation_question = mapped_column(LongText, nullable=False)
- annotation_content = mapped_column(LongText, nullable=False)
+ source: Mapped[str] = mapped_column(LongText, nullable=False)
+ question: Mapped[str] = mapped_column(LongText, nullable=False)
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+ score: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ annotation_question: Mapped[str] = mapped_column(LongText, nullable=False)
+ annotation_content: Mapped[str] = mapped_column(LongText, nullable=False)
@property
def account(self):
@@ -1901,7 +1904,7 @@ class MessageChain(TypeBase):
)
-class MessageAgentThought(Base):
+class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@@ -1909,34 +1912,42 @@ class MessageAgentThought(Base):
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- message_id = mapped_column(StringUUID, nullable=False)
- message_chain_id = mapped_column(StringUUID, nullable=True)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
- thought = mapped_column(LongText, nullable=True)
- tool = mapped_column(LongText, nullable=True)
- tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_input = mapped_column(LongText, nullable=True)
- observation = mapped_column(LongText, nullable=True)
+ created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
- tool_process_data = mapped_column(LongText, nullable=True)
- message = mapped_column(LongText, nullable=True)
- message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- message_unit_price = mapped_column(sa.Numeric, nullable=True)
- message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- message_files = mapped_column(LongText, nullable=True)
- answer = mapped_column(LongText, nullable=True)
- answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- answer_unit_price = mapped_column(sa.Numeric, nullable=True)
- answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- total_price = mapped_column(sa.Numeric, nullable=True)
- currency = mapped_column(String(255), nullable=True)
- latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
- created_by_role = mapped_column(String(255), nullable=False)
- created_by = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
+ tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ message_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ answer_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
+ )
@property
def files(self) -> list[Any]:
@@ -2133,7 +2144,7 @@ class TraceAppConfig(TypeBase):
}
-class TenantCreditPool(Base):
+class TenantCreditPool(TypeBase):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
@@ -2141,14 +2152,20 @@ class TenantCreditPool(Base):
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
- id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
- tenant_id = mapped_column(StringUUID, nullable=False)
- pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
- quota_limit = mapped_column(BigInteger, nullable=False, default=0)
- quota_used = mapped_column(BigInteger, nullable=False, default=0)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
- updated_at = mapped_column(
- sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
+ quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), init=False
+ )
+ updated_at: Mapped[datetime] = mapped_column(
+ sa.DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
)
@property
diff --git a/api/models/workflow.py b/api/models/workflow.py
index a18939523b..2ff47e87b9 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -1,11 +1,9 @@
-from __future__ import annotations
-
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
-from typing import TYPE_CHECKING, Any, Union, cast
+from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
-from core.variables import SecretVariable, Segment, SegmentType, Variable
+from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
- def value_of(cls, value: str) -> WorkflowType:
+ def value_of(cls, value: str) -> "WorkflowType":
"""
Get value of given mode.
@@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
- def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
+ def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
"""
Get workflow type from app mode.
@@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
- ) -> Workflow:
+ ) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
- var: Variable,
+ var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
- def environment_variables(self, value: Sequence[Variable]):
+ def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
- def encrypt_func(var: Variable) -> Variable:
+ def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result
@property
- def conversation_variables(self) -> Sequence[Variable]:
+ def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
- def conversation_variables(self, value: Sequence[Variable]):
+ def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@@ -597,6 +595,7 @@ class WorkflowRun(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
+ sa.Index("workflow_run_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -621,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
- pause: Mapped[WorkflowPause | None] = orm.relationship(
+ pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@@ -691,7 +690,7 @@ class WorkflowRun(Base):
}
@classmethod
- def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
+ def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -782,11 +781,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
- "workflow_node_execution_workflow_run_idx",
- "tenant_id",
- "app_id",
- "workflow_id",
- "triggered_from",
+ "workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
@@ -843,7 +838,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
- offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
+ offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@@ -853,13 +848,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
- query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
+ query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
- query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
+ query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@@ -934,7 +929,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
- def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
+ def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@@ -1048,7 +1043,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
- file: Mapped[UploadFile | None] = orm.relationship(
+ file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1066,7 +1061,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
- def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
+ def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
"""
Get value of given mode.
@@ -1183,7 +1178,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
- def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
+ def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@@ -1192,7 +1187,7 @@ class ConversationVariable(TypeBase):
)
return obj
- def to_variable(self) -> Variable:
+ def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@@ -1336,7 +1331,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
- variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
+ variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1506,7 +1501,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@@ -1529,7 +1524,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@@ -1550,7 +1545,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@@ -1573,7 +1568,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
- ) -> WorkflowDraftVariable:
+ ) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=node_id,
@@ -1669,7 +1664,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
- upload_file: Mapped[UploadFile] = orm.relationship(
+ upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@@ -1736,7 +1731,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
- workflow_run: Mapped[WorkflowRun] = orm.relationship(
+ workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@@ -1792,7 +1787,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
- def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
+ def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
diff --git a/api/pyproject.toml b/api/pyproject.toml
index dbc6a2eb83..d025a92846 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "dify-api"
-version = "1.11.2"
+version = "1.11.4"
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -189,7 +189,7 @@ storage = [
"opendal~=0.46.0",
"oss2==2.18.5",
"supabase~=2.18.1",
- "tos~=2.7.1",
+ "tos~=2.9.0",
]
############################################################
diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py
index fa2c94b623..479eb1ff54 100644
--- a/api/repositories/api_workflow_node_execution_repository.py
+++ b/api/repositories/api_workflow_node_execution_repository.py
@@ -13,6 +13,8 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models.workflow import WorkflowNodeExecutionModel
@@ -130,6 +132,18 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
"""
...
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions and offloads for the given workflow run ids.
+ """
+ ...
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions and offloads for the given workflow run ids.
+ """
+ ...
+
def delete_executions_by_app(
self,
tenant_id: str,
diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py
index fd547c78ba..1a2b84fdf9 100644
--- a/api/repositories/api_workflow_run_repository.py
+++ b/api/repositories/api_workflow_run_repository.py
@@ -34,11 +34,14 @@ Example:
```
"""
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.entities.pause_reason import PauseReason
+from core.workflow.enums import WorkflowType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
@@ -253,6 +256,44 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+ """
+ ...
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Delete workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons).
+ """
+ ...
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Count workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons) without deleting data.
+ """
+ ...
+
def create_workflow_pause(
self,
workflow_run_id: str,
diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
index 7e2173acdd..4a7c975d2c 100644
--- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
@@ -9,11 +9,14 @@ from collections.abc import Sequence
from datetime import datetime
from typing import cast
-from sqlalchemy import asc, delete, desc, select
+from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
-from models.workflow import WorkflowNodeExecutionModel
+from models.workflow import (
+ WorkflowNodeExecutionModel,
+ WorkflowNodeExecutionOffload,
+)
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -290,3 +293,61 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
result = cast(CursorResult, session.execute(stmt))
session.commit()
return result.rowcount
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+
+ offloads_deleted = (
+ cast(
+ CursorResult,
+ session.execute(
+ delete(WorkflowNodeExecutionOffload).where(
+ WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
+ )
+ ),
+ ).rowcount
+ or 0
+ )
+
+ node_executions_deleted = (
+ cast(
+ CursorResult,
+ session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)),
+ ).rowcount
+ or 0
+ )
+
+ return node_executions_deleted, offloads_deleted
+
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+
+ node_executions_count = (
+ session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0
+ )
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+ offloads_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowNodeExecutionOffload)
+ .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
+ )
+ or 0
+ )
+
+ return int(node_executions_count), int(offloads_count)
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
index b172c6a3ac..9d2d06e99f 100644
--- a/api/repositories/sqlalchemy_api_workflow_run_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -21,7 +21,7 @@ Implementation Notes:
import logging
import uuid
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from decimal import Decimal
from typing import Any, cast
@@ -32,7 +32,7 @@ from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
-from core.workflow.enums import WorkflowExecutionStatus
+from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import convert_datetime_to_date
@@ -40,8 +40,14 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
-from models.workflow import WorkflowPause as WorkflowPauseModel
-from models.workflow import WorkflowPauseReason, WorkflowRun
+from models.workflow import (
+ WorkflowAppLog,
+ WorkflowPauseReason,
+ WorkflowRun,
+)
+from models.workflow import (
+ WorkflowPause as WorkflowPauseModel,
+)
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
@@ -314,6 +320,171 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
return total_deleted
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+
+ Query scope:
+ - created_at in [start_from, end_before)
+ - type in run_types (when provided)
+ - status is an ended state
+ - optional tenant_id filter and cursor (last_seen) for pagination
+ """
+ with self._session_maker() as session:
+ stmt = (
+ select(WorkflowRun)
+ .where(
+ WorkflowRun.created_at < end_before,
+ WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
+ )
+ .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
+ .limit(batch_size)
+ )
+ if run_types is not None:
+ if not run_types:
+ return []
+ stmt = stmt.where(WorkflowRun.type.in_(run_types))
+
+ if start_from:
+ stmt = stmt.where(WorkflowRun.created_at >= start_from)
+
+ if tenant_ids:
+ stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids))
+
+ if last_seen:
+ stmt = stmt.where(
+ or_(
+ WorkflowRun.created_at > last_seen[0],
+ and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
+ )
+ )
+
+ return session.scalars(stmt).all()
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if delete_node_executions:
+ node_executions_deleted, offloads_deleted = delete_node_executions(session, runs)
+ else:
+ node_executions_deleted, offloads_deleted = 0, 0
+
+ app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))
+ app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0
+
+ pause_ids = session.scalars(
+ select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
+ ).all()
+ pause_reasons_deleted = 0
+ pauses_deleted = 0
+
+ if pause_ids:
+ pause_reasons_result = session.execute(
+ delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0
+ pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids)))
+ pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
+
+ trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
+
+ runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
+ runs_deleted = cast(CursorResult, runs_result).rowcount or 0
+
+ session.commit()
+
+ return {
+ "runs": runs_deleted,
+ "node_executions": node_executions_deleted,
+ "offloads": offloads_deleted,
+ "app_logs": app_logs_deleted,
+ "trigger_logs": trigger_logs_deleted,
+ "pauses": pauses_deleted,
+ "pause_reasons": pause_reasons_deleted,
+ }
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if count_node_executions:
+ node_executions_count, offloads_count = count_node_executions(session, runs)
+ else:
+ node_executions_count, offloads_count = 0, 0
+
+ app_logs_count = (
+ session.scalar(
+ select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))
+ )
+ or 0
+ )
+
+ pause_ids = session.scalars(
+ select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
+ ).all()
+ pauses_count = len(pause_ids)
+ pause_reasons_count = 0
+ if pause_ids:
+ pause_reasons_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowPauseReason)
+ .where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ or 0
+ )
+
+ trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0
+
+ return {
+ "runs": len(runs),
+ "node_executions": node_executions_count,
+ "offloads": offloads_count,
+ "app_logs": int(app_logs_count),
+ "trigger_logs": trigger_logs_count,
+ "pauses": pauses_count,
+ "pause_reasons": int(pause_reasons_count),
+ }
+
def create_workflow_pause(
self,
workflow_run_id: str,
diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
index 0d67e286b0..ebd3745d18 100644
--- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
+++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
@@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta
+from typing import cast
-from sqlalchemy import and_, select
+from sqlalchemy import and_, delete, func, select
+from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from models.enums import WorkflowTriggerStatus
@@ -84,3 +86,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
)
return list(self.session.scalars(query).all())
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows deleted.
+ """
+ if not run_ids:
+ return 0
+
+ result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)))
+ return cast(CursorResult, result).rowcount or 0
+
+ def count_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Count trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows matched.
+ """
+ if not run_ids:
+ return 0
+
+ count = self.session.scalar(
+ select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
+ )
+ return int(count or 0)
diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py
index 138b8779ac..b0009e398d 100644
--- a/api/repositories/workflow_trigger_log_repository.py
+++ b/api/repositories/workflow_trigger_log_repository.py
@@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol):
A sequence of recent WorkflowTriggerLog instances
"""
...
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs for workflow run IDs.
+
+ Args:
+ run_ids: Workflow run IDs to delete
+
+ Returns:
+ Number of rows deleted
+ """
+ ...
diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py
index 352a84b592..e85bba8823 100644
--- a/api/schedule/clean_messages.py
+++ b/api/schedule/clean_messages.py
@@ -1,90 +1,62 @@
-import datetime
import logging
import time
import click
-from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
-from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
-from extensions.ext_redis import redis_client
-from models.model import (
- App,
- Message,
- MessageAgentThought,
- MessageAnnotation,
- MessageChain,
- MessageFeedback,
- MessageFile,
-)
-from models.web import SavedMessage
-from services.feature_service import FeatureService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
logger = logging.getLogger(__name__)
-@app.celery.task(queue="dataset")
+@app.celery.task(queue="retention")
def clean_messages():
- click.echo(click.style("Start clean messages.", fg="green"))
- start_at = time.perf_counter()
- plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
- days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
- )
- while True:
- try:
- # Main query with join and filter
- messages = (
- db.session.query(Message)
- .where(Message.created_at < plan_sandbox_clean_message_day)
- .order_by(Message.created_at.desc())
- .limit(100)
- .all()
- )
+ """
+ Clean expired messages based on clean policy.
- except SQLAlchemyError:
- raise
- if not messages:
- break
- for message in messages:
- app = db.session.query(App).filter_by(id=message.app_id).first()
- if not app:
- logger.warning(
- "Expected App record to exist, but none was found, app_id=%s, message_id=%s",
- message.app_id,
- message.id,
- )
- continue
- features_cache_key = f"features:{app.tenant_id}"
- plan_cache = redis_client.get(features_cache_key)
- if plan_cache is None:
- features = FeatureService.get_features(app.tenant_id)
- redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
- plan = features.billing.subscription.plan
- else:
- plan = plan_cache.decode()
- if plan == CloudPlan.SANDBOX:
- # clean related message
- db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(Message).where(Message.id == message.id).delete()
- db.session.commit()
- end_at = time.perf_counter()
- click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
+ This task uses MessagesCleanService to efficiently clean messages in batches.
+ The behavior depends on BILLING_ENABLED configuration:
+ - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period)
+ - BILLING_ENABLED=False: delete all messages within the time range
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ policy = create_message_clean_policy(
+ graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
+ )
+
+ # Create and run the cleanup service
+ service = MessagesCleanService.from_days(
+ policy=policy,
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py
new file mode 100644
index 0000000000..9f5bf8e150
--- /dev/null
+++ b/api/schedule/clean_workflow_runs_task.py
@@ -0,0 +1,43 @@
+from datetime import UTC, datetime
+
+import click
+
+import app
+from configs import dify_config
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
+
+
+@app.celery.task(queue="retention")
+def clean_workflow_runs_task() -> None:
+ """
+ Scheduled cleanup for workflow runs and related records (sandbox tenants only).
+ """
+ click.echo(
+ click.style(
+ (
+ "Scheduled workflow run cleanup starting: "
+ f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, "
+ f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}"
+ ),
+ fg="green",
+ )
+ )
+
+ start_time = datetime.now(UTC)
+
+ WorkflowRunCleanup(
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ start_from=None,
+ end_before=None,
+ ).run()
+
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py
index c343063fae..ed46c1c70a 100644
--- a/api/schedule/create_tidb_serverless_task.py
+++ b/api/schedule/create_tidb_serverless_task.py
@@ -50,10 +50,13 @@ def create_clusters(batch_size):
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
+ tenant_id=None,
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
+ active=False,
+ status="CREATING",
)
db.session.add(tidb_auth_binding)
db.session.commit()
diff --git a/api/services/account_service.py b/api/services/account_service.py
index d38c9d5a66..35e4a505af 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -748,6 +748,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
+ @staticmethod
+ def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
+ """
+ Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
+
+ This keeps backward compatibility for older records that stored uppercase emails while the
+ rest of the system gradually normalizes new inputs.
+ """
+ query_session = session or db.session
+ account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ if account or email == email.lower():
+ return account
+
+ return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
+
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@@ -1363,16 +1378,27 @@ class RegisterService:
if not inviter:
raise ValueError("Inviter is required")
+ normalized_email = email.lower()
+
"""Invite new member"""
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
with Session(db.engine) as session:
- account = session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
- name = email.split("@")[0]
+ name = normalized_email.split("@")[0]
account = cls.register(
- email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
+ email=normalized_email,
+ name=name,
+ language=language,
+ status=AccountStatus.PENDING,
+ is_setup=True,
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
@@ -1394,7 +1420,7 @@ class RegisterService:
# send email
send_invite_member_mail_task.delay(
language=language,
- to=email,
+ to=account.email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
@@ -1493,6 +1519,16 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
+ @classmethod
+ def get_invitation_with_case_fallback(
+ cls, workspace_id: str | None, email: str | None, token: str
+ ) -> dict[str, Any] | None:
+ invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
+ if invitation or not email or email == email.lower():
+ return invitation
+ normalized_email = email.lower()
+ return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
+
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)
diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py
index acc0ec2b22..92008d5ff1 100644
--- a/api/services/conversation_variable_updater.py
+++ b/api/services/conversation_variable_updater.py
@@ -1,7 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
-from core.variables.variables import Variable
+from core.variables.variables import VariableBase
from models import ConversationVariable
@@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
- def update(self, conversation_id: str, variable: Variable) -> None:
+ def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py
index bdc960aa2d..e3832475aa 100644
--- a/api/services/enterprise/base.py
+++ b/api/services/enterprise/base.py
@@ -1,9 +1,14 @@
+import logging
import os
from collections.abc import Mapping
from typing import Any
import httpx
+from core.helper.trace_id_helper import generate_traceparent_header
+
+logger = logging.getLogger(__name__)
+
class BaseRequest:
proxies: Mapping[str, str] | None = {
@@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
+
+ try:
+ # ensure traceparent even when OTEL is disabled
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+ except Exception:
+ logger.debug("Failed to generate traceparent header", exc_info=True)
+
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()
diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py
index c0cc0e5233..a5133dfcb4 100644
--- a/api/services/enterprise/enterprise_service.py
+++ b/api/services/enterprise/enterprise_service.py
@@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
)
+class WorkspacePermission(BaseModel):
+ workspace_id: str = Field(
+ description="The ID of the workspace.",
+ alias="workspaceId",
+ )
+ allow_member_invite: bool = Field(
+ description="Whether to allow members to invite new members to the workspace.",
+ default=False,
+ alias="allowMemberInvite",
+ )
+ allow_owner_transfer: bool = Field(
+ description="Whether to allow owners to transfer ownership of the workspace.",
+ default=False,
+ alias="allowOwnerTransfer",
+ )
+
+
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -44,6 +61,16 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
+ class WorkspacePermissionService:
+ @classmethod
+ def get_permission(cls, workspace_id: str):
+ if not workspace_id:
+ raise ValueError("workspace_id must be provided.")
+ data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
+ if not data or "permission" not in data:
+ raise ValueError("No data found.")
+ return WorkspacePermission.model_validate(data["permission"])
+
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
diff --git a/api/services/enterprise/workspace_sync.py b/api/services/enterprise/workspace_sync.py
new file mode 100644
index 0000000000..acfe325397
--- /dev/null
+++ b/api/services/enterprise/workspace_sync.py
@@ -0,0 +1,58 @@
+import json
+import logging
+import uuid
+from datetime import UTC, datetime
+
+from redis import RedisError
+
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
+WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
+
+
+class WorkspaceSyncService:
+ """Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
+
+ @staticmethod
+ def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
+ """
+ Queue a credential sync task for a newly created workspace.
+
+ This publishes a task to Redis that will be consumed by the enterprise backend
+ worker to sync credentials with the plugin-manager.
+
+ Args:
+ workspace_id: The workspace/tenant ID to sync credentials for
+ source: Source of the sync request (for debugging/tracking)
+
+ Returns:
+ bool: True if task was queued successfully, False otherwise
+ """
+ try:
+ task = {
+ "task_id": str(uuid.uuid4()),
+ "workspace_id": workspace_id,
+ "retry_count": 0,
+ "created_at": datetime.now(UTC).isoformat(),
+ "source": source,
+ }
+
+ # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
+ redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
+
+ logger.info(
+ "Queued credential sync task for workspace %s, task_id: %s, source: %s",
+ workspace_id,
+ task["task_id"],
+ source,
+ )
+ return True
+
+ except (RedisError, TypeError) as e:
+ logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
+ # Don't raise - we don't want to fail workspace creation if queueing fails
+ # The scheduled task will catch it later
+ return False
diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py
index b8303eb724..411c335c17 100644
--- a/api/services/plugin/plugin_service.py
+++ b/api/services/plugin/plugin_service.py
@@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
+from sqlalchemy import select
from yarl import URL
from configs import dify_config
@@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
+from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
+
+ # Get plugin info before uninstalling to delete associated credentials
+ try:
+ plugins = manager.list_plugins(tenant_id)
+ plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
+
+ if plugin:
+ plugin_id = plugin.plugin_id
+ logger.info("Deleting credentials for plugin: %s", plugin_id)
+
+ # Delete provider credentials that match this plugin
+ credentials = db.session.scalars(
+ select(ProviderCredential).where(
+ ProviderCredential.tenant_id == tenant_id,
+ ProviderCredential.provider_name.like(f"{plugin_id}/%"),
+ )
+ ).all()
+
+ for cred in credentials:
+ db.session.delete(cred)
+
+ db.session.commit()
+ logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
+ except Exception as e:
+ logger.warning("Failed to delete credentials: %s", e)
+ # Continue with uninstall even if credential deletion fails
+
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 1ba64813ba..2d8418900c 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
-from core.variables.variables import Variable
+from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""
diff --git a/api/services/retention/__init__.py b/api/services/retention/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py
new file mode 100644
index 0000000000..6e647b983b
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_policy.py
@@ -0,0 +1,216 @@
+import datetime
+import logging
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SimpleMessage:
+ id: str
+ app_id: str
+ created_at: datetime.datetime
+
+
+class MessagesCleanPolicy(ABC):
+ """
+ Abstract base class for message cleanup policies.
+
+ A policy determines which messages from a batch should be deleted.
+ """
+
+ @abstractmethod
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages and return IDs of messages that should be deleted.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ ...
+
+
+class BillingDisabledPolicy(MessagesCleanPolicy):
+ """
+ Policy for community or enterpriseedition (billing disabled).
+
+ No special filter logic, just return all message ids.
+ """
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ return [msg.id for msg in messages]
+
+
+class BillingSandboxPolicy(MessagesCleanPolicy):
+ """
+ Policy for sandbox plan tenants in cloud edition (billing enabled).
+
+ Filters messages based on sandbox plan expiration rules:
+ - Skip tenants in the whitelist
+ - Only delete messages from sandbox plan tenants
+ - Respect grace period after subscription expiration
+ - Safe default: if tenant mapping or plan is missing, do NOT delete
+ """
+
+ def __init__(
+ self,
+ plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
+ graceful_period_days: int = 21,
+ tenant_whitelist: Sequence[str] | None = None,
+ current_timestamp: int | None = None,
+ ) -> None:
+ self._graceful_period_days = graceful_period_days
+ self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
+ self._plan_provider = plan_provider
+ self._current_timestamp = current_timestamp
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages based on sandbox plan expiration rules.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ if not messages or not app_to_tenant:
+ return []
+
+ # Get unique tenant_ids and fetch subscription plans
+ tenant_ids = list(set(app_to_tenant.values()))
+ tenant_plans = self._plan_provider(tenant_ids)
+
+ if not tenant_plans:
+ return []
+
+ # Apply sandbox deletion rules
+ return self._filter_expired_sandbox_messages(
+ messages=messages,
+ app_to_tenant=app_to_tenant,
+ tenant_plans=tenant_plans,
+ )
+
+ def _filter_expired_sandbox_messages(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ tenant_plans: dict[str, SubscriptionPlan],
+ ) -> list[str]:
+ """
+ Filter messages that should be deleted based on sandbox plan expiration.
+
+ A message should be deleted if:
+ 1. It belongs to a sandbox tenant AND
+ 2. Either:
+ a) The tenant has no previous subscription (expiration_date == -1), OR
+ b) The subscription expired more than graceful_period_days ago
+
+ Args:
+ messages: List of message objects with id and app_id attributes
+ app_to_tenant: Mapping from app_id to tenant_id
+ tenant_plans: Mapping from tenant_id to subscription plan info
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ current_timestamp = self._current_timestamp
+ if current_timestamp is None:
+ current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
+
+ sandbox_message_ids: list[str] = []
+ graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
+
+ for msg in messages:
+ # Get tenant_id for this message's app
+ tenant_id = app_to_tenant.get(msg.app_id)
+ if not tenant_id:
+ continue
+
+ # Skip tenant messages in whitelist
+ if tenant_id in self._tenant_whitelist:
+ continue
+
+ # Get subscription plan for this tenant
+ tenant_plan = tenant_plans.get(tenant_id)
+ if not tenant_plan:
+ continue
+
+ plan = str(tenant_plan["plan"])
+ expiration_date = int(tenant_plan["expiration_date"])
+
+ # Only process sandbox plans
+ if plan != CloudPlan.SANDBOX:
+ continue
+
+ # Case 1: No previous subscription (-1 means never had a paid subscription)
+ if expiration_date == -1:
+ sandbox_message_ids.append(msg.id)
+ continue
+
+ # Case 2: Subscription expired beyond grace period
+ if current_timestamp - expiration_date > graceful_period_seconds:
+ sandbox_message_ids.append(msg.id)
+
+ return sandbox_message_ids
+
+
+def create_message_clean_policy(
+ graceful_period_days: int = 21,
+ current_timestamp: int | None = None,
+) -> MessagesCleanPolicy:
+ """
+ Factory function to create the appropriate message clean policy.
+
+ Determines which policy to use based on BILLING_ENABLED configuration:
+ - If BILLING_ENABLED is True: returns BillingSandboxPolicy
+ - If BILLING_ENABLED is False: returns BillingDisabledPolicy
+
+ Args:
+ graceful_period_days: Grace period in days after subscription expiration (default: 21)
+ current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
+ """
+ if not dify_config.BILLING_ENABLED:
+ logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
+ return BillingDisabledPolicy()
+
+ # Billing enabled - fetch whitelist from BillingService
+ tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
+ plan_provider = BillingService.get_plan_bulk_with_cache
+
+ logger.info(
+ "create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
+ "(graceful_period_days=%s, whitelist=%s)",
+ graceful_period_days,
+ tenant_whitelist,
+ )
+
+ return BillingSandboxPolicy(
+ plan_provider=plan_provider,
+ graceful_period_days=graceful_period_days,
+ tenant_whitelist=tenant_whitelist,
+ current_timestamp=current_timestamp,
+ )
diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py
new file mode 100644
index 0000000000..3ca5d82860
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_service.py
@@ -0,0 +1,334 @@
+import datetime
+import logging
+import random
+from collections.abc import Sequence
+from typing import cast
+
+from sqlalchemy import delete, select
+from sqlalchemy.engine import CursorResult
+from sqlalchemy.orm import Session
+
+from extensions.ext_database import db
+from models.model import (
+ App,
+ AppAnnotationHitHistory,
+ DatasetRetrieverResource,
+ Message,
+ MessageAgentThought,
+ MessageAnnotation,
+ MessageChain,
+ MessageFeedback,
+ MessageFile,
+)
+from models.web import SavedMessage
+from services.retention.conversation.messages_clean_policy import (
+ MessagesCleanPolicy,
+ SimpleMessage,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class MessagesCleanService:
+ """
+ Service for cleaning expired messages based on retention policies.
+
+ Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
+ If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
+ """
+
+ def __init__(
+ self,
+ policy: MessagesCleanPolicy,
+ end_before: datetime.datetime,
+ start_from: datetime.datetime | None = None,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> None:
+ """
+ Initialize the service with cleanup parameters.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ end_before: End time (exclusive) of the range
+ start_from: Optional start time (inclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+ """
+ self._policy = policy
+ self._end_before = end_before
+ self._start_from = start_from
+ self._batch_size = batch_size
+ self._dry_run = dry_run
+
+ @classmethod
+ def from_time_range(
+ cls,
+ policy: MessagesCleanPolicy,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages within a specific time range.
+
+ Time range is [start_from, end_before).
+
+ Args:
+ policy: The policy that determines which messages to delete
+ start_from: Start time (inclusive) of the range
+ end_before: End time (exclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If start_from >= end_before or invalid parameters
+ """
+ if start_from >= end_before:
+ raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ logger.info(
+ "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
+ start_from,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(
+ policy=policy,
+ end_before=end_before,
+ start_from=start_from,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+
+ @classmethod
+ def from_days(
+ cls,
+ policy: MessagesCleanPolicy,
+ days: int = 30,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages older than specified days.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ days: Number of days to look back from now
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If invalid parameters
+ """
+ if days < 0:
+ raise ValueError(f"days ({days}) must be greater than or equal to 0")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ end_before = datetime.datetime.now() - datetime.timedelta(days=days)
+
+ logger.info(
+ "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
+ days,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
+
+ def run(self) -> dict[str, int]:
+ """
+ Execute the message cleanup operation.
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ return self._clean_messages_by_time_range()
+
+ def _clean_messages_by_time_range(self) -> dict[str, int]:
+ """
+ Clean messages within a time range using cursor-based pagination.
+
+ Time range is [start_from, end_before)
+
+ Steps:
+ 1. Iterate messages using cursor pagination (by created_at, id)
+ 2. Query app_id -> tenant_id mapping
+ 3. Delegate to policy to determine which messages to delete
+ 4. Batch delete messages and their relations
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ stats = {
+ "batches": 0,
+ "total_messages": 0,
+ "filtered_messages": 0,
+ "total_deleted": 0,
+ }
+
+ # Cursor-based pagination using (created_at, id) to avoid infinite loops
+ # and ensure proper ordering with time-based filtering
+ _cursor: tuple[datetime.datetime, str] | None = None
+
+ logger.info(
+ "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
+ self._dry_run,
+ self._start_from,
+ self._end_before,
+ )
+
+ while True:
+ stats["batches"] += 1
+
+ # Step 1: Fetch a batch of messages using cursor
+ with Session(db.engine, expire_on_commit=False) as session:
+ msg_stmt = (
+ select(Message.id, Message.app_id, Message.created_at)
+ .where(Message.created_at < self._end_before)
+ .order_by(Message.created_at, Message.id)
+ .limit(self._batch_size)
+ )
+
+ if self._start_from:
+ msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
+
+ # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
+ # This translates to:
+ # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
+ if _cursor:
+ # Continuing from previous batch
+ msg_stmt = msg_stmt.where(
+ (Message.created_at > _cursor[0])
+ | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
+ )
+
+ raw_messages = list(session.execute(msg_stmt).all())
+ messages = [
+ SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
+ for msg_id, app_id, msg_created_at in raw_messages
+ ]
+
+ # Track total messages fetched across all batches
+ stats["total_messages"] += len(messages)
+
+ if not messages:
+ logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
+ break
+
+ # Update cursor to the last message's (created_at, id)
+ _cursor = (messages[-1].created_at, messages[-1].id)
+
+ # Step 2: Extract app_ids and query tenant_ids
+ app_ids = list({msg.app_id for msg in messages})
+
+ if not app_ids:
+ logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
+ continue
+
+ app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
+ apps = list(session.execute(app_stmt).all())
+
+ if not apps:
+ logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
+ continue
+
+ # Build app_id -> tenant_id mapping
+ app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
+
+ # Step 3: Delegate to policy to determine which messages to delete
+ message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
+
+ if not message_ids_to_delete:
+ logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
+ continue
+
+ stats["filtered_messages"] += len(message_ids_to_delete)
+
+ # Step 4: Batch delete messages and their relations
+ if not self._dry_run:
+ with Session(db.engine, expire_on_commit=False) as session:
+ # Delete related records first
+ self._batch_delete_message_relations(session, message_ids_to_delete)
+
+ # Delete messages
+ delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
+ delete_result = cast(CursorResult, session.execute(delete_stmt))
+ messages_deleted = delete_result.rowcount
+ session.commit()
+
+ stats["total_deleted"] += messages_deleted
+
+ logger.info(
+ "clean_messages (batch %s): processed %s messages, deleted %s messages",
+ stats["batches"],
+ len(messages),
+ messages_deleted,
+ )
+ else:
+ # Log random sample of message IDs that would be deleted (up to 10)
+ sample_size = min(10, len(message_ids_to_delete))
+ sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
+
+ logger.info(
+ "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
+ stats["batches"],
+ len(message_ids_to_delete),
+ sample_size,
+ )
+ for msg_id in sampled_ids:
+ logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
+
+ logger.info(
+ "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
+ stats["batches"],
+ stats["total_messages"],
+ stats["filtered_messages"],
+ stats["total_deleted"],
+ )
+
+ return stats
+
+ @staticmethod
+ def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
+ """
+ Batch delete all related records for given message IDs.
+
+ Args:
+ session: Database session
+ message_ids: List of message IDs to delete relations for
+ """
+ if not message_ids:
+ return
+
+ # Delete all related records in batch
+ session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
+
+ session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
+
+ session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
+
+ session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))
diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
new file mode 100644
index 0000000000..c3e0dce399
--- /dev/null
+++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
@@ -0,0 +1,293 @@
+import datetime
+import logging
+from collections.abc import Iterable, Sequence
+
+import click
+from sqlalchemy.orm import Session, sessionmaker
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from extensions.ext_database import db
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.factory import DifyAPIRepositoryFactory
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+class WorkflowRunCleanup:
+ def __init__(
+ self,
+ days: int,
+ batch_size: int,
+ start_from: datetime.datetime | None = None,
+ end_before: datetime.datetime | None = None,
+ workflow_run_repo: APIWorkflowRunRepository | None = None,
+ dry_run: bool = False,
+ ):
+ if (start_from is None) ^ (end_before is None):
+ raise ValueError("start_from and end_before must be both set or both omitted.")
+
+ computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days)
+ self.window_start = start_from
+ self.window_end = end_before or computed_cutoff
+
+ if self.window_start and self.window_end <= self.window_start:
+ raise ValueError("end_before must be greater than start_from.")
+
+ if batch_size <= 0:
+ raise ValueError("batch_size must be greater than 0.")
+
+ self.batch_size = batch_size
+ self._cleanup_whitelist: set[str] | None = None
+ self.dry_run = dry_run
+ self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
+ self.workflow_run_repo: APIWorkflowRunRepository
+ if workflow_run_repo:
+ self.workflow_run_repo = workflow_run_repo
+ else:
+ # Lazy import to avoid circular dependencies during module import
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ def run(self) -> None:
+ click.echo(
+ click.style(
+ f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs "
+ f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}"
+ f"{self.window_end.isoformat()} (batch={self.batch_size})",
+ fg="white",
+ )
+ )
+ if self.dry_run:
+ click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow"))
+
+ total_runs_deleted = 0
+ total_runs_targeted = 0
+ related_totals = self._empty_related_counts() if self.dry_run else None
+ batch_index = 0
+ last_seen: tuple[datetime.datetime, str] | None = None
+
+ while True:
+ run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
+ start_from=self.window_start,
+ end_before=self.window_end,
+ last_seen=last_seen,
+ batch_size=self.batch_size,
+ )
+ if not run_rows:
+ break
+
+ batch_index += 1
+ last_seen = (run_rows[-1].created_at, run_rows[-1].id)
+ tenant_ids = {row.tenant_id for row in run_rows}
+ free_tenants = self._filter_free_tenants(tenant_ids)
+ free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
+ paid_or_skipped = len(run_rows) - len(free_runs)
+
+ if not free_runs:
+ skipped_message = (
+ f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
+ )
+ click.echo(
+ click.style(
+ skipped_message,
+ fg="yellow",
+ )
+ )
+ continue
+
+ total_runs_targeted += len(free_runs)
+
+ if self.dry_run:
+ batch_counts = self.workflow_run_repo.count_runs_with_related(
+ free_runs,
+ count_node_executions=self._count_node_executions,
+ count_trigger_logs=self._count_trigger_logs,
+ )
+ if related_totals is not None:
+ for key in related_totals:
+ related_totals[key] += batch_counts.get(key, 0)
+ sample_ids = ", ".join(run.id for run in free_runs[:5])
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] would delete {len(free_runs)} runs "
+ f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
+ fg="yellow",
+ )
+ )
+ continue
+
+ try:
+ counts = self.workflow_run_repo.delete_runs_with_related(
+ free_runs,
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+ except Exception:
+ logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
+ raise
+
+ total_runs_deleted += counts["runs"]
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] deleted runs: {counts['runs']} "
+ f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
+ f"skipped {paid_or_skipped} paid/unknown",
+ fg="green",
+ )
+ )
+
+ if self.dry_run:
+ if self.window_start:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"before {self.window_end.isoformat()}"
+ )
+ if related_totals is not None:
+ summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
+ summary_color = "yellow"
+ else:
+ if self.window_start:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
+ )
+ summary_color = "white"
+
+ click.echo(click.style(summary_message, fg=summary_color))
+
+ def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
+ tenant_id_list = list(tenant_ids)
+
+ if not dify_config.BILLING_ENABLED:
+ return set(tenant_id_list)
+
+ if not tenant_id_list:
+ return set()
+
+ cleanup_whitelist = self._get_cleanup_whitelist()
+
+ try:
+ bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list)
+ except Exception:
+ bulk_info = {}
+ logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list)
+
+ eligible_free_tenants: set[str] = set()
+ for tenant_id in tenant_id_list:
+ if tenant_id in cleanup_whitelist:
+ continue
+
+ info = bulk_info.get(tenant_id)
+ if info is None:
+ logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id)
+ continue
+
+ if info.get("plan") != CloudPlan.SANDBOX:
+ continue
+
+ if self._is_within_grace_period(tenant_id, info):
+ continue
+
+ eligible_free_tenants.add(tenant_id)
+
+ return eligible_free_tenants
+
+ def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None:
+ if expiration_value < 0:
+ return None
+
+ try:
+ return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC)
+ except (OverflowError, OSError, ValueError):
+ logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id)
+ return None
+
+ def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool:
+ if self.free_plan_grace_period_days <= 0:
+ return False
+
+ expiration_value = info.get("expiration_date", -1)
+ expiration_at = self._expiration_datetime(tenant_id, expiration_value)
+ if expiration_at is None:
+ return False
+
+ grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days)
+ return datetime.datetime.now(datetime.UTC) < grace_deadline
+
+ def _get_cleanup_whitelist(self) -> set[str]:
+ if self._cleanup_whitelist is not None:
+ return self._cleanup_whitelist
+
+ if not dify_config.BILLING_ENABLED:
+ self._cleanup_whitelist = set()
+ return self._cleanup_whitelist
+
+ try:
+ whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist()
+ except Exception:
+ logger.exception("Failed to fetch cleanup whitelist from billing service")
+ whitelist_ids = []
+
+ self._cleanup_whitelist = set(whitelist_ids)
+ return self._cleanup_whitelist
+
+ def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.count_by_run_ids(run_ids)
+
+ @staticmethod
+ def _empty_related_counts() -> dict[str, int]:
+ return {
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ @staticmethod
+ def _format_related_counts(counts: dict[str, int]) -> str:
+ return (
+ f"node_executions {counts['node_executions']}, "
+ f"offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, "
+ f"trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, "
+ f"pause_reasons {counts['pause_reasons']}"
+ )
+
+ def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.count_by_runs(session, run_ids)
+
+ def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.delete_by_runs(session, run_ids)
diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py
index 9bd797a45f..5ca0b63001 100644
--- a/api/services/webapp_auth_service.py
+++ b/api/services/webapp_auth_service.py
@@ -12,6 +12,7 @@ from libs.passport import PassportService
from libs.password import compare_password
from models import Account, AccountStatus
from models.model import App, EndUser, Site
+from services.account_service import AccountService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
@@ -32,7 +33,7 @@ class WebAppAuthService:
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
- account = db.session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
raise AccountNotFoundError()
@@ -52,7 +53,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
- account = db.session.query(Account).where(Account.email == email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
return None
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index 9407a2b3f0..70b0190231 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
-from core.variables import Segment, StringSegment, Variable
+from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
- _fallback_variables: Sequence[Variable]
+ _fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
- fallback_variables: Sequence[Variable] | None = None,
+ fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
- # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
- variable_by_selector: dict[tuple[str, str], Variable] = {}
+ # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
+ variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
- def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
+ def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index b45a167b73..d8c3159178 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
-from core.variables import Variable
-from core.variables.variables import VariableUnion
+from core.variables import VariableBase
+from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -198,8 +198,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
- conversation_variables: list[Variable],
+ conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- conversation_variables=cast(list[VariableUnion], conversation_variables), #
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool
diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html
index a07c5f4b16..7b296519f0 100644
--- a/api/templates/invite_member_mail_template_en-US.html
+++ b/api/templates/invite_member_mail_template_en-US.html
@@ -83,7 +83,30 @@
Dear {{ to }},
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
Click the button below to log in to Dify and join the workspace.
diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html
index ac5042c274..e2bb99c989 100644
--- a/api/templates/register_email_when_account_exist_template_en-US.html
+++ b/api/templates/register_email_when_account_exist_template_en-US.html
@@ -115,7 +115,30 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here:
+ If the button doesn't work, copy and paste this link into your browser:
+
+ {{ login_url }}
+
+
+
If you forgot your password, you can reset it here: Reset Password
diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html
index 326b58343a..6a5bbd135b 100644
--- a/api/templates/register_email_when_account_exist_template_zh-CN.html
+++ b/api/templates/register_email_when_account_exist_template_zh-CN.html
@@ -115,7 +115,30 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录:
diff --git a/api/templates/without-brand/invite_member_mail_template_en-US.html b/api/templates/without-brand/invite_member_mail_template_en-US.html
index f9157284fa..687ece617a 100644
--- a/api/templates/without-brand/invite_member_mail_template_en-US.html
+++ b/api/templates/without-brand/invite_member_mail_template_en-US.html
@@ -92,12 +92,34 @@
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
create, and collaborate to build and operate AI applications.
Click the button below to log in to {{application_title}} and join the workspace.