)
+vi.mock('@/app/components/base/button', () => ({ children }: any) => )
+
+// ✅ CORRECT: Import and use real base components
+import Loading from '@/app/components/base/loading'
+import Button from '@/app/components/base/button'
+// They will render normally in tests
+```
+
+### What TO Mock
+
+Only mock these categories:
+
+1. **API services** (`@/service/*`) - Network calls
+1. **Complex context providers** - When setup is too difficult
+1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
+1. **i18n** - Always mock to return keys
+
+### Zustand Stores - DO NOT Mock Manually
+
+**Zustand is globally mocked** in `web/vitest.setup.ts`. Use real stores with `setState()`:
+
+```typescript
+// ✅ CORRECT: Use real store, set test state
+import { useAppStore } from '@/app/components/app/store'
+
+useAppStore.setState({ appDetail: { id: 'test', name: 'Test' } })
+render()
+
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({ ... }))
+```
+
+See [Zustand Store Testing](#zustand-store-testing) section for full details.
+
+## Mock Placement
+
+| Location | Purpose |
+|----------|---------|
+| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) |
+| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) |
+| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
+| Test file | Test-specific mocks, inline with `vi.mock()` |
+
+Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
+
+**Note**: Zustand is special - it's globally mocked but you should NOT mock store modules manually. See [Zustand Store Testing](#zustand-store-testing).
+
+## Essential Mocks
+
+### 1. i18n (Auto-loaded via Global Mock)
+
+A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
+
+The global mock provides:
+
+- `useTranslation` - returns translation keys with namespace prefix
+- `Trans` component - renders i18nKey and components
+- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
+- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
+
+**Default behavior**: Most tests should use the global mock (no local override needed).
+
+**For custom translations**: Use the helper function from `@/test/i18n-mock`:
+
+```typescript
+import { createReactI18nextMock } from '@/test/i18n-mock'
+
+vi.mock('react-i18next', () => createReactI18nextMock({
+ 'my.custom.key': 'Custom translation',
+ 'button.save': 'Save',
+}))
+```
+
+**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
+
+### 2. Next.js Router
+
+```typescript
+const mockPush = vi.fn()
+const mockReplace = vi.fn()
+
+vi.mock('next/navigation', () => ({
+ useRouter: () => ({
+ push: mockPush,
+ replace: mockReplace,
+ back: vi.fn(),
+ prefetch: vi.fn(),
+ }),
+ usePathname: () => '/current-path',
+ useSearchParams: () => new URLSearchParams('?key=value'),
+}))
+
+describe('Component', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ it('should navigate on click', () => {
+ render()
+ fireEvent.click(screen.getByRole('button'))
+ expect(mockPush).toHaveBeenCalledWith('/expected-path')
+ })
+})
+```
+
+### 3. Portal Components (with Shared State)
+
+```typescript
+// ⚠️ Important: Use shared state for components that depend on each other
+let mockPortalOpenState = false
+
+vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
+ PortalToFollowElem: ({ children, open, ...props }: any) => {
+ mockPortalOpenState = open || false // Update shared state
+ return
{children}
+ },
+ PortalToFollowElemContent: ({ children }: any) => {
+ // ✅ Matches actual: returns null when portal is closed
+ if (!mockPortalOpenState) return null
+ return
+ ),
+}))
+
+describe('Component', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ mockPortalOpenState = false // ✅ Reset shared state
+ })
+})
+```
+
+### 4. API Service Mocks
+
+```typescript
+import * as api from '@/service/api'
+
+vi.mock('@/service/api')
+
+const mockedApi = vi.mocked(api)
+
+describe('Component', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+
+ // Setup default mock implementation
+ mockedApi.fetchData.mockResolvedValue({ data: [] })
+ })
+
+ it('should show data on success', async () => {
+ mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
+
+ render()
+
+ await waitFor(() => {
+ expect(screen.getByText('1')).toBeInTheDocument()
+ })
+ })
+
+ it('should show error on failure', async () => {
+ mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
+
+ render()
+
+ await waitFor(() => {
+ expect(screen.getByText(/error/i)).toBeInTheDocument()
+ })
+ })
+})
+```
+
+### 5. HTTP Mocking with Nock
+
+```typescript
+import nock from 'nock'
+
+const GITHUB_HOST = 'https://api.github.com'
+const GITHUB_PATH = '/repos/owner/repo'
+
+const mockGithubApi = (status: number, body: Record, delayMs = 0) => {
+ return nock(GITHUB_HOST)
+ .get(GITHUB_PATH)
+ .delay(delayMs)
+ .reply(status, body)
+}
+
+describe('GithubComponent', () => {
+ afterEach(() => {
+ nock.cleanAll()
+ })
+
+ it('should display repo info', async () => {
+ mockGithubApi(200, { name: 'dify', stars: 1000 })
+
+ render()
+
+ await waitFor(() => {
+ expect(screen.getByText('dify')).toBeInTheDocument()
+ })
+ })
+
+ it('should handle API error', async () => {
+ mockGithubApi(500, { message: 'Server error' })
+
+ render()
+
+ await waitFor(() => {
+ expect(screen.getByText(/error/i)).toBeInTheDocument()
+ })
+ })
+})
+```
+
+### 6. Context Providers
+
+```typescript
+import { ProviderContext } from '@/context/provider-context'
+import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
+
+describe('Component with Context', () => {
+ it('should render for free plan', () => {
+ const mockContext = createMockPlan('sandbox')
+
+ render(
+
+
+
+ )
+
+ expect(screen.getByText('Upgrade')).toBeInTheDocument()
+ })
+
+ it('should render for pro plan', () => {
+ const mockContext = createMockPlan('professional')
+
+ render(
+
+
+
+ )
+
+ expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
+ })
+})
+```
+
+### 7. React Query
+
+```typescript
+import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
+
+const createTestQueryClient = () => new QueryClient({
+ defaultOptions: {
+ queries: { retry: false },
+ mutations: { retry: false },
+ },
+})
+
+const renderWithQueryClient = (ui: React.ReactElement) => {
+ const queryClient = createTestQueryClient()
+ return render(
+
+ {ui}
+
+ )
+}
+```
+
+## Mock Best Practices
+
+### ✅ DO
+
+1. **Use real base components** - Import from `@/app/components/base/` directly
+1. **Use real project components** - Prefer importing over mocking
+1. **Use real Zustand stores** - Set test state via `store.setState()`
+1. **Reset mocks in `beforeEach`**, not `afterEach`
+1. **Match actual component behavior** in mocks (when mocking is necessary)
+1. **Use factory functions** for complex mock data
+1. **Import actual types** for type safety
+1. **Reset shared mock state** in `beforeEach`
+
+### ❌ DON'T
+
+1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+1. **Don't mock Zustand store modules** - Use real stores with `setState()`
+1. Don't mock components you can import directly
+1. Don't create overly simplified mocks that miss conditional logic
+1. Don't forget to clean up nock after each test
+1. Don't use `any` types in mocks without necessity
+
+### Mock Decision Tree
+
+```
+Need to use a component in test?
+│
+├─ Is it from @/app/components/base/*?
+│ └─ YES → Import real component, DO NOT mock
+│
+├─ Is it a project component?
+│ └─ YES → Prefer importing real component
+│ Only mock if setup is extremely complex
+│
+├─ Is it an API service (@/service/*)?
+│ └─ YES → Mock it
+│
+├─ Is it a third-party lib with side effects?
+│ └─ YES → Mock it (next/navigation, external SDKs)
+│
+├─ Is it a Zustand store?
+│ └─ YES → DO NOT mock the module!
+│ Use real store + setState() to set test state
+│ (Global mock handles auto-reset)
+│
+└─ Is it i18n?
+ └─ YES → Uses shared mock (auto-loaded). Override only for custom translations
+```
+
+## Zustand Store Testing
+
+### Global Zustand Mock (Auto-loaded)
+
+Zustand is globally mocked in `web/vitest.setup.ts` following the [official Zustand testing guide](https://zustand.docs.pmnd.rs/guides/testing). The mock in `web/__mocks__/zustand.ts` provides:
+
+- Real store behavior with `getState()`, `setState()`, `subscribe()` methods
+- Automatic store reset after each test via `afterEach`
+- Proper test isolation between tests
+
+### ✅ Recommended: Use Real Stores (Official Best Practice)
+
+**DO NOT mock store modules manually.** Import and use the real store, then use `setState()` to set test state:
+
+```typescript
+// ✅ CORRECT: Use real store with setState
+import { useAppStore } from '@/app/components/app/store'
+
+describe('MyComponent', () => {
+ it('should render app details', () => {
+ // Arrange: Set test state via setState
+ useAppStore.setState({
+ appDetail: {
+ id: 'test-app',
+ name: 'Test App',
+ mode: 'chat',
+ },
+ })
+
+ // Act
+ render()
+
+ // Assert
+ expect(screen.getByText('Test App')).toBeInTheDocument()
+ // Can also verify store state directly
+ expect(useAppStore.getState().appDetail?.name).toBe('Test App')
+ })
+
+ // No cleanup needed - global mock auto-resets after each test
+})
+```
+
+### ❌ Avoid: Manual Store Module Mocking
+
+Manual mocking conflicts with the global Zustand mock and loses store functionality:
+
+```typescript
+// ❌ WRONG: Don't mock the store module
+vi.mock('@/app/components/app/store', () => ({
+ useStore: (selector) => mockSelector(selector), // Missing getState, setState!
+}))
+
+// ❌ WRONG: This conflicts with global zustand mock
+vi.mock('@/app/components/workflow/store', () => ({
+ useWorkflowStore: vi.fn(() => mockState),
+}))
+```
+
+**Problems with manual mocking:**
+
+1. Loses `getState()`, `setState()`, `subscribe()` methods
+1. Conflicts with global Zustand mock behavior
+1. Requires manual maintenance of store API
+1. Tests don't reflect actual store behavior
+
+### When Manual Store Mocking is Necessary
+
+In rare cases where the store has complex initialization or side effects, you can mock it, but ensure you provide the full store API:
+
+```typescript
+// If you MUST mock (rare), include full store API
+const mockStore = {
+ appDetail: { id: 'test', name: 'Test' },
+ setAppDetail: vi.fn(),
+}
+
+vi.mock('@/app/components/app/store', () => ({
+ useStore: Object.assign(
+ (selector: (state: typeof mockStore) => unknown) => selector(mockStore),
+ {
+ getState: () => mockStore,
+ setState: vi.fn(),
+ subscribe: vi.fn(),
+ },
+ ),
+}))
+```
+
+### Store Testing Decision Tree
+
+```
+Need to test a component using Zustand store?
+│
+├─ Can you use the real store?
+│ └─ YES → Use real store + setState (RECOMMENDED)
+│ useAppStore.setState({ ... })
+│
+├─ Does the store have complex initialization/side effects?
+│ └─ YES → Consider mocking, but include full API
+│ (getState, setState, subscribe)
+│
+└─ Are you testing the store itself (not a component)?
+ └─ YES → Test store directly with getState/setState
+ const store = useMyStore
+ store.setState({ count: 0 })
+ store.getState().increment()
+ expect(store.getState().count).toBe(1)
+```
+
+### Example: Testing Store Actions
+
+```typescript
+import { useCounterStore } from '@/stores/counter'
+
+describe('Counter Store', () => {
+ it('should increment count', () => {
+ // Initial state (auto-reset by global mock)
+ expect(useCounterStore.getState().count).toBe(0)
+
+ // Call action
+ useCounterStore.getState().increment()
+
+ // Verify state change
+ expect(useCounterStore.getState().count).toBe(1)
+ })
+
+ it('should reset to initial state', () => {
+ // Set some state
+ useCounterStore.setState({ count: 100 })
+ expect(useCounterStore.getState().count).toBe(100)
+
+ // After this test, global mock will reset to initial state
+ })
+})
+```
+
+## Factory Function Pattern
+
+```typescript
+// __mocks__/data-factories.ts
+import type { User, Project } from '@/types'
+
+export const createMockUser = (overrides: Partial = {}): User => ({
+ id: 'user-1',
+ name: 'Test User',
+ email: 'test@example.com',
+ role: 'member',
+ createdAt: new Date().toISOString(),
+ ...overrides,
+})
+
+export const createMockProject = (overrides: Partial = {}): Project => ({
+ id: 'project-1',
+ name: 'Test Project',
+ description: 'A test project',
+ owner: createMockUser(),
+ members: [],
+ createdAt: new Date().toISOString(),
+ ...overrides,
+})
+
+// Usage in tests
+it('should display project owner', () => {
+ const project = createMockProject({
+ owner: createMockUser({ name: 'John Doe' }),
+ })
+
+ render()
+ expect(screen.getByText('John Doe')).toBeInTheDocument()
+})
+```
diff --git a/.agents/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md
new file mode 100644
index 0000000000..bc4ed8285a
--- /dev/null
+++ b/.agents/skills/frontend-testing/references/workflow.md
@@ -0,0 +1,269 @@
+# Testing Workflow Guide
+
+This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
+
+## Scope Clarification
+
+This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals.
+
+| Scope | Rule |
+|-------|------|
+| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
+| **Multi-file directory** | Process one file at a time, verify each before proceeding |
+
+## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
+
+When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
+
+### Why Incremental?
+
+| Batch Approach (❌) | Incremental Approach (✅) |
+|---------------------|---------------------------|
+| Generate 5+ tests at once | Generate 1 test at a time |
+| Run tests only at the end | Run test immediately after each file |
+| Multiple failures compound | Single point of failure, easy to debug |
+| Hard to identify root cause | Clear cause-effect relationship |
+| Mock issues affect many files | Mock issues caught early |
+| Messy git history | Clean, atomic commits possible |
+
+## Single File Workflow
+
+When testing a **single component, hook, or utility**:
+
+```
+1. Read source code completely
+2. Run `pnpm analyze-component ` (if available)
+3. Check complexity score and features detected
+4. Write the test file
+5. Run test: `pnpm test .spec.tsx`
+6. Fix any failures
+7. Verify coverage meets goals (100% function, >95% branch)
+```
+
+## Directory/Multi-File Workflow (MUST FOLLOW)
+
+When testing a **directory or multiple files**, follow this strict workflow:
+
+### Step 1: Analyze and Plan
+
+1. **List all files** that need tests in the directory
+1. **Categorize by complexity**:
+ - 🟢 **Simple**: Utility functions, simple hooks, presentational components
+ - 🟡 **Medium**: Components with state, effects, or event handlers
+ - 🔴 **Complex**: Components with API calls, routing, or many dependencies
+1. **Order by dependency**: Test dependencies before dependents
+1. **Create a todo list** to track progress
+
+### Step 2: Determine Processing Order
+
+Process files in this recommended order:
+
+```
+1. Utility functions (simplest, no React)
+2. Custom hooks (isolated logic)
+3. Simple presentational components (few/no props)
+4. Medium complexity components (state, effects)
+5. Complex components (API, routing, many deps)
+6. Container/index components (integration tests - last)
+```
+
+**Rationale**:
+
+- Simpler files help establish mock patterns
+- Hooks used by components should be tested first
+- Integration tests (index files) depend on child components working
+
+### Step 3: Process Each File Incrementally
+
+**For EACH file in the ordered list:**
+
+```
+┌─────────────────────────────────────────────┐
+│ 1. Write test file │
+│ 2. Run: pnpm test .spec.tsx │
+│ 3. If FAIL → Fix immediately, re-run │
+│ 4. If PASS → Mark complete in todo list │
+│ 5. ONLY THEN proceed to next file │
+└─────────────────────────────────────────────┘
+```
+
+**DO NOT proceed to the next file until the current one passes.**
+
+### Step 4: Final Verification
+
+After all individual tests pass:
+
+```bash
+# Run all tests in the directory together
+pnpm test path/to/directory/
+
+# Check coverage
+pnpm test:coverage path/to/directory/
+```
+
+## Component Complexity Guidelines
+
+Use `pnpm analyze-component ` to assess complexity before testing.
+
+### 🔴 Very Complex Components (Complexity > 50)
+
+**Consider refactoring BEFORE testing:**
+
+- Break component into smaller, testable pieces
+- Extract complex logic into custom hooks
+- Separate container and presentational layers
+
+**If testing as-is:**
+
+- Use integration tests for complex workflows
+- Use `test.each()` for data-driven testing
+- Multiple `describe` blocks for organization
+- Consider testing major sections separately
+
+### 🟡 Medium Complexity (Complexity 30-50)
+
+- Group related tests in `describe` blocks
+- Test integration scenarios between internal parts
+- Focus on state transitions and side effects
+- Use helper functions to reduce test complexity
+
+### 🟢 Simple Components (Complexity < 30)
+
+- Standard test structure
+- Focus on props, rendering, and edge cases
+- Usually straightforward to test
+
+### 📏 Large Files (500+ lines)
+
+Regardless of complexity score:
+
+- **Strongly consider refactoring** before testing
+- If testing as-is, test major sections separately
+- Create helper functions for test setup
+- May need multiple test files
+
+## Todo List Format
+
+When testing multiple files, use a todo list like this:
+
+```
+Testing: path/to/directory/
+
+Ordered by complexity (simple → complex):
+
+☐ utils/helper.ts [utility, simple]
+☐ hooks/use-custom-hook.ts [hook, simple]
+☐ empty-state.tsx [component, simple]
+☐ item-card.tsx [component, medium]
+☐ list.tsx [component, complex]
+☐ index.tsx [integration]
+
+Progress: 0/6 complete
+```
+
+Update status as you complete each:
+
+- ☐ → ⏳ (in progress)
+- ⏳ → ✅ (complete and verified)
+- ⏳ → ❌ (blocked, needs attention)
+
+## When to Stop and Verify
+
+**Always run tests after:**
+
+- Completing a test file
+- Making changes to fix a failure
+- Modifying shared mocks
+- Updating test utilities or helpers
+
+**Signs you should pause:**
+
+- More than 2 consecutive test failures
+- Mock-related errors appearing
+- Unclear why a test is failing
+- Test passing but coverage unexpectedly low
+
+## Common Pitfalls to Avoid
+
+### ❌ Don't: Generate Everything First
+
+```
+# BAD: Writing all files then testing
+Write component-a.spec.tsx
+Write component-b.spec.tsx
+Write component-c.spec.tsx
+Write component-d.spec.tsx
+Run pnpm test ← Multiple failures, hard to debug
+```
+
+### ✅ Do: Verify Each Step
+
+```
+# GOOD: Incremental with verification
+Write component-a.spec.tsx
+Run pnpm test component-a.spec.tsx ✅
+Write component-b.spec.tsx
+Run pnpm test component-b.spec.tsx ✅
+...continue...
+```
+
+### ❌ Don't: Skip Verification for "Simple" Components
+
+Even simple components can have:
+
+- Import errors
+- Missing mock setup
+- Incorrect assumptions about props
+
+**Always verify, regardless of perceived simplicity.**
+
+### ❌ Don't: Continue When Tests Fail
+
+Failing tests compound:
+
+- A mock issue in file A affects files B, C, D
+- Fixing A later requires revisiting all dependent tests
+- Time wasted on debugging cascading failures
+
+**Fix failures immediately before proceeding.**
+
+## Integration with Claude's Todo Feature
+
+When using Claude for multi-file testing:
+
+1. **Ask Claude to create a todo list** before starting
+1. **Request one file at a time** or ensure Claude processes incrementally
+1. **Verify each test passes** before asking for the next
+1. **Mark todos complete** as you progress
+
+Example prompt:
+
+```
+Test all components in `path/to/directory/`.
+First, analyze the directory and create a todo list ordered by complexity.
+Then, process ONE file at a time, waiting for my confirmation that tests pass
+before proceeding to the next.
+```
+
+## Summary Checklist
+
+Before starting multi-file testing:
+
+- [ ] Listed all files needing tests
+- [ ] Ordered by complexity (simple → complex)
+- [ ] Created todo list for tracking
+- [ ] Understand dependencies between files
+
+During testing:
+
+- [ ] Processing ONE file at a time
+- [ ] Running tests after EACH file
+- [ ] Fixing failures BEFORE proceeding
+- [ ] Updating todo list progress
+
+After completion:
+
+- [ ] All individual tests pass
+- [ ] Full directory test run passes
+- [ ] Coverage goals met
+- [ ] Todo list shows all complete
diff --git a/.agents/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md
new file mode 100644
index 0000000000..4e3bfc7a37
--- /dev/null
+++ b/.agents/skills/orpc-contract-first/SKILL.md
@@ -0,0 +1,46 @@
+---
+name: orpc-contract-first
+description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
+---
+
+# oRPC Contract-First Development
+
+## Project Structure
+
+```
+web/contract/
+├── base.ts # Base contract (inputStructure: 'detailed')
+├── router.ts # Router composition & type exports
+├── marketplace.ts # Marketplace contracts
+└── console/ # Console contracts by domain
+ ├── system.ts
+ └── billing.ts
+```
+
+## Workflow
+
+1. **Create contract** in `web/contract/console/{domain}.ts`
+ - Import `base` from `../base` and `type` from `@orpc/contract`
+ - Define route with `path`, `method`, `input`, `output`
+
+2. **Register in router** at `web/contract/router.ts`
+ - Import directly from domain file (no barrel files)
+ - Nest by API prefix: `billing: { invoices, bindPartnerStack }`
+
+3. **Create hooks** in `web/service/use-{domain}.ts`
+ - Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
+ - Use `consoleClient.{group}.{contract}()` for API calls
+
+## Key Rules
+
+- **Input structure**: Always use `{ params, query?, body? }` format
+- **Path params**: Use `{paramName}` in path, match in `params` object
+- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`)
+- **No barrel files**: Import directly from specific files
+- **Types**: Import from `@/types/`, use `type()` helper
+
+## Type Export
+
+```typescript
+export type ConsoleInputs = InferContractRouterInputs
+```
diff --git a/.claude/settings.json b/.claude/settings.json
new file mode 100644
index 0000000000..fe108722be
--- /dev/null
+++ b/.claude/settings.json
@@ -0,0 +1,15 @@
+{
+ "hooks": {
+ "PreToolUse": [
+ {
+ "matcher": "Bash",
+ "hooks": [
+ {
+ "type": "command",
+ "command": "npx -y block-no-verify@1.1.1"
+ }
+ ]
+ }
+ ]
+ }
+}
diff --git a/.claude/settings.json.example b/.claude/settings.json.example
deleted file mode 100644
index 1149895340..0000000000
--- a/.claude/settings.json.example
+++ /dev/null
@@ -1,19 +0,0 @@
-{
- "permissions": {
- "allow": [],
- "deny": []
- },
- "env": {
- "__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
- "GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
- },
- "enabledMcpjsonServers": [
- "context7",
- "sequential-thinking",
- "github",
- "fetch",
- "playwright",
- "ide"
- ],
- "enableAllProjectMcpServers": true
- }
\ No newline at end of file
diff --git a/.claude/skills/component-refactoring b/.claude/skills/component-refactoring
new file mode 120000
index 0000000000..53ae67e2f2
--- /dev/null
+++ b/.claude/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.claude/skills/frontend-code-review b/.claude/skills/frontend-code-review
new file mode 120000
index 0000000000..55654ffbd7
--- /dev/null
+++ b/.claude/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.claude/skills/frontend-testing b/.claude/skills/frontend-testing
new file mode 120000
index 0000000000..092cec7745
--- /dev/null
+++ b/.claude/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.claude/skills/orpc-contract-first b/.claude/skills/orpc-contract-first
new file mode 120000
index 0000000000..da47b335c7
--- /dev/null
+++ b/.claude/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.codex/skills/component-refactoring b/.codex/skills/component-refactoring
new file mode 120000
index 0000000000..53ae67e2f2
--- /dev/null
+++ b/.codex/skills/component-refactoring
@@ -0,0 +1 @@
+../../.agents/skills/component-refactoring
\ No newline at end of file
diff --git a/.codex/skills/frontend-code-review b/.codex/skills/frontend-code-review
new file mode 120000
index 0000000000..55654ffbd7
--- /dev/null
+++ b/.codex/skills/frontend-code-review
@@ -0,0 +1 @@
+../../.agents/skills/frontend-code-review
\ No newline at end of file
diff --git a/.codex/skills/frontend-testing b/.codex/skills/frontend-testing
new file mode 120000
index 0000000000..092cec7745
--- /dev/null
+++ b/.codex/skills/frontend-testing
@@ -0,0 +1 @@
+../../.agents/skills/frontend-testing
\ No newline at end of file
diff --git a/.codex/skills/orpc-contract-first b/.codex/skills/orpc-contract-first
new file mode 120000
index 0000000000..da47b335c7
--- /dev/null
+++ b/.codex/skills/orpc-contract-first
@@ -0,0 +1 @@
+../../.agents/skills/orpc-contract-first
\ No newline at end of file
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000000..190c0c185b
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,5 @@
+[run]
+omit =
+ api/tests/*
+ api/migrations/*
+ api/core/rag/datasource/vdb/*
diff --git a/.cursorrules b/.cursorrules
deleted file mode 100644
index cdfb8b17a3..0000000000
--- a/.cursorrules
+++ /dev/null
@@ -1,6 +0,0 @@
-# Cursor Rules for Dify Project
-
-## Automated Test Generation
-
-- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
-- When proposing or saving tests, re-read that document and follow every requirement.
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index ddec42e0ee..3998a69c36 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -6,6 +6,9 @@
"context": "..",
"dockerfile": "Dockerfile"
},
+ "mounts": [
+ "source=dify-dev-tmp,target=/tmp,type=volume"
+ ],
"features": {
"ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true,
@@ -34,19 +37,13 @@
},
"postStartCommand": "./.devcontainer/post_start_command.sh",
"postCreateCommand": "./.devcontainer/post_create_command.sh"
-
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
-
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
-
// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "python --version",
-
// Configure tool-specific properties.
// "customizations": {},
-
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
- // "remoteUser": "root"
-}
+}
\ No newline at end of file
diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh
index a26fd076ed..637593b9de 100755
--- a/.devcontainer/post_create_command.sh
+++ b/.devcontainer/post_create_command.sh
@@ -1,13 +1,14 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
+export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
corepack enable
cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
-echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
-echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
+echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
+echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index d6f326d4dc..106c26bbed 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -6,229 +6,244 @@
* @crazywoola @laipz8200 @Yeuoly
+# CODEOWNERS file
+/.github/CODEOWNERS @laipz8200 @crazywoola
+
+# Docs
+/docs/ @crazywoola
+
# Backend (default owner, more specific rules below will override)
-api/ @QuantumGhost
+/api/ @QuantumGhost
# Backend - MCP
-api/core/mcp/ @Nov1c444
-api/core/entities/mcp_provider.py @Nov1c444
-api/services/tools/mcp_tools_manage_service.py @Nov1c444
-api/controllers/mcp/ @Nov1c444
-api/controllers/console/app/mcp_server.py @Nov1c444
-api/tests/**/*mcp* @Nov1c444
+/api/core/mcp/ @Nov1c444
+/api/core/entities/mcp_provider.py @Nov1c444
+/api/services/tools/mcp_tools_manage_service.py @Nov1c444
+/api/controllers/mcp/ @Nov1c444
+/api/controllers/console/app/mcp_server.py @Nov1c444
+/api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
-api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
-api/core/workflow/runtime/ @laipz8200 @QuantumGhost
-api/core/workflow/graph/ @laipz8200 @QuantumGhost
-api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
-api/core/workflow/node_events/ @laipz8200 @QuantumGhost
-api/core/model_runtime/ @laipz8200 @QuantumGhost
+/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
+/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
+/api/core/workflow/graph/ @laipz8200 @QuantumGhost
+/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
+/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
+/api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
-api/core/workflow/nodes/agent/ @Nov1c444
-api/core/workflow/nodes/iteration/ @Nov1c444
-api/core/workflow/nodes/loop/ @Nov1c444
-api/core/workflow/nodes/llm/ @Nov1c444
+/api/core/workflow/nodes/agent/ @Nov1c444
+/api/core/workflow/nodes/iteration/ @Nov1c444
+/api/core/workflow/nodes/loop/ @Nov1c444
+/api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
-api/core/rag/ @JohnJyong
-api/services/rag_pipeline/ @JohnJyong
-api/services/dataset_service.py @JohnJyong
-api/services/knowledge_service.py @JohnJyong
-api/services/external_knowledge_service.py @JohnJyong
-api/services/hit_testing_service.py @JohnJyong
-api/services/metadata_service.py @JohnJyong
-api/services/vector_service.py @JohnJyong
-api/services/entities/knowledge_entities/ @JohnJyong
-api/services/entities/external_knowledge_entities/ @JohnJyong
-api/controllers/console/datasets/ @JohnJyong
-api/controllers/service_api/dataset/ @JohnJyong
-api/models/dataset.py @JohnJyong
-api/tasks/rag_pipeline/ @JohnJyong
-api/tasks/add_document_to_index_task.py @JohnJyong
-api/tasks/batch_clean_document_task.py @JohnJyong
-api/tasks/clean_document_task.py @JohnJyong
-api/tasks/clean_notion_document_task.py @JohnJyong
-api/tasks/document_indexing_task.py @JohnJyong
-api/tasks/document_indexing_sync_task.py @JohnJyong
-api/tasks/document_indexing_update_task.py @JohnJyong
-api/tasks/duplicate_document_indexing_task.py @JohnJyong
-api/tasks/recover_document_indexing_task.py @JohnJyong
-api/tasks/remove_document_from_index_task.py @JohnJyong
-api/tasks/retry_document_indexing_task.py @JohnJyong
-api/tasks/sync_website_document_indexing_task.py @JohnJyong
-api/tasks/batch_create_segment_to_index_task.py @JohnJyong
-api/tasks/create_segment_to_index_task.py @JohnJyong
-api/tasks/delete_segment_from_index_task.py @JohnJyong
-api/tasks/disable_segment_from_index_task.py @JohnJyong
-api/tasks/disable_segments_from_index_task.py @JohnJyong
-api/tasks/enable_segment_to_index_task.py @JohnJyong
-api/tasks/enable_segments_to_index_task.py @JohnJyong
-api/tasks/clean_dataset_task.py @JohnJyong
-api/tasks/deal_dataset_index_update_task.py @JohnJyong
-api/tasks/deal_dataset_vector_index_task.py @JohnJyong
+/api/core/rag/ @JohnJyong
+/api/services/rag_pipeline/ @JohnJyong
+/api/services/dataset_service.py @JohnJyong
+/api/services/knowledge_service.py @JohnJyong
+/api/services/external_knowledge_service.py @JohnJyong
+/api/services/hit_testing_service.py @JohnJyong
+/api/services/metadata_service.py @JohnJyong
+/api/services/vector_service.py @JohnJyong
+/api/services/entities/knowledge_entities/ @JohnJyong
+/api/services/entities/external_knowledge_entities/ @JohnJyong
+/api/controllers/console/datasets/ @JohnJyong
+/api/controllers/service_api/dataset/ @JohnJyong
+/api/models/dataset.py @JohnJyong
+/api/tasks/rag_pipeline/ @JohnJyong
+/api/tasks/add_document_to_index_task.py @JohnJyong
+/api/tasks/batch_clean_document_task.py @JohnJyong
+/api/tasks/clean_document_task.py @JohnJyong
+/api/tasks/clean_notion_document_task.py @JohnJyong
+/api/tasks/document_indexing_task.py @JohnJyong
+/api/tasks/document_indexing_sync_task.py @JohnJyong
+/api/tasks/document_indexing_update_task.py @JohnJyong
+/api/tasks/duplicate_document_indexing_task.py @JohnJyong
+/api/tasks/recover_document_indexing_task.py @JohnJyong
+/api/tasks/remove_document_from_index_task.py @JohnJyong
+/api/tasks/retry_document_indexing_task.py @JohnJyong
+/api/tasks/sync_website_document_indexing_task.py @JohnJyong
+/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
+/api/tasks/create_segment_to_index_task.py @JohnJyong
+/api/tasks/delete_segment_from_index_task.py @JohnJyong
+/api/tasks/disable_segment_from_index_task.py @JohnJyong
+/api/tasks/disable_segments_from_index_task.py @JohnJyong
+/api/tasks/enable_segment_to_index_task.py @JohnJyong
+/api/tasks/enable_segments_to_index_task.py @JohnJyong
+/api/tasks/clean_dataset_task.py @JohnJyong
+/api/tasks/deal_dataset_index_update_task.py @JohnJyong
+/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
-api/core/plugin/ @Mairuis @Yeuoly @Stream29
-api/services/plugin/ @Mairuis @Yeuoly @Stream29
-api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
-api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
-api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
+/api/core/plugin/ @Mairuis @Yeuoly @Stream29
+/api/services/plugin/ @Mairuis @Yeuoly @Stream29
+/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
+/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
+/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
-api/controllers/trigger/ @Mairuis @Yeuoly
-api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
-api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
-api/core/trigger/ @Mairuis @Yeuoly
-api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
-api/services/trigger/ @Mairuis @Yeuoly
-api/models/trigger.py @Mairuis @Yeuoly
-api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
-api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
-api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
-api/libs/schedule_utils.py @Mairuis @Yeuoly
-api/services/workflow/scheduler.py @Mairuis @Yeuoly
-api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
-api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
-api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
-api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
-api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
-api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
-api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
-api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
-api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
-api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
+/api/controllers/trigger/ @Mairuis @Yeuoly
+/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
+/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
+/api/core/trigger/ @Mairuis @Yeuoly
+/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
+/api/services/trigger/ @Mairuis @Yeuoly
+/api/models/trigger.py @Mairuis @Yeuoly
+/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
+/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
+/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
+/api/libs/schedule_utils.py @Mairuis @Yeuoly
+/api/services/workflow/scheduler.py @Mairuis @Yeuoly
+/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
+/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
+/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
+/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
+/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
+/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
+/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
+/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
+/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
+/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
-api/services/async_workflow_service.py @Mairuis @Yeuoly
-api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
+/api/services/async_workflow_service.py @Mairuis @Yeuoly
+/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
-api/services/billing_service.py @hj24 @zyssyz123
-api/controllers/console/billing/ @hj24 @zyssyz123
+/api/services/billing_service.py @hj24 @zyssyz123
+/api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
-api/configs/enterprise/ @GarfieldDai @GareArc
-api/services/enterprise/ @GarfieldDai @GareArc
-api/services/feature_service.py @GarfieldDai @GareArc
-api/controllers/console/feature.py @GarfieldDai @GareArc
-api/controllers/web/feature.py @GarfieldDai @GareArc
+/api/configs/enterprise/ @GarfieldDai @GareArc
+/api/services/enterprise/ @GarfieldDai @GareArc
+/api/services/feature_service.py @GarfieldDai @GareArc
+/api/controllers/console/feature.py @GarfieldDai @GareArc
+/api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
-api/migrations/ @snakevash @laipz8200
+/api/migrations/ @snakevash @laipz8200 @MRZHUH
+
+# Backend - Vector DB Middleware
+/api/configs/middleware/vdb/* @JohnJyong
# Frontend
-web/ @iamjoel
+/web/ @iamjoel
+
+# Frontend - Web Tests
+/.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration
-web/app/components/workflow/ @iamjoel @zxhlyh
-web/app/components/workflow-app/ @iamjoel @zxhlyh
-web/app/components/app/configuration/ @iamjoel @zxhlyh
-web/app/components/app/app-publisher/ @iamjoel @zxhlyh
+/web/app/components/workflow/ @iamjoel @zxhlyh
+/web/app/components/workflow-app/ @iamjoel @zxhlyh
+/web/app/components/app/configuration/ @iamjoel @zxhlyh
+/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
-web/app/components/base/chat/ @iamjoel @zxhlyh
+/web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
-web/app/components/share/text-generation/ @iamjoel @zxhlyh
+/web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
-web/app/components/apps/ @JzoNgKVO @iamjoel
-web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
-web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
-web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
+/web/app/components/apps/ @JzoNgKVO @iamjoel
+/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
+/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
+/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
-web/app/components/develop/ @JzoNgKVO @iamjoel
+/web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
-web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
-web/app/components/app/log/ @JzoNgKVO @iamjoel
-web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
-web/app/components/app/annotation/ @JzoNgKVO @iamjoel
+/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
+/web/app/components/app/log/ @JzoNgKVO @iamjoel
+/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
+/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
-web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
-web/app/components/app/overview/ @JzoNgKVO @iamjoel
+/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
+/web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
-web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
+/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
-web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
+/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
-web/app/components/datasets/list/ @iamjoel @WTW0313
-web/app/components/datasets/create/ @iamjoel @WTW0313
-web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
-web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
+/web/app/components/datasets/list/ @iamjoel @WTW0313
+/web/app/components/datasets/create/ @iamjoel @WTW0313
+/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
+/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
-web/app/components/rag-pipeline/ @iamjoel @WTW0313
-web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
-web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
+/web/app/components/rag-pipeline/ @iamjoel @WTW0313
+/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
+/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
-web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
-web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
+/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
+/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
-web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
+/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
-web/app/components/datasets/settings/ @iamjoel @WTW0313
+/web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
-web/app/components/plugins/ @iamjoel @zhsama
+/web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
-web/app/components/tools/ @iamjoel @Yessenia-d
+/web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
-web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
+/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
-web/app/signin/ @douxc @iamjoel
-web/app/signup/ @douxc @iamjoel
-web/app/reset-password/ @douxc @iamjoel
-web/app/install/ @douxc @iamjoel
-web/app/init/ @douxc @iamjoel
-web/app/forgot-password/ @douxc @iamjoel
-web/app/account/ @douxc @iamjoel
+/web/app/signin/ @douxc @iamjoel
+/web/app/signup/ @douxc @iamjoel
+/web/app/reset-password/ @douxc @iamjoel
+/web/app/install/ @douxc @iamjoel
+/web/app/init/ @douxc @iamjoel
+/web/app/forgot-password/ @douxc @iamjoel
+/web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
-web/service/base.ts @douxc @iamjoel
+/web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
-web/app/(shareLayout)/components/ @douxc @iamjoel
-web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
-web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
-web/app/components/app/app-access-control/ @douxc @iamjoel
+/web/app/(shareLayout)/components/ @douxc @iamjoel
+/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
+/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
+/web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
-web/app/components/explore/ @CodingOnStar @iamjoel
+/web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
-web/app/components/header/account-setting/ @CodingOnStar @iamjoel
-web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
+/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
+/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
-web/app/components/base/ga/ @CodingOnStar @iamjoel
+/web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
-web/app/components/base/ @iamjoel @zxhlyh
+/web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
-web/utils/classnames.ts @iamjoel @zxhlyh
-web/utils/time.ts @iamjoel @zxhlyh
-web/utils/format.ts @iamjoel @zxhlyh
-web/utils/clipboard.ts @iamjoel @zxhlyh
-web/hooks/use-document-title.ts @iamjoel @zxhlyh
+/web/utils/classnames.ts @iamjoel @zxhlyh
+/web/utils/time.ts @iamjoel @zxhlyh
+/web/utils/format.ts @iamjoel @zxhlyh
+/web/utils/clipboard.ts @iamjoel @zxhlyh
+/web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
-web/app/components/billing/ @iamjoel @zxhlyh
-web/app/education-apply/ @iamjoel @zxhlyh
+/web/app/components/billing/ @iamjoel @zxhlyh
+/web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
-web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
+/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
+
+# Docker
+/docker/* @laipz8200
diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
deleted file mode 100644
index 53afcbda1e..0000000000
--- a/.github/copilot-instructions.md
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copilot Instructions
-
-GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
-
-Key reminders:
-
-- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
-- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
-- Target >95% line and branch coverage and 100% function/statement coverage.
-- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
-
-Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
diff --git a/.github/labeler.yml b/.github/labeler.yml
new file mode 100644
index 0000000000..d1d324d381
--- /dev/null
+++ b/.github/labeler.yml
@@ -0,0 +1,3 @@
+web:
+ - changed-files:
+ - any-glob-to-any-file: 'web/**'
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index aa5a50918a..50dbde2aee 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
-- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
+- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index 557d747a8c..52e3272f99 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -39,12 +39,6 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- - name: Run pyrefly check
- run: |
- cd api
- uv add --dev pyrefly
- uv run pyrefly check || true
-
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@@ -57,7 +51,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
- uses: hoverkraft-tech/compose-action@v2.0.2
+ uses: hoverkraft-tech/compose-action@v2
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -71,18 +65,19 @@ jobs:
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- - name: Run Workflow
- run: uv run --project api bash dev/pytest/pytest_workflow.sh
-
- - name: Run Tool
- run: uv run --project api bash dev/pytest/pytest_tools.sh
-
- - name: Run TestContainers
- run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
-
- - name: Run Unit tests
+ - name: Run API Tests
+ env:
+ STORAGE_TYPE: opendal
+ OPENDAL_SCHEME: fs
+ OPENDAL_FS_ROOT: /tmp/dify-storage
run: |
- uv run --project api bash dev/pytest/pytest_unit_tests.sh
+ uv run --project api pytest \
+ -n auto \
+ --timeout "${PYTEST_TIMEOUT:-180}" \
+ api/tests/integration_tests/workflow \
+ api/tests/integration_tests/tools \
+ api/tests/test_containers_integration_tests \
+ api/tests/unit_tests
- name: Coverage Summary
run: |
@@ -93,5 +88,12 @@ jobs:
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
- uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
-
+ {
+ echo ""
+ echo "File-level coverage (click to expand)"
+ echo ""
+ echo '```'
+ uv run --project api coverage report -m
+ echo '```'
+ echo ""
+ } >> $GITHUB_STEP_SUMMARY
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index 81392a9734..4a8c61e7d2 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -12,12 +12,29 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- # Use uv to ensure we have the same ruff version in CI and locally.
- - uses: astral-sh/setup-uv@v6
+ - name: Check Docker Compose inputs
+ id: docker-compose-changes
+ uses: tj-actions/changed-files@v47
+ with:
+ files: |
+ docker/generate_docker_compose
+ docker/.env.example
+ docker/docker-compose-template.yaml
+ docker/docker-compose.yaml
+ - uses: actions/setup-python@v6
with:
python-version: "3.11"
+
+ - uses: astral-sh/setup-uv@v7
+
+ - name: Generate Docker Compose
+ if: steps.docker-compose-changes.outputs.any_changed == 'true'
+ run: |
+ cd docker
+ ./generate_docker_compose
+
- run: |
cd api
uv sync --dev
@@ -35,10 +52,11 @@ jobs:
- name: ast-grep
run: |
- uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
- uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
- uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
+ # ast-grep exits 1 if no matches are found; allow idempotent runs.
+ uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
+ uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
+ uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
+ uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union
@@ -56,35 +74,37 @@ jobs:
pattern: $T
fix: $T | None
EOF
- uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
+ uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
- - name: mdformat
- run: |
- uvx mdformat .
-
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- - name: Setup NodeJS
- uses: actions/setup-node@v4
+ - name: Setup Node.js
+ uses: actions/setup-node@v6
with:
- node-version: 22
+ node-version: 24
cache: pnpm
- cache-dependency-path: ./web/package.json
+ cache-dependency-path: ./web/pnpm-lock.yaml
- - name: Web dependencies
- working-directory: ./web
- run: pnpm install --frozen-lockfile
-
- - name: oxlint
- working-directory: ./web
+ - name: Install web dependencies
run: |
- pnpx oxlint --fix
+ cd web
+ pnpm install --frozen-lockfile
+
+ - name: ESLint autofix
+ run: |
+ cd web
+ pnpm lint:fix || true
+
+ # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
+ - name: mdformat
+ run: |
+ uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index f7f464a601..704d896192 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -90,7 +90,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v6
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
@@ -112,7 +112,7 @@ jobs:
context: "web"
steps:
- name: Download digests
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v7
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*
diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml
index 101d973466..e20cf9850b 100644
--- a/.github/workflows/db-migration-test.yml
+++ b/.github/workflows/db-migration-test.yml
@@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"
@@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: "3.12"
diff --git a/.github/workflows/deploy-agent-dev.yml b/.github/workflows/deploy-agent-dev.yml
new file mode 100644
index 0000000000..dd759f7ba5
--- /dev/null
+++ b/.github/workflows/deploy-agent-dev.yml
@@ -0,0 +1,28 @@
+name: Deploy Agent Dev
+
+permissions:
+ contents: read
+
+on:
+ workflow_run:
+ workflows: ["Build and Push API & Web"]
+ branches:
+ - "deploy/agent-dev"
+ types:
+ - completed
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ if: |
+ github.event.workflow_run.conclusion == 'success' &&
+ github.event.workflow_run.head_branch == 'deploy/agent-dev'
+ steps:
+ - name: Deploy to server
+ uses: appleboy/ssh-action@v1
+ with:
+ host: ${{ secrets.AGENT_DEV_SSH_HOST }}
+ username: ${{ secrets.SSH_USER }}
+ key: ${{ secrets.SSH_PRIVATE_KEY }}
+ script: |
+ ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml
index cd1c86e668..38fa0b9a7f 100644
--- a/.github/workflows/deploy-dev.yml
+++ b/.github/workflows/deploy-dev.yml
@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
+ uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml
new file mode 100644
index 0000000000..7d5f0a22e7
--- /dev/null
+++ b/.github/workflows/deploy-hitl.yml
@@ -0,0 +1,29 @@
+name: Deploy HITL
+
+on:
+ workflow_run:
+ workflows: ["Build and Push API & Web"]
+ branches:
+ - "feat/hitl-frontend"
+ - "feat/hitl-backend"
+ types:
+ - completed
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ if: |
+ github.event.workflow_run.conclusion == 'success' &&
+ (
+ github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
+ github.event.workflow_run.head_branch == 'feat/hitl-backend'
+ )
+ steps:
+ - name: Deploy to server
+ uses: appleboy/ssh-action@v1
+ with:
+ host: ${{ secrets.HITL_SSH_HOST }}
+ username: ${{ secrets.SSH_USER }}
+ key: ${{ secrets.SSH_PRIVATE_KEY }}
+ script: |
+ ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
diff --git a/.github/workflows/deploy-trigger-dev.yml b/.github/workflows/deploy-trigger-dev.yml
deleted file mode 100644
index 2d9a904fc5..0000000000
--- a/.github/workflows/deploy-trigger-dev.yml
+++ /dev/null
@@ -1,28 +0,0 @@
-name: Deploy Trigger Dev
-
-permissions:
- contents: read
-
-on:
- workflow_run:
- workflows: ["Build and Push API & Web"]
- branches:
- - "deploy/trigger-dev"
- types:
- - completed
-
-jobs:
- deploy:
- runs-on: ubuntu-latest
- if: |
- github.event.workflow_run.conclusion == 'success' &&
- github.event.workflow_run.head_branch == 'deploy/trigger-dev'
- steps:
- - name: Deploy to server
- uses: appleboy/ssh-action@v0.1.8
- with:
- host: ${{ secrets.TRIGGER_SSH_HOST }}
- username: ${{ secrets.SSH_USER }}
- key: ${{ secrets.SSH_PRIVATE_KEY }}
- script: |
- ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
new file mode 100644
index 0000000000..06782b53c1
--- /dev/null
+++ b/.github/workflows/labeler.yml
@@ -0,0 +1,14 @@
+name: "Pull Request Labeler"
+on:
+ pull_request_target:
+
+jobs:
+ labeler:
+ permissions:
+ contents: read
+ pull-requests: write
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/labeler@v6
+ with:
+ sync-labels: true
diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml
index 876ec23a3d..d6653de950 100644
--- a/.github/workflows/main-ci.yml
+++ b/.github/workflows/main-ci.yml
@@ -27,7 +27,7 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- uses: dorny/paths-filter@v3
id: changes
with:
@@ -38,6 +38,7 @@ jobs:
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
+ - '.github/workflows/web-tests.yml'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 1870b1f670..b6df1d7e93 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- - uses: actions/stale@v5
+ - uses: actions/stale@v10
with:
days-before-issue-stale: 15
days-before-issue-close: 3
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index 5a8a34be79..cbd6edf94b 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
files: |
api/**
@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: false
python-version: "3.12"
@@ -47,13 +47,9 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- - name: Run Basedpyright Checks
+ - name: Run Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
- run: dev/basedpyright-check
-
- - name: Run Mypy Type Checks
- if: steps.changed-files.outputs.any_changed == 'true'
- run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
+ run: make type-check
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
@@ -65,18 +61,23 @@ jobs:
defaults:
run:
working-directory: ./web
+ permissions:
+ checks: write
+ pull-requests: read
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
- files: web/**
+ files: |
+ web/**
+ .github/workflows/style.yml
- name: Install pnpm
uses: pnpm/action-setup@v4
@@ -85,12 +86,12 @@ jobs:
run_install: false
- name: Setup NodeJS
- uses: actions/setup-node@v4
+ uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
- node-version: 22
+ node-version: 24
cache: pnpm
- cache-dependency-path: ./web/package.json
+ cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
@@ -101,42 +102,31 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
- pnpm run lint
+ pnpm run lint:ci
+ # pnpm run lint:report
+ # continue-on-error: true
+
+ # - name: Annotate Code
+ # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
+ # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
+ # with:
+ # eslint-report: web/eslint_report.json
+ # github-token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Web tsslint
+ if: steps.changed-files.outputs.any_changed == 'true'
+ working-directory: ./web
+ run: pnpm run lint:tss
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
- run: pnpm run type-check:tsgo
+ run: pnpm run type-check
- docker-compose-template:
- name: Docker Compose Template
- runs-on: ubuntu-latest
-
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- persist-credentials: false
-
- - name: Check changed files
- id: changed-files
- uses: tj-actions/changed-files@v46
- with:
- files: |
- docker/generate_docker_compose
- docker/.env.example
- docker/docker-compose-template.yaml
- docker/docker-compose.yaml
-
- - name: Generate Docker Compose
+ - name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
- run: |
- cd docker
- ./generate_docker_compose
-
- - name: Check for changes
- if: steps.changed-files.outputs.any_changed == 'true'
- run: git diff --exit-code
+ working-directory: ./web
+ run: pnpm run knip
superlinter:
name: SuperLinter
@@ -144,14 +134,14 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v46
+ uses: tj-actions/changed-files@v47
with:
files: |
**.sh
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index b1ccd7417a..ec392cb3b2 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -16,23 +16,19 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
- strategy:
- matrix:
- node-version: [16, 18, 20, 22]
-
defaults:
run:
working-directory: sdks/nodejs-client
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
with:
persist-credentials: false
- - name: Use Node.js ${{ matrix.node-version }}
- uses: actions/setup-node@v4
+ - name: Use Node.js
+ uses: actions/setup-node@v6
with:
- node-version: ${{ matrix.node-version }}
+ node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'
diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml
deleted file mode 100644
index fe8e2ebc2b..0000000000
--- a/.github/workflows/translate-i18n-base-on-english.yml
+++ /dev/null
@@ -1,91 +0,0 @@
-name: Check i18n Files and Create PR
-
-on:
- push:
- branches: [main]
- paths:
- - 'web/i18n/en-US/*.ts'
-
-permissions:
- contents: write
- pull-requests: write
-
-jobs:
- check-and-update:
- if: github.repository == 'langgenius/dify'
- runs-on: ubuntu-latest
- defaults:
- run:
- working-directory: web
- steps:
- - uses: actions/checkout@v4
- with:
- fetch-depth: 0
- token: ${{ secrets.GITHUB_TOKEN }}
-
- - name: Check for file changes in i18n/en-US
- id: check_files
- run: |
- git fetch origin "${{ github.event.before }}" || true
- git fetch origin "${{ github.sha }}" || true
- changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
- echo "Changed files: $changed_files"
- if [ -n "$changed_files" ]; then
- echo "FILES_CHANGED=true" >> $GITHUB_ENV
- file_args=""
- for file in $changed_files; do
- filename=$(basename "$file" .ts)
- file_args="$file_args --file $filename"
- done
- echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
- echo "File arguments: $file_args"
- else
- echo "FILES_CHANGED=false" >> $GITHUB_ENV
- fi
-
- - name: Install pnpm
- uses: pnpm/action-setup@v4
- with:
- package_json_file: web/package.json
- run_install: false
-
- - name: Set up Node.js
- if: env.FILES_CHANGED == 'true'
- uses: actions/setup-node@v4
- with:
- node-version: 'lts/*'
- cache: pnpm
- cache-dependency-path: ./web/package.json
-
- - name: Install dependencies
- if: env.FILES_CHANGED == 'true'
- working-directory: ./web
- run: pnpm install --frozen-lockfile
-
- - name: Generate i18n translations
- if: env.FILES_CHANGED == 'true'
- working-directory: ./web
- run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
-
- - name: Generate i18n type definitions
- if: env.FILES_CHANGED == 'true'
- working-directory: ./web
- run: pnpm run gen:i18n-types
-
- - name: Create Pull Request
- if: env.FILES_CHANGED == 'true'
- uses: peter-evans/create-pull-request@v6
- with:
- token: ${{ secrets.GITHUB_TOKEN }}
- commit-message: 'chore(i18n): update translations based on en-US changes'
- title: 'chore(i18n): translate i18n files and update type definitions'
- body: |
- This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
-
- **Triggered by:** ${{ github.sha }}
-
- **Changes included:**
- - Updated translation files for all locales
- - Regenerated TypeScript type definitions for type safety
- branch: chore/automated-i18n-updates-${{ github.sha }}
- delete-branch: true
diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml
new file mode 100644
index 0000000000..5d9440ff35
--- /dev/null
+++ b/.github/workflows/translate-i18n-claude.yml
@@ -0,0 +1,440 @@
+name: Translate i18n Files with Claude Code
+
+# Note: claude-code-action doesn't support push events directly.
+# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
+# See: https://github.com/langgenius/dify/issues/30743
+
+on:
+ repository_dispatch:
+ types: [i18n-sync]
+ workflow_dispatch:
+ inputs:
+ files:
+ description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.'
+ required: false
+ type: string
+ languages:
+ description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.'
+ required: false
+ type: string
+ mode:
+ description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
+ required: false
+ default: 'incremental'
+ type: choice
+ options:
+ - incremental
+ - full
+
+permissions:
+ contents: write
+ pull-requests: write
+
+jobs:
+ translate:
+ if: github.repository == 'langgenius/dify'
+ runs-on: ubuntu-latest
+ timeout-minutes: 60
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+ token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Configure Git
+ run: |
+ git config --global user.name "github-actions[bot]"
+ git config --global user.email "github-actions[bot]@users.noreply.github.com"
+
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Set up Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Detect changed files and generate diff
+ id: detect_changes
+ run: |
+ if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
+ # Manual trigger
+ if [ -n "${{ github.event.inputs.files }}" ]; then
+ echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
+ else
+ # Get all JSON files in en-US directory
+ files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
+ echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
+ fi
+ echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
+ echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
+
+ # For manual trigger with incremental mode, get diff from last commit
+ # For full mode, we'll do a complete check anyway
+ if [ "${{ github.event.inputs.mode }}" == "full" ]; then
+ echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ else
+ git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
+ if [ -s /tmp/i18n-diff.txt ]; then
+ echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
+ else
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ fi
+ fi
+ elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
+ # Triggered by push via trigger-i18n-sync.yml workflow
+ # Validate required payload fields
+ if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
+ echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
+ exit 1
+ fi
+ echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
+ echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
+ echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
+
+ # Decode the base64-encoded diff from the trigger workflow
+ if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
+ if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
+ echo "Warning: Failed to decode base64 diff payload" >&2
+ echo "" > /tmp/i18n-diff.txt
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ elif [ -s /tmp/i18n-diff.txt ]; then
+ echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
+ else
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ fi
+ else
+ echo "" > /tmp/i18n-diff.txt
+ echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
+ fi
+ else
+ echo "Unsupported event type: ${{ github.event_name }}"
+ exit 1
+ fi
+
+ # Truncate diff if too large (keep first 50KB)
+ if [ -f /tmp/i18n-diff.txt ]; then
+ head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
+ mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
+ fi
+
+ echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
+
+ - name: Run Claude Code for Translation Sync
+ if: steps.detect_changes.outputs.CHANGED_FILES != ''
+ uses: anthropics/claude-code-action@v1
+ with:
+ anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ # Allow github-actions bot to trigger this workflow via repository_dispatch
+ # See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
+ allowed_bots: 'github-actions[bot]'
+ prompt: |
+ You are a professional i18n synchronization engineer for the Dify project.
+ Your task is to keep all language translations in sync with the English source (en-US).
+
+ ## CRITICAL TOOL RESTRICTIONS
+ - Use **Read** tool to read files (NOT cat or bash)
+ - Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
+ - Use **Bash** ONLY for: git commands, gh commands, pnpm commands
+ - Run bash commands ONE BY ONE, never combine with && or ||
+ - NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
+
+ ## WORKING DIRECTORY & ABSOLUTE PATHS
+ Claude Code sandbox working directory may vary. Always use absolute paths:
+ - For pnpm: `pnpm --dir ${{ github.workspace }}/web `
+ - For git: `git -C ${{ github.workspace }} `
+ - For gh: `gh --repo ${{ github.repository }} `
+ - For file paths: `${{ github.workspace }}/web/i18n/`
+
+ ## EFFICIENCY RULES
+ - **ONE Edit per language file** - batch all key additions into a single Edit
+ - Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
+ - Translate ALL keys for a language mentally first, then do ONE Edit
+
+ ## Context
+ - Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
+ - Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
+ - Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
+ - Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
+ - Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
+ - Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
+
+ ## CRITICAL DESIGN: Verify First, Then Sync
+
+ You MUST follow this three-phase approach:
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
+ ═══════════════════════════════════════════════════════════════
+
+ ### Step 1.1: Analyze Git Diff (for incremental mode)
+ Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
+
+ Parse the diff to categorize changes:
+ - Lines with `+` (not `+++`): Added or modified values
+ - Lines with `-` (not `---`): Removed or old values
+ - Identify specific keys for each category:
+ * ADD: Keys that appear only in `+` lines (new keys)
+ * UPDATE: Keys that appear in both `-` and `+` lines (value changed)
+ * DELETE: Keys that appear only in `-` lines (removed keys)
+
+ ### Step 1.2: Read Language Configuration
+ Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
+ Extract all languages with `supported: true`.
+
+ ### Step 1.3: Run i18n:check for Each Language
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
+ ```
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web run i18n:check
+ ```
+
+ This will report:
+ - Missing keys (need to ADD)
+ - Extra keys (need to DELETE)
+
+ ### Step 1.4: Generate Change Report
+
+ Create a structured report identifying:
+ ```
+ ╔══════════════════════════════════════════════════════════════╗
+ ║ I18N SYNC CHANGE REPORT ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ Files to process: [list] ║
+ ║ Languages to sync: [list] ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ ADD (New Keys): ║
+ ║ - [filename].[key]: "English value" ║
+ ║ ... ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ UPDATE (Modified Keys - MUST re-translate): ║
+ ║ - [filename].[key]: "Old value" → "New value" ║
+ ║ ... ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ DELETE (Extra Keys): ║
+ ║ - [language]/[filename].[key] ║
+ ║ ... ║
+ ╚══════════════════════════════════════════════════════════════╝
+ ```
+
+ **IMPORTANT**: For UPDATE detection, compare git diff to find keys where
+ the English value changed. These MUST be re-translated even if target
+ language already has a translation (it's now stale!).
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 2: SYNC - Execute Changes Based on Report ║
+ ═══════════════════════════════════════════════════════════════
+
+ ### Step 2.1: Process ADD Operations (BATCH per language file)
+
+ **CRITICAL WORKFLOW for efficiency:**
+ 1. First, translate ALL new keys for ALL languages mentally
+ 2. Then, for EACH language file, do ONE Edit operation:
+ - Read the file once
+ - Insert ALL new keys at the beginning (right after the opening `{`)
+ - Don't worry about alphabetical order - lint:fix will sort them later
+
+ Example Edit (adding 3 keys to zh-Hans/app.json):
+ ```
+ old_string: '{\n "accessControl"'
+ new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
+ ```
+
+ **IMPORTANT**:
+ - ONE Edit per language file (not one Edit per key!)
+ - Always use the Edit tool. NEVER use bash scripts, node, or jq.
+
+ ### Step 2.2: Process UPDATE Operations
+
+ **IMPORTANT: Special handling for zh-Hans and ja-JP**
+ If zh-Hans or ja-JP files were ALSO modified in the same push:
+ - Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
+ - If found, it means someone manually translated them. Apply these rules:
+
+ 1. **Missing keys**: Still ADD them (completeness required)
+ 2. **Existing translations**: Compare with the NEW English value:
+ - If translation is **completely wrong** or **unrelated** → Update it
+ - If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
+ - When in doubt, **keep the manual translation**
+
+ Example:
+ - English changed: "Save" → "Save Changes"
+ - Manual translation: "保存更改" → Keep it (correct meaning)
+ - Manual translation: "删除" → Update it (completely wrong)
+
+ For other languages:
+ Use Edit tool to replace the old value with the new translation.
+ You can batch multiple updates in one Edit if they are adjacent.
+
+ ### Step 2.3: Process DELETE Operations
+ For extra keys reported by i18n:check:
+ - Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
+ - Or manually remove from target language JSON files
+
+ ## Translation Guidelines
+
+ - PRESERVE all placeholders exactly as-is:
+ - `{{variable}}` - Mustache interpolation
+ - `${variable}` - Template literal
+ - `content` - HTML tags
+ - `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
+
+ **CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them**
+
+ ✅ CORRECT examples:
+ - English: "{{count}} items" → Japanese: "{{count}} 個のアイテム"
+ - English: "{{name}} updated" → Korean: "{{name}} 업데이트됨"
+ - English: "{{email}}" → Chinese: "{{email}}"
+ - English: "Marketplace" → Japanese: "マーケットプレイス"
+
+ ❌ WRONG examples (NEVER do this - will break the application):
+ - "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese)
+ - "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean)
+ - "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese)
+ - "" → "<メール>" ❌ (tag name translated)
+ - "" → "<自定义链接>" ❌ (component name translated)
+
+ - Use appropriate language register (formal/informal) based on existing translations
+ - Match existing translation style in each language
+ - Technical terms: check existing conventions per language
+ - For CJK languages: no spaces between characters unless necessary
+ - For RTL languages (ar-TN, fa-IR): ensure proper text handling
+
+ ## Output Format Requirements
+ - Alphabetical key ordering (if original file uses it)
+ - 2-space indentation
+ - Trailing newline at end of file
+ - Valid JSON (use proper escaping for special characters)
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
+ ═══════════════════════════════════════════════════════════════
+
+ ### Step 3.1: Run Lint Fix (IMPORTANT!)
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
+ ```
+ This ensures:
+ - JSON keys are sorted alphabetically (jsonc/sort-keys rule)
+ - Valid i18n keys (dify-i18n/valid-i18n-keys rule)
+ - No extra keys (dify-i18n/no-extra-keys rule)
+
+ ### Step 3.2: Run Final i18n Check
+ ```bash
+ pnpm --dir ${{ github.workspace }}/web run i18n:check
+ ```
+
+ ### Step 3.3: Fix Any Remaining Issues
+ If check reports issues:
+ - Go back to PHASE 2 for unresolved items
+ - Repeat until check passes
+
+ ### Step 3.4: Generate Final Summary
+ ```
+ ╔══════════════════════════════════════════════════════════════╗
+ ║ SYNC COMPLETED SUMMARY ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ Language │ Added │ Updated │ Deleted │ Status ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
+ ║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
+ ║ ... │ ... │ ... │ ... │ ... ║
+ ╠══════════════════════════════════════════════════════════════╣
+ ║ i18n:check │ PASSED - All keys in sync ║
+ ╚══════════════════════════════════════════════════════════════╝
+ ```
+
+ ## Mode-Specific Behavior
+
+ **SYNC_MODE = "incremental"** (default):
+ - Focus on keys identified from git diff
+ - Also check i18n:check output for any missing/extra keys
+ - Efficient for small changes
+
+ **SYNC_MODE = "full"**:
+ - Compare ALL keys between en-US and each language
+ - Run i18n:check to identify all discrepancies
+ - Use for first-time sync or fixing historical issues
+
+ ## Important Notes
+
+ 1. Always run i18n:check BEFORE and AFTER making changes
+ 2. The check script is the source of truth for missing/extra keys
+ 3. For UPDATE scenario: git diff is the source of truth for changed values
+ 4. Create a single commit with all translation changes
+ 5. If any translation fails, continue with others and report failures
+
+ ═══════════════════════════════════════════════════════════════
+ ║ PHASE 4: COMMIT AND CREATE PR ║
+ ═══════════════════════════════════════════════════════════════
+
+ After all translations are complete and verified:
+
+ ### Step 4.1: Check for changes
+ ```bash
+ git -C ${{ github.workspace }} status --porcelain
+ ```
+
+ If there are changes:
+
+ ### Step 4.2: Create a new branch and commit
+ Run these git commands ONE BY ONE (not combined with &&).
+ **IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
+
+ 1. First, get the timestamp:
+ ```bash
+ date +%Y%m%d-%H%M%S
+ ```
+ (Note the output, e.g., "20260115-143052")
+
+ 2. Then create branch using the timestamp value:
+ ```bash
+ git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
+ ```
+ (Replace "20260115-143052" with the actual timestamp from step 1)
+
+ 3. Stage changes:
+ ```bash
+ git -C ${{ github.workspace }} add web/i18n/
+ ```
+
+ 4. Commit:
+ ```bash
+ git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
+ ```
+
+ 5. Push:
+ ```bash
+ git -C ${{ github.workspace }} push origin HEAD
+ ```
+
+ ### Step 4.3: Create Pull Request
+ ```bash
+ gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
+
+ This PR was automatically generated to sync i18n translation files.
+
+ ### Changes
+ - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
+ - Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
+
+ ### Verification
+ - [x] \`i18n:check\` passed
+ - [x] \`lint:fix\` applied
+
+ 🤖 Generated with Claude Code GitHub Action" --base main
+ ```
+
+ claude_args: |
+ --max-turns 150
+ --allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"
diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml
new file mode 100644
index 0000000000..66a29453b4
--- /dev/null
+++ b/.github/workflows/trigger-i18n-sync.yml
@@ -0,0 +1,66 @@
+name: Trigger i18n Sync on Push
+
+# This workflow bridges the push event to repository_dispatch
+# because claude-code-action doesn't support push events directly.
+# See: https://github.com/langgenius/dify/issues/30743
+
+on:
+ push:
+ branches: [main]
+ paths:
+ - 'web/i18n/en-US/*.json'
+
+permissions:
+ contents: write
+
+jobs:
+ trigger:
+ if: github.repository == 'langgenius/dify'
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6
+ with:
+ fetch-depth: 0
+
+ - name: Detect changed files and generate diff
+ id: detect
+ run: |
+ BEFORE_SHA="${{ github.event.before }}"
+ # Handle edge case: force push may have null/zero SHA
+ if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
+ BEFORE_SHA="HEAD~1"
+ fi
+
+ # Detect changed i18n files
+ changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
+ echo "changed_files=$changed" >> $GITHUB_OUTPUT
+
+ # Generate diff for context
+ git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
+
+ # Truncate if too large (keep first 50KB to match receiving workflow)
+ head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
+ mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
+
+ # Base64 encode the diff for safe JSON transport (portable, single-line)
+ diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
+ echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
+
+ if [ -n "$changed" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ echo "Detected changed files: $changed"
+ else
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ echo "No i18n changes detected"
+ fi
+
+ - name: Trigger i18n sync workflow
+ if: steps.detect.outputs.has_changes == 'true'
+ uses: peter-evans/repository-dispatch@v3
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ event-type: i18n-sync
+ client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'
diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml
index 291171e5c7..7735afdaca 100644
--- a/.github/workflows/vdb-tests.yml
+++ b/.github/workflows/vdb-tests.yml
@@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- name: Free Disk Space
- uses: endersonmenezes/free-disk-space@v2
+ uses: endersonmenezes/free-disk-space@v3
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
- uses: astral-sh/setup-uv@v6
+ uses: astral-sh/setup-uv@v7
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index 3313e58614..191ce56aaa 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -13,46 +13,401 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
+ shell: bash
working-directory: ./web
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
persist-credentials: false
- - name: Check changed files
- id: changed-files
- uses: tj-actions/changed-files@v46
- with:
- files: web/**
-
- name: Install pnpm
- if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
- uses: actions/setup-node@v4
- if: steps.changed-files.outputs.any_changed == 'true'
+ uses: actions/setup-node@v6
with:
- node-version: 22
+ node-version: 24
cache: pnpm
- cache-dependency-path: ./web/package.json
+ cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
+ run: pnpm install --frozen-lockfile
+
+ - name: Run tests
+ run: pnpm test:coverage
+
+ - name: Coverage Summary
+ if: always()
+ id: coverage-summary
+ run: |
+ set -eo pipefail
+
+ COVERAGE_FILE="coverage/coverage-final.json"
+ COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
+
+ if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
+ echo "has_coverage=false" >> "$GITHUB_OUTPUT"
+ echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
+ echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
+ exit 0
+ fi
+
+ echo "has_coverage=true" >> "$GITHUB_OUTPUT"
+
+ node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
+ const fs = require('fs');
+ const path = require('path');
+ let libCoverage = null;
+
+ try {
+ libCoverage = require('istanbul-lib-coverage');
+ } catch (error) {
+ libCoverage = null;
+ }
+
+ const summaryPath = path.join('coverage', 'coverage-summary.json');
+ const finalPath = path.join('coverage', 'coverage-final.json');
+
+ const hasSummary = fs.existsSync(summaryPath);
+ const hasFinal = fs.existsSync(finalPath);
+
+ if (!hasSummary && !hasFinal) {
+ console.log('### Test Coverage Summary :test_tube:');
+ console.log('');
+ console.log('No coverage data found.');
+ process.exit(0);
+ }
+
+ const summary = hasSummary
+ ? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
+ : null;
+ const coverage = hasFinal
+ ? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
+ : null;
+
+ const getLineCoverageFromStatements = (statementMap, statementHits) => {
+ const lineHits = {};
+
+ if (!statementMap || !statementHits) {
+ return lineHits;
+ }
+
+ Object.entries(statementMap).forEach(([key, statement]) => {
+ const line = statement?.start?.line;
+ if (!line) {
+ return;
+ }
+ const hits = statementHits[key] ?? 0;
+ const previous = lineHits[line];
+ lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
+ });
+
+ return lineHits;
+ };
+
+ const getFileCoverage = (entry) => (
+ libCoverage ? libCoverage.createFileCoverage(entry) : null
+ );
+
+ const getLineHits = (entry, fileCoverage) => {
+ const lineHits = entry.l ?? {};
+ if (Object.keys(lineHits).length > 0) {
+ return lineHits;
+ }
+ if (fileCoverage) {
+ return fileCoverage.getLineCoverage();
+ }
+ return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
+ };
+
+ const getUncoveredLines = (entry, fileCoverage, lineHits) => {
+ if (lineHits && Object.keys(lineHits).length > 0) {
+ return Object.entries(lineHits)
+ .filter(([, count]) => count === 0)
+ .map(([line]) => Number(line))
+ .sort((a, b) => a - b);
+ }
+ if (fileCoverage) {
+ return fileCoverage.getUncoveredLines();
+ }
+ return [];
+ };
+
+ const totals = {
+ lines: { covered: 0, total: 0 },
+ statements: { covered: 0, total: 0 },
+ branches: { covered: 0, total: 0 },
+ functions: { covered: 0, total: 0 },
+ };
+ const fileSummaries = [];
+
+ if (summary) {
+ const totalEntry = summary.total ?? {};
+ ['lines', 'statements', 'branches', 'functions'].forEach((key) => {
+ if (totalEntry[key]) {
+ totals[key].covered = totalEntry[key].covered ?? 0;
+ totals[key].total = totalEntry[key].total ?? 0;
+ }
+ });
+
+ Object.entries(summary)
+ .filter(([file]) => file !== 'total')
+ .forEach(([file, data]) => {
+ fileSummaries.push({
+ file,
+ pct: data.lines?.pct ?? data.statements?.pct ?? 0,
+ lines: {
+ covered: data.lines?.covered ?? 0,
+ total: data.lines?.total ?? 0,
+ },
+ });
+ });
+ } else if (coverage) {
+ Object.entries(coverage).forEach(([file, entry]) => {
+ const fileCoverage = getFileCoverage(entry);
+ const lineHits = getLineHits(entry, fileCoverage);
+ const statementHits = entry.s ?? {};
+ const branchHits = entry.b ?? {};
+ const functionHits = entry.f ?? {};
+
+ const lineTotal = Object.keys(lineHits).length;
+ const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
+
+ const statementTotal = Object.keys(statementHits).length;
+ const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
+
+ const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
+ const branchCovered = Object.values(branchHits).reduce(
+ (acc, branches) => acc + branches.filter((n) => n > 0).length,
+ 0,
+ );
+
+ const functionTotal = Object.keys(functionHits).length;
+ const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
+
+ totals.lines.total += lineTotal;
+ totals.lines.covered += lineCovered;
+ totals.statements.total += statementTotal;
+ totals.statements.covered += statementCovered;
+ totals.branches.total += branchTotal;
+ totals.branches.covered += branchCovered;
+ totals.functions.total += functionTotal;
+ totals.functions.covered += functionCovered;
+
+ const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
+
+ fileSummaries.push({
+ file,
+ pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
+ lines: {
+ covered: lineCovered || statementCovered,
+ total: lineTotal || statementTotal,
+ },
+ });
+ });
+ }
+
+ const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
+
+ console.log('### Test Coverage Summary :test_tube:');
+ console.log('');
+ console.log('| Metric | Coverage | Covered / Total |');
+ console.log('|--------|----------|-----------------|');
+ console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
+ console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
+ console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
+ console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
+
+ console.log('');
+ console.log('File coverage (lowest lines first)');
+ console.log('');
+ console.log('```');
+ fileSummaries
+ .sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
+ .slice(0, 25)
+ .forEach(({ file, pct, lines }) => {
+ console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
+ });
+ console.log('```');
+ console.log('');
+
+ if (coverage) {
+ const pctValue = (covered, tot) => {
+ if (tot === 0) {
+ return '0';
+ }
+ return ((covered / tot) * 100)
+ .toFixed(2)
+ .replace(/\.?0+$/, '');
+ };
+
+ const formatLineRanges = (lines) => {
+ if (lines.length === 0) {
+ return '';
+ }
+ const ranges = [];
+ let start = lines[0];
+ let end = lines[0];
+
+ for (let i = 1; i < lines.length; i += 1) {
+ const current = lines[i];
+ if (current === end + 1) {
+ end = current;
+ continue;
+ }
+ ranges.push(start === end ? `${start}` : `${start}-${end}`);
+ start = current;
+ end = current;
+ }
+ ranges.push(start === end ? `${start}` : `${start}-${end}`);
+ return ranges.join(',');
+ };
+
+ const tableTotals = {
+ statements: { covered: 0, total: 0 },
+ branches: { covered: 0, total: 0 },
+ functions: { covered: 0, total: 0 },
+ lines: { covered: 0, total: 0 },
+ };
+ const tableRows = Object.entries(coverage)
+ .map(([file, entry]) => {
+ const fileCoverage = getFileCoverage(entry);
+ const lineHits = getLineHits(entry, fileCoverage);
+ const statementHits = entry.s ?? {};
+ const branchHits = entry.b ?? {};
+ const functionHits = entry.f ?? {};
+
+ const lineTotal = Object.keys(lineHits).length;
+ const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
+ const statementTotal = Object.keys(statementHits).length;
+ const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
+ const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
+ const branchCovered = Object.values(branchHits).reduce(
+ (acc, branches) => acc + branches.filter((n) => n > 0).length,
+ 0,
+ );
+ const functionTotal = Object.keys(functionHits).length;
+ const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
+
+ tableTotals.lines.total += lineTotal;
+ tableTotals.lines.covered += lineCovered;
+ tableTotals.statements.total += statementTotal;
+ tableTotals.statements.covered += statementCovered;
+ tableTotals.branches.total += branchTotal;
+ tableTotals.branches.covered += branchCovered;
+ tableTotals.functions.total += functionTotal;
+ tableTotals.functions.covered += functionCovered;
+
+ const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
+
+ const filePath = entry.path ?? file;
+ const relativePath = path.isAbsolute(filePath)
+ ? path.relative(process.cwd(), filePath)
+ : filePath;
+
+ return {
+ file: relativePath || file,
+ statements: pctValue(statementCovered, statementTotal),
+ branches: pctValue(branchCovered, branchTotal),
+ functions: pctValue(functionCovered, functionTotal),
+ lines: pctValue(lineCovered, lineTotal),
+ uncovered: formatLineRanges(uncoveredLines),
+ };
+ })
+ .sort((a, b) => a.file.localeCompare(b.file));
+
+ const columns = [
+ { key: 'file', header: 'File', align: 'left' },
+ { key: 'statements', header: '% Stmts', align: 'right' },
+ { key: 'branches', header: '% Branch', align: 'right' },
+ { key: 'functions', header: '% Funcs', align: 'right' },
+ { key: 'lines', header: '% Lines', align: 'right' },
+ { key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
+ ];
+
+ const allFilesRow = {
+ file: 'All files',
+ statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
+ branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
+ functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
+ lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
+ uncovered: '',
+ };
+
+ const rowsForOutput = [allFilesRow, ...tableRows];
+ const formatRow = (row) => `| ${columns
+ .map(({ key }) => String(row[key] ?? ''))
+ .join(' | ')} |`;
+ const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
+ const dividerRow = `| ${columns
+ .map(({ align }) => (align === 'right' ? '---:' : ':---'))
+ .join(' | ')} |`;
+
+ console.log('');
+ console.log('Vitest coverage table');
+ console.log('');
+ console.log(headerRow);
+ console.log(dividerRow);
+ rowsForOutput.forEach((row) => console.log(formatRow(row)));
+ console.log('');
+ }
+ NODE
+
+ - name: Upload Coverage Artifact
+ if: steps.coverage-summary.outputs.has_coverage == 'true'
+ uses: actions/upload-artifact@v6
+ with:
+ name: web-coverage-report
+ path: web/coverage
+ retention-days: 30
+ if-no-files-found: error
+
+ web-build:
+ name: Web Build
+ runs-on: ubuntu-latest
+ defaults:
+ run:
+ working-directory: ./web
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v6
+ with:
+ persist-credentials: false
+
+ - name: Check changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v47
+ with:
+ files: |
+ web/**
+ .github/workflows/web-tests.yml
+
+ - name: Install pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ package_json_file: web/package.json
+ run_install: false
+
+ - name: Setup NodeJS
+ uses: actions/setup-node@v6
+ if: steps.changed-files.outputs.any_changed == 'true'
+ with:
+ node-version: 24
+ cache: pnpm
+ cache-dependency-path: ./web/pnpm-lock.yaml
+
+ - name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- - name: Check i18n types synchronization
+ - name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
- run: pnpm run check:i18n-types
-
- - name: Run tests
- if: steps.changed-files.outputs.any_changed == 'true'
- working-directory: ./web
- run: pnpm test
+ run: pnpm run build
diff --git a/.gitignore b/.gitignore
index 79ba44b207..7bd919f095 100644
--- a/.gitignore
+++ b/.gitignore
@@ -139,7 +139,6 @@ pyrightconfig.json
.idea/'
.DS_Store
-web/.vscode/settings.json
# Intellij IDEA Files
.idea/*
@@ -189,12 +188,14 @@ docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d
+docker/volumes/iris/*
docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep
docker/middleware.env
docker/docker-compose.override.yaml
+docker/env-backup/*
sdks/python-client/build
sdks/python-client/dist
@@ -204,7 +205,6 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template
!.vscode/README.md
api/.vscode
-web/.vscode
# vscode Code History Extension
.history
@@ -219,15 +219,6 @@ plugins.jsonl
# mise
mise.toml
-# Next.js build output
-.next/
-
-# PWA generated files
-web/public/sw.js
-web/public/sw.js.map
-web/public/workbox-*.js
-web/public/workbox-*.js.map
-web/public/fallback-*.js
# AI Assistant
.roo/
@@ -244,3 +235,4 @@ scripts/stress-test/reports/
# settings
*.local.json
+*.local.md
diff --git a/.mcp.json b/.mcp.json
deleted file mode 100644
index 8eceaf9ead..0000000000
--- a/.mcp.json
+++ /dev/null
@@ -1,34 +0,0 @@
-{
- "mcpServers": {
- "context7": {
- "type": "http",
- "url": "https://mcp.context7.com/mcp"
- },
- "sequential-thinking": {
- "type": "stdio",
- "command": "npx",
- "args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
- "env": {}
- },
- "github": {
- "type": "stdio",
- "command": "npx",
- "args": ["-y", "@modelcontextprotocol/server-github"],
- "env": {
- "GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
- }
- },
- "fetch": {
- "type": "stdio",
- "command": "uvx",
- "args": ["mcp-server-fetch"],
- "env": {}
- },
- "playwright": {
- "type": "stdio",
- "command": "npx",
- "args": ["-y", "@playwright/mcp@latest"],
- "env": {}
- }
- }
- }
\ No newline at end of file
diff --git a/.nvmrc b/.nvmrc
deleted file mode 100644
index 7af24b7ddb..0000000000
--- a/.nvmrc
+++ /dev/null
@@ -1 +0,0 @@
-22.11.0
diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template
index cb934d01b5..bdded1e73e 100644
--- a/.vscode/launch.json.template
+++ b/.vscode/launch.json.template
@@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
- "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
+ "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],
diff --git a/.windsurf/rules/testing.md b/.windsurf/rules/testing.md
deleted file mode 100644
index 64fec20cb8..0000000000
--- a/.windsurf/rules/testing.md
+++ /dev/null
@@ -1,5 +0,0 @@
-# Windsurf Testing Rules
-
-- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
-- Honor every requirement in that document when generating or accepting tests.
-- When proposing or saving tests, re-read that document and follow every requirement.
diff --git a/AGENTS.md b/AGENTS.md
index 782861ad36..51fa6e4527 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -7,27 +7,18 @@ Dify is an open-source platform for developing LLM applications with an intuitiv
The codebase is split into:
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
-- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
+- **Frontend Web** (`/web`): Next.js application using TypeScript and React
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Backend Workflow
+- Read `api/AGENTS.md` for details
- Run backend CLI commands through `uv run --project api `.
-
-- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-
-- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
-
- Integration tests are CI-only and are not expected to run in the local environment.
## Frontend Workflow
-```bash
-cd web
-pnpm lint:fix
-pnpm type-check:tsgo
-pnpm test
-```
+- Read `web/AGENTS.md` for details
## Testing & Quality Practices
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 20a7d6c6f6..d7f007af67 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -77,7 +77,7 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
-**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
+**Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend
diff --git a/Makefile b/Makefile
index 07afd8187e..984e8676ee 100644
--- a/Makefile
+++ b/Makefile
@@ -60,19 +60,28 @@ check:
@echo "✅ Code check complete"
lint:
- @echo "🔧 Running ruff format, check with fixes, and import linter..."
- @uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
+ @echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
+ @uv run --project api --dev ruff format ./api
+ @uv run --project api --dev ruff check --fix ./api
@uv run --directory api --dev lint-imports
+ @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
type-check:
- @echo "📝 Running type check with basedpyright..."
- @uv run --directory api --dev basedpyright
- @echo "✅ Type check complete"
+ @echo "📝 Running type checks (basedpyright + mypy + ty)..."
+ @./dev/basedpyright-check $(PATH_TO_CHECK)
+ @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
+ @cd api && uv run ty check
+ @echo "✅ Type checks complete"
test:
@echo "🧪 Running backend unit tests..."
- @uv run --project api --dev dev/pytest/pytest_unit_tests.sh
+ @if [ -n "$(TARGET_TESTS)" ]; then \
+ echo "Target: $(TARGET_TESTS)"; \
+ uv run --project api --dev pytest $(TARGET_TESTS); \
+ else \
+ PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
+ fi
@echo "✅ Tests complete"
# Build Docker images
@@ -122,9 +131,9 @@ help:
@echo "Backend Code Quality:"
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
- @echo " make lint - Format and fix code with ruff"
- @echo " make type-check - Run type checking with basedpyright"
- @echo " make test - Run backend unit tests"
+ @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
+ @echo " make type-check - Run type checks (basedpyright, mypy, ty)"
+ @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
diff --git a/api/.env.example b/api/.env.example
index 516a119d98..8bd2c706c1 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
+# Workflow run and Conversation archive storage (S3-compatible)
+ARCHIVE_STORAGE_ENABLED=false
+ARCHIVE_STORAGE_ENDPOINT=
+ARCHIVE_STORAGE_ARCHIVE_BUCKET=
+ARCHIVE_STORAGE_EXPORT_BUCKET=
+ARCHIVE_STORAGE_ACCESS_KEY=
+ARCHIVE_STORAGE_SECRET_KEY=
+ARCHIVE_STORAGE_REGION=auto
+
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
@@ -116,6 +125,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
+ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
@@ -127,12 +137,14 @@ TENCENT_COS_SECRET_KEY=your-secret-key
TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
+TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
+HUAWEI_OBS_PATH_STYLE=false
# Baidu OBS Storage Configuration
BAIDU_OBS_BUCKET_NAME=your-bucket-name
@@ -405,6 +417,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
+# Optional: override the local hostname used for SMTP HELO/EHLO
+SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
@@ -490,6 +504,8 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
+# Log output format: text or json
+LOG_OUTPUT_FORMAT=text
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
@@ -543,6 +559,29 @@ APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
+# Aliyun SLS Logstore Configuration
+# Aliyun Access Key ID
+ALIYUN_SLS_ACCESS_KEY_ID=
+# Aliyun Access Key Secret
+ALIYUN_SLS_ACCESS_KEY_SECRET=
+# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
+ALIYUN_SLS_ENDPOINT=
+# Aliyun SLS Region (e.g., cn-hangzhou)
+ALIYUN_SLS_REGION=
+# Aliyun SLS Project Name
+ALIYUN_SLS_PROJECT_NAME=
+# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage)
+ALIYUN_SLS_LOGSTORE_TTL=365
+# Enable dual-write to both SLS LogStore and SQL database (default: false)
+LOGSTORE_DUAL_WRITE_ENABLED=false
+# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
+# Useful for migration scenarios where historical data exists only in SQL database
+LOGSTORE_DUAL_READ_ENABLED=true
+# Control flag for whether to write the `graph` field to LogStore.
+# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
+# otherwise write an empty {} instead. Defaults to writing the `graph` field.
+LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true
+
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@@ -552,6 +591,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
+ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
@@ -577,6 +617,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
+PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
@@ -660,3 +701,19 @@ SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10
+
+# Maximum allowed CSV file size for annotation import in megabytes
+ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
+#Maximum number of annotation records allowed in a single import
+ANNOTATION_IMPORT_MAX_RECORDS=10000
+# Minimum number of annotation records required in a single import
+ANNOTATION_IMPORT_MIN_RECORDS=1
+ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
+ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
+# Maximum number of concurrent annotation import tasks per tenant
+ANNOTATION_IMPORT_MAX_CONCURRENT=5
+# Sandbox expired records clean configuration
+SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
+SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
+SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
+SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
diff --git a/api/.importlinter b/api/.importlinter
index 24ece72b30..9dad254560 100644
--- a/api/.importlinter
+++ b/api/.importlinter
@@ -3,9 +3,11 @@ root_packages =
core
configs
controllers
+ extensions
models
tasks
services
+include_external_packages = True
[importlinter:contract:workflow]
name = Workflow
@@ -25,7 +27,9 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
- core.workflow.nodes.node_factory -> core.workflow.graph
+ core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
+ core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
+
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
@@ -33,6 +37,324 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
+[importlinter:contract:workflow-infrastructure-dependencies]
+name = Workflow Infrastructure Dependencies
+type = forbidden
+source_modules =
+ core.workflow
+forbidden_modules =
+ extensions.ext_database
+ extensions.ext_redis
+allow_indirect_imports = True
+ignore_imports =
+ core.workflow.nodes.agent.agent_node -> extensions.ext_database
+ core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
+ core.workflow.nodes.llm.file_saver -> extensions.ext_database
+ core.workflow.nodes.llm.llm_utils -> extensions.ext_database
+ core.workflow.nodes.llm.node -> extensions.ext_database
+ core.workflow.nodes.tool.tool_node -> extensions.ext_database
+ core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
+ core.workflow.graph_engine.manager -> extensions.ext_redis
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
+
+[importlinter:contract:workflow-external-imports]
+name = Workflow External Imports
+type = forbidden
+source_modules =
+ core.workflow
+forbidden_modules =
+ configs
+ controllers
+ extensions
+ models
+ services
+ tasks
+ core.agent
+ core.app
+ core.base
+ core.callback_handler
+ core.datasource
+ core.db
+ core.entities
+ core.errors
+ core.extension
+ core.external_data_tool
+ core.file
+ core.helper
+ core.hosting_configuration
+ core.indexing_runner
+ core.llm_generator
+ core.logging
+ core.mcp
+ core.memory
+ core.model_manager
+ core.moderation
+ core.ops
+ core.plugin
+ core.prompt
+ core.provider_manager
+ core.rag
+ core.repositories
+ core.schemas
+ core.tools
+ core.trigger
+ core.variables
+ignore_imports =
+ core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
+ core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
+ core.workflow.workflow_entry -> core.app.workflow.layers.observability
+ core.workflow.nodes.agent.agent_node -> core.model_manager
+ core.workflow.nodes.agent.agent_node -> core.provider_manager
+ core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
+ core.workflow.nodes.datasource.datasource_node -> models.model
+ core.workflow.nodes.datasource.datasource_node -> models.tools
+ core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
+ core.workflow.nodes.document_extractor.node -> configs
+ core.workflow.nodes.document_extractor.node -> core.file.file_manager
+ core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
+ core.workflow.nodes.http_request.entities -> configs
+ core.workflow.nodes.http_request.executor -> configs
+ core.workflow.nodes.http_request.executor -> core.file.file_manager
+ core.workflow.nodes.http_request.node -> configs
+ core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
+ core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.llm.llm_utils -> configs
+ core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.llm.llm_utils -> core.file.models
+ core.workflow.nodes.llm.llm_utils -> core.model_manager
+ core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.llm.llm_utils -> models.model
+ core.workflow.nodes.llm.llm_utils -> models.provider
+ core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
+ core.workflow.nodes.llm.node -> core.tools.signature
+ core.workflow.nodes.template_transform.template_transform_node -> configs
+ core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
+ core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
+ core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
+ core.workflow.workflow_entry -> configs
+ core.workflow.workflow_entry -> models.workflow
+ core.workflow.nodes.agent.agent_node -> core.agent.entities
+ core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
+ core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.start.entities -> core.app.app_config.entities
+ core.workflow.nodes.start.start_node -> core.app.app_config.entities
+ core.workflow.workflow_entry -> core.app.apps.exc
+ core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
+ core.workflow.workflow_entry -> core.app.workflow.node_factory
+ core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
+ core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
+ core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
+ core.workflow.node_events.node -> core.file
+ core.workflow.nodes.agent.agent_node -> core.file
+ core.workflow.nodes.datasource.datasource_node -> core.file
+ core.workflow.nodes.datasource.datasource_node -> core.file.enums
+ core.workflow.nodes.document_extractor.node -> core.file
+ core.workflow.nodes.http_request.executor -> core.file.enums
+ core.workflow.nodes.http_request.node -> core.file
+ core.workflow.nodes.http_request.node -> core.file.file_manager
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
+ core.workflow.nodes.list_operator.node -> core.file
+ core.workflow.nodes.llm.file_saver -> core.file
+ core.workflow.nodes.llm.llm_utils -> core.variables.segments
+ core.workflow.nodes.llm.node -> core.file
+ core.workflow.nodes.llm.node -> core.file.file_manager
+ core.workflow.nodes.llm.node -> core.file.models
+ core.workflow.nodes.loop.entities -> core.variables.types
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
+ core.workflow.nodes.protocols -> core.file
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
+ core.workflow.nodes.tool.tool_node -> core.file
+ core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
+ core.workflow.nodes.tool.tool_node -> models
+ core.workflow.nodes.trigger_webhook.node -> core.file
+ core.workflow.runtime.variable_pool -> core.file
+ core.workflow.runtime.variable_pool -> core.file.file_manager
+ core.workflow.system_variable -> core.file.models
+ core.workflow.utils.condition.processor -> core.file
+ core.workflow.utils.condition.processor -> core.file.file_manager
+ core.workflow.workflow_entry -> core.file.models
+ core.workflow.workflow_type_encoder -> core.file.models
+ core.workflow.nodes.agent.agent_node -> models.model
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
+ core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
+ core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
+ core.workflow.nodes.datasource.datasource_node -> core.variables.variables
+ core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
+ core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
+ core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
+ core.workflow.nodes.llm.node -> core.helper.code_executor
+ core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
+ core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
+ core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
+ core.workflow.nodes.llm.node -> core.model_manager
+ core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
+ core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
+ core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
+ core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
+ core.workflow.nodes.llm.node -> models.dataset
+ core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
+ core.workflow.nodes.llm.file_saver -> core.tools.signature
+ core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
+ core.workflow.nodes.tool.tool_node -> core.tools.errors
+ core.workflow.conversation_variable_updater -> core.variables
+ core.workflow.graph_engine.entities.commands -> core.variables.variables
+ core.workflow.nodes.agent.agent_node -> core.variables.segments
+ core.workflow.nodes.answer.answer_node -> core.variables
+ core.workflow.nodes.code.code_node -> core.variables.segments
+ core.workflow.nodes.code.code_node -> core.variables.types
+ core.workflow.nodes.code.entities -> core.variables.types
+ core.workflow.nodes.datasource.datasource_node -> core.variables.segments
+ core.workflow.nodes.document_extractor.node -> core.variables
+ core.workflow.nodes.document_extractor.node -> core.variables.segments
+ core.workflow.nodes.http_request.executor -> core.variables.segments
+ core.workflow.nodes.http_request.node -> core.variables.segments
+ core.workflow.nodes.iteration.iteration_node -> core.variables
+ core.workflow.nodes.iteration.iteration_node -> core.variables.segments
+ core.workflow.nodes.iteration.iteration_node -> core.variables.variables
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
+ core.workflow.nodes.list_operator.node -> core.variables
+ core.workflow.nodes.list_operator.node -> core.variables.segments
+ core.workflow.nodes.llm.node -> core.variables
+ core.workflow.nodes.loop.loop_node -> core.variables
+ core.workflow.nodes.parameter_extractor.entities -> core.variables.types
+ core.workflow.nodes.parameter_extractor.exc -> core.variables.types
+ core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
+ core.workflow.nodes.tool.tool_node -> core.variables.segments
+ core.workflow.nodes.tool.tool_node -> core.variables.variables
+ core.workflow.nodes.trigger_webhook.node -> core.variables.types
+ core.workflow.nodes.trigger_webhook.node -> core.variables.variables
+ core.workflow.nodes.variable_aggregator.entities -> core.variables.types
+ core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
+ core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
+ core.workflow.nodes.variable_assigner.v1.node -> core.variables
+ core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
+ core.workflow.nodes.variable_assigner.v2.node -> core.variables
+ core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
+ core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
+ core.workflow.runtime.read_only_wrappers -> core.variables.segments
+ core.workflow.runtime.variable_pool -> core.variables
+ core.workflow.runtime.variable_pool -> core.variables.consts
+ core.workflow.runtime.variable_pool -> core.variables.segments
+ core.workflow.runtime.variable_pool -> core.variables.variables
+ core.workflow.utils.condition.processor -> core.variables
+ core.workflow.utils.condition.processor -> core.variables.segments
+ core.workflow.variable_loader -> core.variables
+ core.workflow.variable_loader -> core.variables.consts
+ core.workflow.workflow_type_encoder -> core.variables
+ core.workflow.graph_engine.manager -> extensions.ext_redis
+ core.workflow.nodes.agent.agent_node -> extensions.ext_database
+ core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
+ core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
+ core.workflow.nodes.llm.file_saver -> extensions.ext_database
+ core.workflow.nodes.llm.llm_utils -> extensions.ext_database
+ core.workflow.nodes.llm.node -> extensions.ext_database
+ core.workflow.nodes.tool.tool_node -> extensions.ext_database
+ core.workflow.workflow_entry -> extensions.otel.runtime
+ core.workflow.nodes.agent.agent_node -> models
+ core.workflow.nodes.base.node -> models.enums
+ core.workflow.nodes.llm.llm_utils -> models.provider_ids
+ core.workflow.nodes.llm.node -> models.model
+ core.workflow.workflow_entry -> models.enums
+ core.workflow.nodes.agent.agent_node -> services
+ core.workflow.nodes.tool.tool_node -> services
+
+[importlinter:contract:model-runtime-no-internal-imports]
+name = Model Runtime Internal Imports
+type = forbidden
+source_modules =
+ core.model_runtime
+forbidden_modules =
+ configs
+ controllers
+ extensions
+ models
+ services
+ tasks
+ core.agent
+ core.app
+ core.base
+ core.callback_handler
+ core.datasource
+ core.db
+ core.entities
+ core.errors
+ core.extension
+ core.external_data_tool
+ core.file
+ core.helper
+ core.hosting_configuration
+ core.indexing_runner
+ core.llm_generator
+ core.logging
+ core.mcp
+ core.memory
+ core.model_manager
+ core.moderation
+ core.ops
+ core.plugin
+ core.prompt
+ core.provider_manager
+ core.rag
+ core.repositories
+ core.schemas
+ core.tools
+ core.trigger
+ core.variables
+ core.workflow
+ignore_imports =
+ core.model_runtime.model_providers.__base.ai_model -> configs
+ core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
+ core.model_runtime.model_providers.__base.large_language_model -> configs
+ core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
+ core.model_runtime.model_providers.model_provider_factory -> configs
+ core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
+ core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
+
[importlinter:contract:rsc]
name = RSC
type = layers
diff --git a/api/.ruff.toml b/api/.ruff.toml
index 7206f7fa0f..8db0cbcb21 100644
--- a/api/.ruff.toml
+++ b/api/.ruff.toml
@@ -1,4 +1,8 @@
-exclude = ["migrations/*"]
+exclude = [
+ "migrations/*",
+ ".git",
+ ".git/**",
+]
line-length = 120
[format]
diff --git a/api/AGENTS.md b/api/AGENTS.md
index 17398ec4b8..13adb42276 100644
--- a/api/AGENTS.md
+++ b/api/AGENTS.md
@@ -1,62 +1,186 @@
-# Agent Skill Index
+# API Agent Guide
-Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
+## Notes for Agent (must-check)
-______________________________________________________________________
+Before changing any backend code under `api/`, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec.
-## Platform Foundations
+Look for:
-- **[Infrastructure Overview](agent_skills/infra.md)**\
- When to read this:
+- The module (file) docstring at the top of a source code file
+- Docstrings on classes and functions/methods
+- Paragraph/block comments for non-obvious logic
- - You need to understand where a feature belongs in the architecture.
- - You’re wiring storage, Redis, vector stores, or OTEL.
- - You’re about to add CLI commands or async jobs.\
- What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
+### What to write where
-- **[Coding Style](agent_skills/coding_style.md)**\
- When to read this:
+- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse.
+- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing.
+ - Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard.
+ - Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes.
+- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used).
+ - If the class is intentionally stateful, note what state exists and what methods mutate it.
+ - If concurrency/async assumptions matter, state them explicitly.
+- **Function/method docstring**: behavioural contract.
+ - Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions.
+ - Add examples only when they prevent misuse.
+- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states.
+ - Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality.
- - You’re writing or reviewing backend code and need the authoritative checklist.
- - You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- - You want the exact lint/type/test commands used in PRs.\
- Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
+### Rules (must follow)
-______________________________________________________________________
+In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments.
-## Plugin & Extension Development
+- **Before working**
+ - Read the notes in the area you’ll touch; treat them as part of the spec.
+ - If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality.
+ - If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour).
+- **During working**
+ - Keep the notes in sync as you discover constraints, make decisions, or change approach.
+ - If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants.
+ - Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct.
+ - Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions.
+- **When finishing**
+ - Update the notes to reflect what changed, why, and any new edge cases/tests.
+ - Remove or rewrite any comments that could be mistaken as current guidance but no longer apply.
+ - Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery.
-- **[Plugin Systems](agent_skills/plugin.md)**\
- When to read this:
+## Coding Style
- - You’re building or debugging a marketplace plugin.
- - You need to know how manifests, providers, daemons, and migrations fit together.\
- What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
+This is the default standard for backend code in this repo. Follow it for new code and use it as the checklist when reviewing changes.
-- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
- When to read this:
+### Linting & Formatting
- - You must integrate OAuth for a plugin or datasource.
- - You’re handling credential encryption or refresh flows.\
- Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
+- Use Ruff for formatting and linting (follow `.ruff.toml`).
+- Keep each line under 120 characters (including spaces).
-______________________________________________________________________
+### Naming Conventions
-## Workflow Entry & Execution
+- Use `snake_case` for variables and functions.
+- Use `PascalCase` for classes.
+- Use `UPPER_CASE` for constants.
-- **[Trigger Concepts](agent_skills/trigger.md)**\
- When to read this:
- - You’re debugging why a workflow didn’t start.
- - You’re adding a new trigger type or hook.
- - You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
- Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
+### Typing & Class Layout
-______________________________________________________________________
+- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
+- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
+- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
-## Additional Notes for Agents
+```python
+from datetime import datetime
-- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
-- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
-- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
-- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
-- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
+
+class Example:
+ user_id: str
+ created_at: datetime
+
+ def __init__(self, user_id: str, created_at: datetime) -> None:
+ self.user_id = user_id
+ self.created_at = created_at
+```
+
+### General Rules
+
+- Use Pydantic v2 conventions.
+- Use `uv` for Python package management in this repo (usually with `--project api`).
+- Prefer simple functions over small “utility classes” for lightweight helpers.
+- Avoid implementing dunder methods unless it’s clearly needed and matches existing patterns.
+- Never start long-running services as part of agent work (`uv run app.py`, `flask run`, etc.); running tests is allowed.
+- Keep files below ~800 lines; split when necessary.
+- Keep code readable and explicit—avoid clever hacks.
+
+### Architecture & Boundaries
+
+- Mirror the layered architecture: controller → service → core/domain.
+- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
+- Optimise for observability: deterministic control flow, clear logging, actionable errors.
+
+### Logging & Errors
+
+- Never use `print`; use a module-level logger:
+ - `logger = logging.getLogger(__name__)`
+- Include tenant/app/workflow identifiers in log context when relevant.
+- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate them into HTTP responses in controllers.
+- Log retryable events at `warning`, terminal failures at `error`.
+
+### SQLAlchemy Patterns
+
+- Models inherit from `models.base.TypeBase`; do not create ad-hoc metadata or engines.
+- Open sessions with context managers:
+
+```python
+from sqlalchemy.orm import Session
+
+with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(Workflow).where(
+ Workflow.id == workflow_id,
+ Workflow.tenant_id == tenant_id,
+ )
+ workflow = session.execute(stmt).scalar_one_or_none()
+```
+
+- Prefer SQLAlchemy expressions; avoid raw SQL unless necessary.
+- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
+- Introduce repository abstractions only for very large tables (e.g., workflow executions) or when alternative storage strategies are required.
+
+### Storage & External I/O
+
+- Access storage via `extensions.ext_storage.storage`.
+- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
+- Background tasks that touch storage must be idempotent, and should log relevant object identifiers.
+
+### Pydantic Usage
+
+- Define DTOs with Pydantic v2 models and forbid extras by default.
+- Use `@field_validator` / `@model_validator` for domain rules.
+
+Example:
+
+```python
+from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
+
+
+class TriggerConfig(BaseModel):
+ endpoint: HttpUrl
+ secret: str
+
+ model_config = ConfigDict(extra="forbid")
+
+ @field_validator("secret")
+ def ensure_secret_prefix(cls, value: str) -> str:
+ if not value.startswith("dify_"):
+ raise ValueError("secret must start with dify_")
+ return value
+```
+
+### Generics & Protocols
+
+- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
+- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
+- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
+
+### Tooling & Checks
+
+Quick checks while iterating:
+
+- Format: `make format`
+- Lint (includes auto-fix): `make lint`
+- Type check: `make type-check`
+- Targeted tests: `make test TARGET_TESTS=./api/tests/`
+
+Before opening a PR / submitting:
+
+- `make lint`
+- `make type-check`
+- `make test`
+
+### Controllers & Services
+
+- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
+- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
+- Document non-obvious behaviour with concise docstrings and comments.
+
+### Miscellaneous
+
+- Use `configs.dify_config` for configuration—never read environment variables directly.
+- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
+- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
+- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/Dockerfile b/api/Dockerfile
index 02df91bfc1..a08d4e3aab 100644
--- a/api/Dockerfile
+++ b/api/Dockerfile
@@ -50,16 +50,33 @@ WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
+ARG NODE_MAJOR=22
+ARG NODE_PACKAGE_VERSION=22.21.0-1nodesource1
+ARG NODESOURCE_KEY_FPR=6F71F525282841EEDAF851B42F59B5F99B1BE0B4
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
+ && apt-get install -y --no-install-recommends \
+ ca-certificates \
+ curl \
+ gnupg \
+ && mkdir -p /etc/apt/keyrings \
+ && curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg \
+ && gpg --show-keys --with-colons /tmp/nodesource.gpg \
+ | awk -F: '/^fpr:/ {print $10}' \
+ | grep -Fx "${NODESOURCE_KEY_FPR}" \
+ && gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg \
+ && rm -f /tmp/nodesource.gpg \
+ && echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" \
+ > /etc/apt/sources.list.d/nodesource.list \
+ && apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
- curl nodejs \
+ nodejs=${NODE_PACKAGE_VERSION} \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
@@ -79,7 +96,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
-RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
+RUN mkdir -p /usr/local/share/nltk_data \
+ && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
diff --git a/api/README.md b/api/README.md
index 2dab2ec6e6..9d89b490b0 100644
--- a/api/README.md
+++ b/api/README.md
@@ -1,6 +1,6 @@
# Dify Backend API
-## Usage
+## Setup and Run
> [!IMPORTANT]
>
@@ -8,48 +8,77 @@
> [`uv`](https://docs.astral.sh/uv/) as the package manager
> for Dify API backend service.
-1. Start the docker-compose stack
+`uv` and `pnpm` are required to run the setup and development commands below.
- The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
+### Using scripts (recommended)
+
+The scripts resolve paths relative to their location, so you can run them from anywhere.
+
+1. Run setup (copies env files and installs dependencies).
```bash
- cd ../docker
- cp middleware.env.example middleware.env
- # change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
- docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
- cd ../api
+ ./dev/setup
```
-1. Copy `.env.example` to `.env`
+1. Review `api/.env`, `web/.env.local`, and `docker/middleware.env` values (see the `SECRET_KEY` note below).
- ```cli
- cp .env.example .env
+1. Start middleware (PostgreSQL/Redis/Weaviate).
+
+ ```bash
+ ./dev/start-docker-compose
```
-> [!IMPORTANT]
->
-> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
+1. Start backend (runs migrations first).
-1. Generate a `SECRET_KEY` in the `.env` file.
-
- bash for Linux
-
- ```bash for Linux
- sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
+ ```bash
+ ./dev/start-api
```
- bash for Mac
+1. Start Dify [web](../web) service.
- ```bash for Mac
- secret_key=$(openssl rand -base64 42)
- sed -i '' "/^SECRET_KEY=/c\\
- SECRET_KEY=${secret_key}" .env
+ ```bash
+ ./dev/start-web
```
-1. Create environment.
+1. Set up your application by visiting `http://localhost:3000`.
- Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
- First, you need to add the uv package manager, if you don't have it already.
+1. Optional: start the worker service (async tasks, runs from `api`).
+
+ ```bash
+ ./dev/start-worker
+ ```
+
+1. Optional: start Celery Beat (scheduled tasks).
+
+ ```bash
+ ./dev/start-beat
+ ```
+
+### Manual commands
+
+
+Show manual setup and run steps
+
+These commands assume you start from the repository root.
+
+1. Start the docker-compose stack.
+
+ The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
+
+ ```bash
+ cp docker/middleware.env.example docker/middleware.env
+ # Use mysql or another vector database profile if you are not using postgres/weaviate.
+ docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
+ ```
+
+1. Copy env files.
+
+ ```bash
+ cp api/.env.example api/.env
+ cp web/.env.example web/.env.local
+ ```
+
+1. Install UV if needed.
```bash
pip install uv
@@ -57,60 +86,96 @@
brew install uv
```
-1. Install dependencies
+1. Install API dependencies.
```bash
- uv sync --dev
+ cd api
+ uv sync --group dev
```
-1. Run migrate
-
- Before the first launch, migrate the database to the latest version.
+1. Install web dependencies.
```bash
+ cd web
+ pnpm install
+ cd ..
+ ```
+
+1. Start backend (runs migrations first, in a new terminal).
+
+ ```bash
+ cd api
uv run flask db upgrade
- ```
-
-1. Start backend
-
- ```bash
uv run flask run --host 0.0.0.0 --port=5001 --debug
```
-1. Start Dify [web](../web) service.
+1. Start Dify [web](../web) service (in a new terminal).
-1. Setup your application by visiting `http://localhost:3000`.
+ ```bash
+ cd web
+ pnpm dev:inspect
+ ```
-1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
+1. Set up your application by visiting `http://localhost:3000`.
-```bash
-uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
-```
+1. Optional: start the worker service (async tasks, in a new terminal).
-Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
+ ```bash
+ cd api
+ uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
+ ```
-```bash
-uv run celery -A app.celery beat
-```
+1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
+
+ ```bash
+ cd api
+ uv run celery -A app.celery beat
+ ```
+
+
+
+### Environment notes
+
+> [!IMPORTANT]
+>
+> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
+
+- Generate a `SECRET_KEY` in the `.env` file.
+
+ bash for Linux
+
+ ```bash
+ sed -i "/^SECRET_KEY=/c\\SECRET_KEY=$(openssl rand -base64 42)" .env
+ ```
+
+ bash for Mac
+
+ ```bash
+ secret_key=$(openssl rand -base64 42)
+ sed -i '' "/^SECRET_KEY=/c\\
+ SECRET_KEY=${secret_key}" .env
+ ```
## Testing
1. Install dependencies for both the backend and the test environment
```bash
- uv sync --dev
+ cd api
+ uv sync --group dev
```
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
```bash
+ cd api
uv run pytest # Run all tests
uv run pytest tests/unit_tests/ # Unit tests only
uv run pytest tests/integration_tests/ # Integration tests
# Code quality
- ../dev/reformat # Run all formatters and linters
- uv run ruff check --fix ./ # Fix linting issues
- uv run ruff format ./ # Format code
- uv run basedpyright . # Type checking
+ ./dev/reformat # Run all formatters and linters
+ uv run ruff check --fix ./ # Fix linting issues
+ uv run ruff format ./ # Format code
+ uv run basedpyright . # Type checking
```
diff --git a/api/agent_skills/coding_style.md b/api/agent_skills/coding_style.md
deleted file mode 100644
index a2b66f0bd5..0000000000
--- a/api/agent_skills/coding_style.md
+++ /dev/null
@@ -1,115 +0,0 @@
-## Linter
-
-- Always follow `.ruff.toml`.
-- Run `uv run ruff check --fix --unsafe-fixes`.
-- Keep each line under 100 characters (including spaces).
-
-## Code Style
-
-- `snake_case` for variables and functions.
-- `PascalCase` for classes.
-- `UPPER_CASE` for constants.
-
-## Rules
-
-- Use Pydantic v2 standard.
-- Use `uv` for package management.
-- Do not override dunder methods like `__init__`, `__iadd__`, etc.
-- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
-- Prefer simple functions over classes for lightweight helpers.
-- Keep files below 800 lines; split when necessary.
-- Keep code readable—no clever hacks.
-- Never use `print`; log with `logger = logging.getLogger(__name__)`.
-
-## Guiding Principles
-
-- Mirror the project’s layered architecture: controller → service → core/domain.
-- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
-- Optimise for observability: deterministic control flow, clear logging, actionable errors.
-
-## SQLAlchemy Patterns
-
-- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
-
-- Open sessions with context managers:
-
- ```python
- from sqlalchemy.orm import Session
-
- with Session(db.engine, expire_on_commit=False) as session:
- stmt = select(Workflow).where(
- Workflow.id == workflow_id,
- Workflow.tenant_id == tenant_id,
- )
- workflow = session.execute(stmt).scalar_one_or_none()
- ```
-
-- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
-
-- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
-
-- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
-
-## Storage & External IO
-
-- Access storage via `extensions.ext_storage.storage`.
-- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
-- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
-
-## Pydantic Usage
-
-- Define DTOs with Pydantic v2 models and forbid extras by default.
-
-- Use `@field_validator` / `@model_validator` for domain rules.
-
-- Example:
-
- ```python
- from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
-
- class TriggerConfig(BaseModel):
- endpoint: HttpUrl
- secret: str
-
- model_config = ConfigDict(extra="forbid")
-
- @field_validator("secret")
- def ensure_secret_prefix(cls, value: str) -> str:
- if not value.startswith("dify_"):
- raise ValueError("secret must start with dify_")
- return value
- ```
-
-## Generics & Protocols
-
-- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
-- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
-- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
-
-## Error Handling & Logging
-
-- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
-- Declare `logger = logging.getLogger(__name__)` at module top.
-- Include tenant/app/workflow identifiers in log context.
-- Log retryable events at `warning`, terminal failures at `error`.
-
-## Tooling & Checks
-
-- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
-- Type checks: `uv run --directory api --dev basedpyright`.
-- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
-- Run all of the above before submitting your work.
-
-## Controllers & Services
-
-- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
-- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
-- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
-- Document non-obvious behaviour with concise comments.
-
-## Miscellaneous
-
-- Use `configs.dify_config` for configuration—never read environment variables directly.
-- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
-- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
-- Keep experimental scripts under `dev/`; do not ship them in production builds.
diff --git a/api/agent_skills/infra.md b/api/agent_skills/infra.md
deleted file mode 100644
index bc36c7bf64..0000000000
--- a/api/agent_skills/infra.md
+++ /dev/null
@@ -1,96 +0,0 @@
-## Configuration
-
-- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
-- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
-- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
-- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
-
-## Dependencies
-
-- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
-- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
-- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
-
-## Storage & Files
-
-- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
-- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
-- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
-- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
-
-## Redis & Shared State
-
-- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
-- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
-
-## Models
-
-- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
-- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
-- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
-- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
-
-## Vector Stores
-
-- Vector client implementations live in `core/rag/datasource/vdb/`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
-- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
-- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
-- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
-
-## Observability & OTEL
-
-- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
-- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
-- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
-- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
-
-## Ops Integrations
-
-- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
-- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
-- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
-
-## Controllers, Services, Core
-
-- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
-- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
-- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
-
-## Plugins, Tools, Providers
-
-- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
-- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
-- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
-- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
-- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
-- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
-
-## Async Workloads
-
-see `agent_skills/trigger.md` for more detailed documentation.
-
-- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
-- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
-- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
-
-## Database & Migrations
-
-- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
-- Generate migrations with `uv run --project api flask db revision --autogenerate -m ""`, then review the diff; never hand-edit the database outside Alembic.
-- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
-- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
-
-## CLI Commands
-
-- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask `.
-- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
-- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
-- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
-- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
-
-## When You Add Features
-
-- Check for an existing helper or service before writing a new util.
-- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
-- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
-- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
diff --git a/api/agent_skills/plugin.md b/api/agent_skills/plugin.md
deleted file mode 100644
index 954ddd236b..0000000000
--- a/api/agent_skills/plugin.md
+++ /dev/null
@@ -1 +0,0 @@
-// TBD
diff --git a/api/agent_skills/plugin_oauth.md b/api/agent_skills/plugin_oauth.md
deleted file mode 100644
index 954ddd236b..0000000000
--- a/api/agent_skills/plugin_oauth.md
+++ /dev/null
@@ -1 +0,0 @@
-// TBD
diff --git a/api/agent_skills/trigger.md b/api/agent_skills/trigger.md
deleted file mode 100644
index f4b076332c..0000000000
--- a/api/agent_skills/trigger.md
+++ /dev/null
@@ -1,53 +0,0 @@
-## Overview
-
-Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
-
-## Trigger nodes
-
-- `UserInput`
-- `Trigger Webhook`
-- `Trigger Schedule`
-- `Trigger Plugin`
-
-### UserInput
-
-Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
-
-1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
-1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
-1. For its detailed implementation, please refer to `core/workflow/nodes/start`
-
-### Trigger Webhook
-
-Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
-
-Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
-
-### Trigger Schedule
-
-`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
-
-To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
-
-### Trigger Plugin
-
-`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
-
-1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
-1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
-
-A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
-
-## Worker Pool / Async Task
-
-All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
-
-The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
-
-## Debug Strategy
-
-Dify divided users into 2 groups: builders / end users.
-
-Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
-
-A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
diff --git a/api/app.py b/api/app.py
index 99f70f32d5..c018c8a045 100644
--- a/api/app.py
+++ b/api/app.py
@@ -1,4 +1,12 @@
+from __future__ import annotations
+
import sys
+from typing import TYPE_CHECKING, cast
+
+if TYPE_CHECKING:
+ from celery import Celery
+
+ celery: Celery
def is_db_command() -> bool:
@@ -23,7 +31,7 @@ else:
from app_factory import create_app
app = create_app()
- celery = app.extensions["celery"]
+ celery = cast("Celery", app.extensions["celery"])
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
diff --git a/api/app_factory.py b/api/app_factory.py
index 3a3ee03cff..dcbc821687 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -2,9 +2,11 @@ import logging
import time
from opentelemetry.trace import get_current_span
+from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
+from core.logging.context import init_request_context
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
# add before request hook
@dify_app.before_request
def before_request():
- # add an unique identifier to each request
+ # Initialize logging context for this request
+ init_request_context()
RecyclableContextVar.increment_thread_recycles()
- # add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
+ # add after request hook for injecting trace headers from OpenTelemetry span context
+ # Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
- def add_trace_id_header(response):
+ def add_trace_headers(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
- if ctx and ctx.is_valid:
- trace_id_hex = format(ctx.trace_id, "032x")
- # Avoid duplicates if some middleware added it
- if "X-Trace-Id" not in response.headers:
- response.headers["X-Trace-Id"] = trace_id_hex
+
+ if not ctx or not ctx.is_valid:
+ return response
+
+ # Inject trace headers from OTEL context
+ if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
+ response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
+ if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
+ response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
+
except Exception:
# Never break the response due to tracing header injection
- logger.warning("Failed to add trace ID to response header", exc_info=True)
+ logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
- _ = add_trace_id_header
+ _ = add_trace_headers
return dify_app
@@ -62,6 +71,8 @@ def create_app() -> DifyApp:
def initialize_extensions(app: DifyApp):
+ # Initialize Flask context capture for workflow execution
+ from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
@@ -70,11 +81,13 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
+ ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
ext_login,
+ ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
@@ -83,12 +96,15 @@ def initialize_extensions(app: DifyApp):
ext_redis,
ext_request_logging,
ext_sentry,
+ ext_session_factory,
ext_set_secretkey,
ext_storage,
ext_timezone,
ext_warnings,
)
+ init_flask_context()
+
extensions = [
ext_timezone,
ext_logging,
@@ -104,6 +120,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
+ ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,
@@ -112,8 +129,10 @@ def initialize_extensions(app: DifyApp):
ext_proxy_fix,
ext_blueprints,
ext_commands,
+ ext_fastopenapi,
ext_otel,
ext_request_logging,
+ ext_session_factory,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]
@@ -130,7 +149,7 @@ def initialize_extensions(app: DifyApp):
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
-def create_migrations_app():
+def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
diff --git a/api/commands.py b/api/commands.py
index a8d89ac200..4b811fb1e6 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -1,7 +1,9 @@
import base64
+import datetime
import json
import logging
import secrets
+import time
from typing import Any
import click
@@ -20,7 +22,7 @@ from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
-from core.rag.models.document import Document
+from core.rag.models.document import ChildDocument, Document
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
@@ -34,7 +36,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
-from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
+from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@@ -45,6 +47,9 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -62,8 +67,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
+ normalized_email = email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@@ -84,7 +91,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -100,20 +107,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
+ normalized_new_email = new_email.strip().lower()
+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
- account = session.query(Account).where(Account.email == email).one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
- email_validate(new_email)
+ email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
- account.email = new_email
+ account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@@ -235,7 +244,7 @@ def migrate_annotation_vector_database():
if annotations:
for annotation in annotations:
document = Document(
- page_content=annotation.question,
+ page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
@@ -409,6 +418,22 @@ def migrate_knowledge_vector_database():
"dataset_id": segment.dataset_id,
},
)
+ if dataset_document.doc_form == "hierarchical_model":
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
documents.append(document)
segments_count = segments_count + 1
@@ -422,7 +447,13 @@ def migrate_knowledge_vector_database():
fg="green",
)
)
+ all_child_documents = []
+ for doc in documents:
+ if doc.children:
+ all_child_documents.extend(doc.children)
vector.create(documents)
+ if all_child_documents:
+ vector.create(all_child_documents)
click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
@@ -658,7 +689,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
- email = email.strip()
+ email = email.strip().lower()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@@ -852,6 +883,435 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
+@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
+@click.option(
+ "--before-days",
+ "--days",
+ default=30,
+ show_default=True,
+ type=click.IntRange(min=0),
+ help="Delete workflow runs created before N days ago.",
+)
+@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
+@click.option(
+ "--from-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
+)
+@click.option(
+ "--to-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
+)
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option(
+ "--dry-run",
+ is_flag=True,
+ help="Preview cleanup results without deleting any workflow run data.",
+)
+def clean_workflow_runs(
+ before_days: int,
+ batch_size: int,
+ from_days_ago: int | None,
+ to_days_ago: int | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ dry_run: bool,
+):
+ """
+ Clean workflow runs and related workflow data for free tenants.
+ """
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+
+ if (from_days_ago is None) ^ (to_days_ago is None):
+ raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
+
+ if from_days_ago is not None and to_days_ago is not None:
+ if start_from or end_before:
+ raise click.UsageError("Choose either day offsets or explicit dates, not both.")
+ if from_days_ago <= to_days_ago:
+ raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
+ now = datetime.datetime.now()
+ start_from = now - datetime.timedelta(days=from_days_ago)
+ end_before = now - datetime.timedelta(days=to_days_ago)
+ before_days = 0
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
+
+ WorkflowRunCleanup(
+ days=before_days,
+ batch_size=batch_size,
+ start_from=start_from,
+ end_before=end_before,
+ dry_run=dry_run,
+ ).run()
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Workflow run cleanup completed. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+
+
+@click.command(
+ "archive-workflow-runs",
+ help="Archive workflow runs for paid plan tenants to S3-compatible storage.",
+)
+@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.")
+@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.")
+@click.option(
+ "--from-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
+)
+@click.option(
+ "--to-days-ago",
+ default=None,
+ type=click.IntRange(min=0),
+ help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
+)
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Archive runs created at or after this timestamp (UTC if no timezone).",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Archive runs created before this timestamp (UTC if no timezone).",
+)
+@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.")
+@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.")
+@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.")
+@click.option("--dry-run", is_flag=True, help="Preview without archiving.")
+@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.")
+def archive_workflow_runs(
+ tenant_ids: str | None,
+ before_days: int,
+ from_days_ago: int | None,
+ to_days_ago: int | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ batch_size: int,
+ workers: int,
+ limit: int | None,
+ dry_run: bool,
+ delete_after_archive: bool,
+):
+ """
+ Archive workflow runs for paid plan tenants older than the specified days.
+
+ This command archives the following tables to storage:
+ - workflow_node_executions
+ - workflow_node_execution_offload
+ - workflow_pauses
+ - workflow_pause_reasons
+ - workflow_trigger_logs
+
+ The workflow_runs and workflow_app_logs tables are preserved for UI listing.
+ """
+ from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
+
+ run_started_at = datetime.datetime.now(datetime.UTC)
+ click.echo(
+ click.style(
+ f"Starting workflow run archiving at {run_started_at.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ if (start_from is None) ^ (end_before is None):
+ click.echo(click.style("start-from and end-before must be provided together.", fg="red"))
+ return
+
+ if (from_days_ago is None) ^ (to_days_ago is None):
+ click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red"))
+ return
+
+ if from_days_ago is not None and to_days_ago is not None:
+ if start_from or end_before:
+ click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red"))
+ return
+ if from_days_ago <= to_days_ago:
+ click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red"))
+ return
+ now = datetime.datetime.now()
+ start_from = now - datetime.timedelta(days=from_days_ago)
+ end_before = now - datetime.timedelta(days=to_days_ago)
+ before_days = 0
+
+ if start_from and end_before and start_from >= end_before:
+ click.echo(click.style("start-from must be earlier than end-before.", fg="red"))
+ return
+ if workers < 1:
+ click.echo(click.style("workers must be at least 1.", fg="red"))
+ return
+
+ archiver = WorkflowRunArchiver(
+ days=before_days,
+ batch_size=batch_size,
+ start_from=start_from,
+ end_before=end_before,
+ workers=workers,
+ tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None,
+ limit=limit,
+ dry_run=dry_run,
+ delete_after_archive=delete_after_archive,
+ )
+ summary = archiver.run()
+ click.echo(
+ click.style(
+ f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
+ f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
+ f"time={summary.total_elapsed_time:.2f}s",
+ fg="cyan",
+ )
+ )
+
+ run_finished_at = datetime.datetime.now(datetime.UTC)
+ elapsed = run_finished_at - run_started_at
+ click.echo(
+ click.style(
+ f"Workflow run archiving completed. start={run_started_at.isoformat()} "
+ f"end={run_finished_at.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+
+
+@click.command(
+ "restore-workflow-runs",
+ help="Restore archived workflow runs from S3-compatible storage.",
+)
+@click.option(
+ "--tenant-ids",
+ required=False,
+ help="Tenant IDs (comma-separated).",
+)
+@click.option("--run-id", required=False, help="Workflow run ID to restore.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.")
+@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.")
+@click.option("--dry-run", is_flag=True, help="Preview without restoring.")
+def restore_workflow_runs(
+ tenant_ids: str | None,
+ run_id: str | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ workers: int,
+ limit: int,
+ dry_run: bool,
+):
+ """
+ Restore an archived workflow run from storage to the database.
+
+ This restores the following tables:
+ - workflow_node_executions
+ - workflow_node_execution_offload
+ - workflow_pauses
+ - workflow_pause_reasons
+ - workflow_trigger_logs
+ """
+ from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
+
+ parsed_tenant_ids = None
+ if tenant_ids:
+ parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
+ if not parsed_tenant_ids:
+ raise click.BadParameter("tenant-ids must not be empty")
+
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+ if run_id is None and (start_from is None or end_before is None):
+ raise click.UsageError("--start-from and --end-before are required for batch restore.")
+ if workers < 1:
+ raise click.BadParameter("workers must be at least 1")
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ click.echo(
+ click.style(
+ f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
+ if run_id:
+ results = [restorer.restore_by_run_id(run_id)]
+ else:
+ assert start_from is not None
+ assert end_before is not None
+ results = restorer.restore_batch(
+ parsed_tenant_ids,
+ start_date=start_from,
+ end_date=end_before,
+ limit=limit,
+ )
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+
+ successes = sum(1 for result in results if result.success)
+ failures = len(results) - successes
+
+ if failures == 0:
+ click.echo(
+ click.style(
+ f"Restore completed successfully. success={successes} duration={elapsed}",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
+ fg="red",
+ )
+ )
+
+
+@click.command(
+ "delete-archived-workflow-runs",
+ help="Delete archived workflow runs from the database.",
+)
+@click.option(
+ "--tenant-ids",
+ required=False,
+ help="Tenant IDs (comma-separated).",
+)
+@click.option("--run-id", required=False, help="Workflow run ID to delete.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ default=None,
+ help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
+)
+@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.")
+@click.option("--dry-run", is_flag=True, help="Preview without deleting.")
+def delete_archived_workflow_runs(
+ tenant_ids: str | None,
+ run_id: str | None,
+ start_from: datetime.datetime | None,
+ end_before: datetime.datetime | None,
+ limit: int,
+ dry_run: bool,
+):
+ """
+ Delete archived workflow runs from the database.
+ """
+ from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
+
+ parsed_tenant_ids = None
+ if tenant_ids:
+ parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
+ if not parsed_tenant_ids:
+ raise click.BadParameter("tenant-ids must not be empty")
+
+ if (start_from is None) ^ (end_before is None):
+ raise click.UsageError("--start-from and --end-before must be provided together.")
+ if run_id is None and (start_from is None or end_before is None):
+ raise click.UsageError("--start-from and --end-before are required for batch delete.")
+
+ start_time = datetime.datetime.now(datetime.UTC)
+ target_desc = f"workflow run {run_id}" if run_id else "workflow runs"
+ click.echo(
+ click.style(
+ f"Starting delete of {target_desc} at {start_time.isoformat()}.",
+ fg="white",
+ )
+ )
+
+ deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run)
+ if run_id:
+ results = [deleter.delete_by_run_id(run_id)]
+ else:
+ assert start_from is not None
+ assert end_before is not None
+ results = deleter.delete_batch(
+ parsed_tenant_ids,
+ start_date=start_from,
+ end_date=end_before,
+ limit=limit,
+ )
+
+ for result in results:
+ if result.success:
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
+ f"workflow run {result.run_id} (tenant={result.tenant_id})",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Failed to delete workflow run {result.run_id}: {result.error}",
+ fg="red",
+ )
+ )
+
+ end_time = datetime.datetime.now(datetime.UTC)
+ elapsed = end_time - start_time
+
+ successes = sum(1 for result in results if result.success)
+ failures = len(results) - successes
+
+ if failures == 0:
+ click.echo(
+ click.style(
+ f"Delete completed successfully. success={successes} duration={elapsed}",
+ fg="green",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
+ fg="red",
+ )
+ )
+
+
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):
@@ -1184,6 +1644,217 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
+@click.command("file-usage", help="Query file usages and show where files are referenced.")
+@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
+@click.option("--key", type=str, default=None, help="Filter by storage key.")
+@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
+@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
+@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
+@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
+def file_usage(
+ file_id: str | None,
+ key: str | None,
+ src: str | None,
+ limit: int,
+ offset: int,
+ output_json: bool,
+):
+ """
+ Query file usages and show where files are referenced in the database.
+
+ This command reuses the same reference checking logic as clear-orphaned-file-records
+ and displays detailed information about where each file is referenced.
+ """
+ # define tables and columns to process
+ files_tables = [
+ {"table": "upload_files", "id_column": "id", "key_column": "key"},
+ {"table": "tool_files", "id_column": "id", "key_column": "file_key"},
+ ]
+ ids_tables = [
+ {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
+ {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
+ {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
+ {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
+ {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
+ {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
+ {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
+ {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
+ {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
+ {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
+ {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
+ {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
+ {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
+ {"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
+ ]
+
+ # Stream file usages with pagination to avoid holding all results in memory
+ paginated_usages = []
+ total_count = 0
+
+ # First, build a mapping of file_id -> storage_key from the base tables
+ file_key_map = {}
+ for files_table in files_tables:
+ query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
+ with db.engine.begin() as conn:
+ rs = conn.execute(sa.text(query))
+ for row in rs:
+ file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
+
+ # If filtering by key or file_id, verify it exists
+ if file_id and file_id not in file_key_map:
+ if output_json:
+ click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
+ else:
+ click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
+ return
+
+ if key:
+ valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
+ matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
+ if not matching_file_ids:
+ if output_json:
+ click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
+ else:
+ click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
+ return
+
+ guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
+
+ # For each reference table/column, find matching file IDs and record the references
+ for ids_table in ids_tables:
+ src_filter = f"{ids_table['table']}.{ids_table['column']}"
+
+ # Skip if src filter doesn't match (use fnmatch for wildcard patterns)
+ if src:
+ if "%" in src or "_" in src:
+ import fnmatch
+
+ # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
+ pattern = src.replace("%", "*").replace("_", "?")
+ if not fnmatch.fnmatch(src_filter, pattern):
+ continue
+ else:
+ if src_filter != src:
+ continue
+
+ if ids_table["type"] == "uuid":
+ # Direct UUID match
+ query = (
+ f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
+ f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
+ )
+ with db.engine.begin() as conn:
+ rs = conn.execute(sa.text(query))
+ for row in rs:
+ record_id = str(row[0])
+ ref_file_id = str(row[1])
+ if ref_file_id not in file_key_map:
+ continue
+ storage_key = file_key_map[ref_file_id]
+
+ # Apply filters
+ if file_id and ref_file_id != file_id:
+ continue
+ if key and not storage_key.endswith(key):
+ continue
+
+ # Only collect items within the requested page range
+ if offset <= total_count < offset + limit:
+ paginated_usages.append(
+ {
+ "src": f"{ids_table['table']}.{ids_table['column']}",
+ "record_id": record_id,
+ "file_id": ref_file_id,
+ "key": storage_key,
+ }
+ )
+ total_count += 1
+
+ elif ids_table["type"] in ("text", "json"):
+ # Extract UUIDs from text/json content
+ column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
+ query = (
+ f"SELECT {ids_table['pk_column']}, {column_cast} "
+ f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
+ )
+ with db.engine.begin() as conn:
+ rs = conn.execute(sa.text(query))
+ for row in rs:
+ record_id = str(row[0])
+ content = str(row[1])
+
+ # Find all UUIDs in the content
+ import re
+
+ uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
+ matches = uuid_pattern.findall(content)
+
+ for ref_file_id in matches:
+ if ref_file_id not in file_key_map:
+ continue
+ storage_key = file_key_map[ref_file_id]
+
+ # Apply filters
+ if file_id and ref_file_id != file_id:
+ continue
+ if key and not storage_key.endswith(key):
+ continue
+
+ # Only collect items within the requested page range
+ if offset <= total_count < offset + limit:
+ paginated_usages.append(
+ {
+ "src": f"{ids_table['table']}.{ids_table['column']}",
+ "record_id": record_id,
+ "file_id": ref_file_id,
+ "key": storage_key,
+ }
+ )
+ total_count += 1
+
+ # Output results
+ if output_json:
+ result = {
+ "total": total_count,
+ "offset": offset,
+ "limit": limit,
+ "usages": paginated_usages,
+ }
+ click.echo(json.dumps(result, indent=2))
+ else:
+ click.echo(
+ click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
+ )
+ click.echo("")
+
+ if not paginated_usages:
+ click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
+ return
+
+ # Print table header
+ click.echo(
+ click.style(
+ f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
+ fg="cyan",
+ )
+ )
+ click.echo(click.style("-" * 190, fg="white"))
+
+ # Print each usage
+ for usage in paginated_usages:
+ click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
+
+ # Show pagination info
+ if offset + limit < total_count:
+ click.echo("")
+ click.echo(
+ click.style(
+ f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
+ )
+ )
+ click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
+
+
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
@@ -1900,3 +2571,79 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
+
+
+@click.command("clean-expired-messages", help="Clean expired messages.")
+@click.option(
+ "--start-from",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Lower bound (inclusive) for created_at.",
+)
+@click.option(
+ "--end-before",
+ type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
+ required=True,
+ help="Upper bound (exclusive) for created_at.",
+)
+@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
+@click.option(
+ "--graceful-period",
+ default=21,
+ show_default=True,
+ help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
+)
+@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
+def clean_expired_messages(
+ batch_size: int,
+ graceful_period: int,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ dry_run: bool,
+):
+ """
+ Clean expired messages and related data for tenants based on clean policy.
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ # NOTE: graceful_period will be ignored when billing is disabled.
+ policy = create_message_clean_policy(graceful_period_days=graceful_period)
+
+ # Create and run the cleanup service
+ service = MessagesCleanService.from_time_range(
+ policy=policy,
+ start_from=start_from,
+ end_before=end_before,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
+
+ click.echo(click.style("messages cleanup completed.", fg="green"))
diff --git a/api/configs/extra/__init__.py b/api/configs/extra/__init__.py
index 4543b5389d..de97adfc0e 100644
--- a/api/configs/extra/__init__.py
+++ b/api/configs/extra/__init__.py
@@ -1,9 +1,11 @@
+from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
+ ArchiveStorageConfig,
NotionConfig,
SentryConfig,
):
diff --git a/api/configs/extra/archive_config.py b/api/configs/extra/archive_config.py
new file mode 100644
index 0000000000..a85628fa61
--- /dev/null
+++ b/api/configs/extra/archive_config.py
@@ -0,0 +1,43 @@
+from pydantic import Field
+from pydantic_settings import BaseSettings
+
+
+class ArchiveStorageConfig(BaseSettings):
+ """
+ Configuration settings for workflow run logs archiving storage.
+ """
+
+ ARCHIVE_STORAGE_ENABLED: bool = Field(
+ description="Enable workflow run logs archiving to S3-compatible storage",
+ default=False,
+ )
+
+ ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
+ description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
+ description="Name of the bucket to store archived workflow logs",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
+ description="Name of the bucket to store exported workflow runs",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
+ description="Access key ID for authenticating with storage",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
+ description="Secret access key for authenticating with storage",
+ default=None,
+ )
+
+ ARCHIVE_STORAGE_REGION: str = Field(
+ description="Region for storage (use 'auto' if the provider supports it)",
+ default="auto",
+ )
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index a5916241df..d97e9a0440 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
- default=300.0,
+ default=600.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
@@ -243,6 +243,11 @@ class PluginConfig(BaseSettings):
default=15728640 * 12,
)
+ PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field(
+ description="TTL in seconds for caching plugin model schemas in Redis",
+ default=60 * 60,
+ )
+
class MarketplaceConfig(BaseSettings):
"""
@@ -380,6 +385,37 @@ class FileUploadConfig(BaseSettings):
default=60,
)
+ # Annotation Import Security Configurations
+ ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
+ description="Maximum allowed CSV file size for annotation import in megabytes",
+ default=2,
+ )
+
+ ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
+ description="Maximum number of annotation records allowed in a single import",
+ default=10000,
+ )
+
+ ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
+ description="Minimum number of annotation records required in a single import",
+ default=1,
+ )
+
+ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
+ description="Maximum number of annotation import requests per minute per tenant",
+ default=5,
+ )
+
+ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
+ description="Maximum number of annotation import requests per hour per tenant",
+ default=20,
+ )
+
+ ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
+ description="Maximum number of concurrent annotation import tasks per tenant",
+ default=2,
+ )
+
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
@@ -556,6 +592,11 @@ class LoggingConfig(BaseSettings):
default="INFO",
)
+ LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
+ description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
+ default="text",
+ )
+
LOG_FILE: str | None = Field(
description="File path for log output.",
default=None,
@@ -913,6 +954,12 @@ class MailConfig(BaseSettings):
default=False,
)
+ SMTP_LOCAL_HOSTNAME: str | None = Field(
+ description="Override the local hostname used in SMTP HELO/EHLO. "
+ "Useful behind NAT or when the default hostname causes rejections.",
+ default=None,
+ )
+
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
@@ -923,6 +970,16 @@ class MailConfig(BaseSettings):
default=None,
)
+ ENABLE_TRIAL_APP: bool = Field(
+ description="Enable trial app",
+ default=False,
+ )
+
+ ENABLE_EXPLORE_BANNER: bool = Field(
+ description="Enable explore banner",
+ default=False,
+ )
+
class RagEtlConfig(BaseSettings):
"""
@@ -1065,6 +1122,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
+ ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
+ description="Enable scheduled workflow run cleanup task",
+ default=False,
+ )
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
@@ -1239,6 +1300,25 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
)
+class SandboxExpiredRecordsCleanConfig(BaseSettings):
+ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
+ description="Graceful period in days for sandbox records clean after subscription expiration",
+ default=21,
+ )
+ SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
+ description="Maximum number of records to process in each batch",
+ default=1000,
+ )
+ SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
+ description="Retention days for sandbox expired workflow_run records and message records",
+ default=30,
+ )
+ SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
+ description="Lock TTL for sandbox expired records clean task in seconds",
+ default=90000,
+ )
+
+
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@@ -1264,6 +1344,7 @@ class FeatureConfig(
PositionConfig,
RagEtlConfig,
RepositoryConfig,
+ SandboxExpiredRecordsCleanConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,
diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py
index 4ad30014c7..42ede718c4 100644
--- a/api/configs/feature/hosted_service/__init__.py
+++ b/api/configs/feature/hosted_service/__init__.py
@@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
default="",
)
+ HOSTED_POOL_CREDITS: int = Field(
+ description="Pool credits for hosted service",
+ default=200,
+ )
+
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings):
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
- default="gpt-3.5-turbo,"
- "gpt-3.5-turbo-1106,"
- "gpt-3.5-turbo-instruct,"
+ default="gpt-4,"
+ "gpt-4-turbo-preview,"
+ "gpt-4-turbo-2024-04-09,"
+ "gpt-4-1106-preview,"
+ "gpt-4-0125-preview,"
+ "gpt-4-turbo,"
+ "gpt-4.1,"
+ "gpt-4.1-2025-04-14,"
+ "gpt-4.1-mini,"
+ "gpt-4.1-mini-2025-04-14,"
+ "gpt-4.1-nano,"
+ "gpt-4.1-nano-2025-04-14,"
+ "gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
+ "gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
- "text-davinci-003",
- )
-
- HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
- description="Quota limit for hosted OpenAI service usage",
- default=200,
+ "gpt-3.5-turbo-instruct,"
+ "text-davinci-003,"
+ "chatgpt-4o-latest,"
+ "gpt-4o,"
+ "gpt-4o-2024-05-13,"
+ "gpt-4o-2024-08-06,"
+ "gpt-4o-2024-11-20,"
+ "gpt-4o-audio-preview,"
+ "gpt-4o-audio-preview-2025-06-03,"
+ "gpt-4o-mini,"
+ "gpt-4o-mini-2024-07-18,"
+ "o3-mini,"
+ "o3-mini-2025-01-31,"
+ "gpt-5-mini-2025-08-07,"
+ "gpt-5-mini,"
+ "o4-mini,"
+ "o4-mini-2025-04-16,"
+ "gpt-5-chat-latest,"
+ "gpt-5,"
+ "gpt-5-2025-08-07,"
+ "gpt-5-nano,"
+ "gpt-5-nano-2025-08-07",
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
@@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
+ "gpt-4-turbo,"
+ "gpt-4.1,"
+ "gpt-4.1-2025-04-14,"
+ "gpt-4.1-mini,"
+ "gpt-4.1-mini-2025-04-14,"
+ "gpt-4.1-nano,"
+ "gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
@@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
- "text-davinci-003",
+ "text-davinci-003,"
+ "chatgpt-4o-latest,"
+ "gpt-4o,"
+ "gpt-4o-2024-05-13,"
+ "gpt-4o-2024-08-06,"
+ "gpt-4o-2024-11-20,"
+ "gpt-4o-audio-preview,"
+ "gpt-4o-audio-preview-2025-06-03,"
+ "gpt-4o-mini,"
+ "gpt-4o-mini-2024-07-18,"
+ "o3-mini,"
+ "o3-mini-2025-01-31,"
+ "gpt-5-mini-2025-08-07,"
+ "gpt-5-mini,"
+ "o4-mini,"
+ "o4-mini-2025-04-16,"
+ "gpt-5-chat-latest,"
+ "gpt-5,"
+ "gpt-5-2025-08-07,"
+ "gpt-5-nano,"
+ "gpt-5-nano-2025-08-07",
+ )
+
+
+class HostedGeminiConfig(BaseSettings):
+ """
+ Configuration for fetching Gemini service
+ """
+
+ HOSTED_GEMINI_API_KEY: str | None = Field(
+ description="API key for hosted Gemini service",
+ default=None,
+ )
+
+ HOSTED_GEMINI_API_BASE: str | None = Field(
+ description="Base URL for hosted Gemini API",
+ default=None,
+ )
+
+ HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
+ description="Organization ID for hosted Gemini service",
+ default=None,
+ )
+
+ HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted Gemini service",
+ default=False,
+ )
+
+ HOSTED_GEMINI_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
+ )
+
+ HOSTED_GEMINI_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted gemini service",
+ default=False,
+ )
+
+ HOSTED_GEMINI_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
+ )
+
+
+class HostedXAIConfig(BaseSettings):
+ """
+ Configuration for fetching XAI service
+ """
+
+ HOSTED_XAI_API_KEY: str | None = Field(
+ description="API key for hosted XAI service",
+ default=None,
+ )
+
+ HOSTED_XAI_API_BASE: str | None = Field(
+ description="Base URL for hosted XAI API",
+ default=None,
+ )
+
+ HOSTED_XAI_API_ORGANIZATION: str | None = Field(
+ description="Organization ID for hosted XAI service",
+ default=None,
+ )
+
+ HOSTED_XAI_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted XAI service",
+ default=False,
+ )
+
+ HOSTED_XAI_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="grok-3,grok-3-mini,grok-3-mini-fast",
+ )
+
+ HOSTED_XAI_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted XAI service",
+ default=False,
+ )
+
+ HOSTED_XAI_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="grok-3,grok-3-mini,grok-3-mini-fast",
+ )
+
+
+class HostedDeepseekConfig(BaseSettings):
+ """
+ Configuration for fetching Deepseek service
+ """
+
+ HOSTED_DEEPSEEK_API_KEY: str | None = Field(
+ description="API key for hosted Deepseek service",
+ default=None,
+ )
+
+ HOSTED_DEEPSEEK_API_BASE: str | None = Field(
+ description="Base URL for hosted Deepseek API",
+ default=None,
+ )
+
+ HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
+ description="Organization ID for hosted Deepseek service",
+ default=None,
+ )
+
+ HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted Deepseek service",
+ default=False,
+ )
+
+ HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="deepseek-chat,deepseek-reasoner",
+ )
+
+ HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted Deepseek service",
+ default=False,
+ )
+
+ HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="deepseek-chat,deepseek-reasoner",
)
@@ -144,16 +326,66 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
- HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
- description="Quota limit for hosted Anthropic service usage",
- default=600000,
- )
-
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
+ HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="claude-opus-4-20250514,"
+ "claude-sonnet-4-20250514,"
+ "claude-3-5-haiku-20241022,"
+ "claude-3-opus-20240229,"
+ "claude-3-7-sonnet-20250219,"
+ "claude-3-haiku-20240307",
+ )
+ HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="claude-opus-4-20250514,"
+ "claude-sonnet-4-20250514,"
+ "claude-3-5-haiku-20241022,"
+ "claude-3-opus-20240229,"
+ "claude-3-7-sonnet-20250219,"
+ "claude-3-haiku-20240307",
+ )
+
+
+class HostedTongyiConfig(BaseSettings):
+ """
+ Configuration for hosted Tongyi service
+ """
+
+ HOSTED_TONGYI_API_KEY: str | None = Field(
+ description="API key for hosted Tongyi service",
+ default=None,
+ )
+
+ HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field(
+ description="Use international endpoint for hosted Tongyi service",
+ default=False,
+ )
+
+ HOSTED_TONGYI_TRIAL_ENABLED: bool = Field(
+ description="Enable trial access to hosted Tongyi service",
+ default=False,
+ )
+
+ HOSTED_TONGYI_PAID_ENABLED: bool = Field(
+ description="Enable paid access to hosted Anthropic service",
+ default=False,
+ )
+
+ HOSTED_TONGYI_TRIAL_MODELS: str = Field(
+ description="Comma-separated list of available models for trial access",
+ default="",
+ )
+
+ HOSTED_TONGYI_PAID_MODELS: str = Field(
+ description="Comma-separated list of available models for paid access",
+ default="",
+ )
+
class HostedMinmaxConfig(BaseSettings):
"""
@@ -246,9 +478,13 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
+ HostedTongyiConfig,
# moderation
HostedModerationConfig,
# credit config
HostedCreditConfig,
+ HostedGeminiConfig,
+ HostedXAIConfig,
+ HostedDeepseekConfig,
):
pass
diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py
index a5e35c99ca..63f75924bf 100644
--- a/api/configs/middleware/__init__.py
+++ b/api/configs/middleware/__init__.py
@@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
+from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig
@@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
- DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
+ DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
@@ -336,6 +337,7 @@ class MiddlewareConfig(
ChromaConfig,
ClickzettaConfig,
HuaweiCloudConfig,
+ IrisVectorConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,
diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py
index 331c486d54..6df14175ae 100644
--- a/api/configs/middleware/storage/aliyun_oss_storage_config.py
+++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py
@@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
default=None,
)
+
+ ALIYUN_CLOUDBOX_ID: str | None = Field(
+ description="Cloudbox id for aliyun cloudbox service",
+ default=None,
+ )
diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py
index 5b5cd2f750..46b6f2e68d 100644
--- a/api/configs/middleware/storage/huawei_obs_storage_config.py
+++ b/api/configs/middleware/storage/huawei_obs_storage_config.py
@@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None,
)
+
+ HUAWEI_OBS_PATH_STYLE: bool = Field(
+ description="Flag to indicate whether to use path-style URLs for OBS requests",
+ default=False,
+ )
diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py
index e297e748e9..cdd10740f8 100644
--- a/api/configs/middleware/storage/tencent_cos_storage_config.py
+++ b/api/configs/middleware/storage/tencent_cos_storage_config.py
@@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
default=None,
)
+
+ TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
+ description="Tencent Cloud COS custom domain setting",
+ default=None,
+ )
diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py
index be01f2dc36..2a35300401 100644
--- a/api/configs/middleware/storage/volcengine_tos_storage_config.py
+++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py
@@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
- Configuration settings for Volcengine Tinder Object Storage (TOS)
+ Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(
diff --git a/api/configs/middleware/vdb/iris_config.py b/api/configs/middleware/vdb/iris_config.py
new file mode 100644
index 0000000000..c532d191c3
--- /dev/null
+++ b/api/configs/middleware/vdb/iris_config.py
@@ -0,0 +1,91 @@
+"""Configuration for InterSystems IRIS vector database."""
+
+from pydantic import Field, PositiveInt, model_validator
+from pydantic_settings import BaseSettings
+
+
+class IrisVectorConfig(BaseSettings):
+ """Configuration settings for IRIS vector database connection and pooling."""
+
+ IRIS_HOST: str | None = Field(
+ description="Hostname or IP address of the IRIS server.",
+ default="localhost",
+ )
+
+ IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
+ description="Port number for IRIS connection.",
+ default=1972,
+ )
+
+ IRIS_USER: str | None = Field(
+ description="Username for IRIS authentication.",
+ default="_SYSTEM",
+ )
+
+ IRIS_PASSWORD: str | None = Field(
+ description="Password for IRIS authentication.",
+ default="Dify@1234",
+ )
+
+ IRIS_SCHEMA: str | None = Field(
+ description="Schema name for IRIS tables.",
+ default="dify",
+ )
+
+ IRIS_DATABASE: str | None = Field(
+ description="Database namespace for IRIS connection.",
+ default="USER",
+ )
+
+ IRIS_CONNECTION_URL: str | None = Field(
+ description="Full connection URL for IRIS (overrides individual fields if provided).",
+ default=None,
+ )
+
+ IRIS_MIN_CONNECTION: PositiveInt = Field(
+ description="Minimum number of connections in the pool.",
+ default=1,
+ )
+
+ IRIS_MAX_CONNECTION: PositiveInt = Field(
+ description="Maximum number of connections in the pool.",
+ default=3,
+ )
+
+ IRIS_TEXT_INDEX: bool = Field(
+ description="Enable full-text search index using %iFind.Index.Basic.",
+ default=True,
+ )
+
+ IRIS_TEXT_INDEX_LANGUAGE: str = Field(
+ description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
+ default="en",
+ )
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_config(cls, values: dict) -> dict:
+ """Validate IRIS configuration values.
+
+ Args:
+ values: Configuration dictionary
+
+ Returns:
+ Validated configuration dictionary
+
+ Raises:
+ ValueError: If required fields are missing or pool settings are invalid
+ """
+ # Only validate required fields if IRIS is being used as the vector store
+ # This allows the config to be loaded even when IRIS is not in use
+
+ # vector_store = os.environ.get("VECTOR_STORE", "")
+ # We rely on Pydantic defaults for required fields if they are missing from env.
+ # Strict existence check is removed to allow defaults to work.
+
+ min_conn = values.get("IRIS_MIN_CONNECTION", 1)
+ max_conn = values.get("IRIS_MAX_CONNECTION", 3)
+ if min_conn > max_conn:
+ raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
+
+ return values
diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py
index 05cee51cc9..eb9b0ac2ab 100644
--- a/api/configs/middleware/vdb/milvus_config.py
+++ b/api/configs/middleware/vdb/milvus_config.py
@@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
description="Authentication token for Milvus, if token-based authentication is enabled",
default=None,
)
-
MILVUS_USER: str | None = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
default=None,
diff --git a/api/constants/languages.py b/api/constants/languages.py
index 0312a558c9..8c1ce368ac 100644
--- a/api/constants/languages.py
+++ b/api/constants/languages.py
@@ -20,6 +20,7 @@ language_timezone_mapping = {
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
+ "ar-TN": "Africa/Tunis",
}
languages = list(language_timezone_mapping.keys())
diff --git a/api/context/__init__.py b/api/context/__init__.py
new file mode 100644
index 0000000000..aebf9750ce
--- /dev/null
+++ b/api/context/__init__.py
@@ -0,0 +1,74 @@
+"""
+Core Context - Framework-agnostic context management.
+
+This module provides context management that is independent of any specific
+web framework. Framework-specific implementations register their context
+capture functions at application initialization time.
+
+This ensures the workflow layer remains completely decoupled from Flask
+or any other web framework.
+"""
+
+import contextvars
+from collections.abc import Callable
+
+from core.workflow.context.execution_context import (
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+)
+
+# Global capturer function - set by framework-specific modules
+_capturer: Callable[[], IExecutionContext] | None = None
+
+
+def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
+ """
+ Register a context capture function.
+
+ This should be called by framework-specific modules (e.g., Flask)
+ during application initialization.
+
+ Args:
+ capturer: Function that captures current context and returns IExecutionContext
+ """
+ global _capturer
+ _capturer = capturer
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context.
+
+ This function uses the registered context capturer. If no capturer
+ is registered, it returns a minimal context with only contextvars
+ (suitable for non-framework environments like tests or standalone scripts).
+
+ Returns:
+ IExecutionContext with captured context
+ """
+ if _capturer is None:
+ # No framework registered - return minimal context
+ return ExecutionContext(
+ app_context=NullAppContext(),
+ context_vars=contextvars.copy_context(),
+ )
+
+ return _capturer()
+
+
+def reset_context_provider() -> None:
+ """
+ Reset the context capturer.
+
+ This is primarily useful for testing to ensure a clean state.
+ """
+ global _capturer
+ _capturer = None
+
+
+__all__ = [
+ "capture_current_context",
+ "register_context_capturer",
+ "reset_context_provider",
+]
diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py
new file mode 100644
index 0000000000..2d465c8cf4
--- /dev/null
+++ b/api/context/flask_app_context.py
@@ -0,0 +1,192 @@
+"""
+Flask App Context - Flask implementation of AppContext interface.
+"""
+
+import contextvars
+import threading
+from collections.abc import Generator
+from contextlib import contextmanager
+from typing import Any, final
+
+from flask import Flask, current_app, g
+
+from core.workflow.context import register_context_capturer
+from core.workflow.context.execution_context import (
+ AppContext,
+ IExecutionContext,
+)
+
+
+@final
+class FlaskAppContext(AppContext):
+ """
+ Flask implementation of AppContext.
+
+ This adapts Flask's app context to the AppContext interface.
+ """
+
+ def __init__(self, flask_app: Flask) -> None:
+ """
+ Initialize Flask app context.
+
+ Args:
+ flask_app: The Flask application instance
+ """
+ self._flask_app = flask_app
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value from Flask app config."""
+ return self._flask_app.config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name."""
+ return self._flask_app.extensions.get(name)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask app context."""
+ with self._flask_app.app_context():
+ yield
+
+ @property
+ def flask_app(self) -> Flask:
+ """Get the underlying Flask app instance."""
+ return self._flask_app
+
+
+def capture_flask_context(user: Any = None) -> IExecutionContext:
+ """
+ Capture current Flask execution context.
+
+ This function captures the Flask app context and contextvars from the
+ current environment. It should be called from within a Flask request or
+ app context.
+
+ Args:
+ user: Optional user object to include in context
+
+ Returns:
+ IExecutionContext with captured Flask context
+
+ Raises:
+ RuntimeError: If called outside Flask context
+ """
+ # Get Flask app instance
+ flask_app = current_app._get_current_object() # type: ignore
+
+ # Save current user if available
+ saved_user = user
+ if saved_user is None:
+ # Check for user in g (flask-login)
+ if hasattr(g, "_login_user"):
+ saved_user = g._login_user
+
+ # Capture contextvars
+ context_vars = contextvars.copy_context()
+
+ return FlaskExecutionContext(
+ flask_app=flask_app,
+ context_vars=context_vars,
+ user=saved_user,
+ )
+
+
+@final
+class FlaskExecutionContext:
+ """
+ Flask-specific execution context.
+
+ This is a specialized version of ExecutionContext that includes Flask app
+ context. It provides the same interface as ExecutionContext but with
+ Flask-specific implementation.
+ """
+
+ def __init__(
+ self,
+ flask_app: Flask,
+ context_vars: contextvars.Context,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize Flask execution context.
+
+ Args:
+ flask_app: Flask application instance
+ context_vars: Python contextvars
+ user: Optional user object
+ """
+ self._app_context = FlaskAppContext(flask_app)
+ self._context_vars = context_vars
+ self._user = user
+ self._flask_app = flask_app
+ self._local = threading.local()
+
+ @property
+ def app_context(self) -> FlaskAppContext:
+ """Get Flask app context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ def __enter__(self) -> "FlaskExecutionContext":
+ """Enter the Flask execution context."""
+ # Restore non-Flask context variables to avoid leaking Flask tokens across threads
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter Flask app context
+ cm = self._app_context.enter()
+ self._local.cm = cm
+ cm.__enter__()
+
+ # Restore user in new app context
+ if self._user is not None:
+ g._login_user = self._user
+
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the Flask execution context."""
+ cm = getattr(self._local, "cm", None)
+ if cm is not None:
+ cm.__exit__(*args)
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter Flask execution context as context manager."""
+ # Restore non-Flask context variables to avoid leaking Flask tokens across threads
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter Flask app context
+ with self._flask_app.app_context():
+ # Restore user in new app context
+ if self._user is not None:
+ g._login_user = self._user
+ yield
+
+
+def init_flask_context() -> None:
+ """
+ Initialize Flask context capture by registering the capturer.
+
+ This function should be called during Flask application initialization
+ to register the Flask-specific context capturer with the core context module.
+
+ Example:
+ app = Flask(__name__)
+ init_flask_context() # Register Flask context capturer
+
+ Note:
+ This function does not need the app instance as it uses Flask's
+ `current_app` to get the app when capturing context.
+ """
+ register_context_capturer(capture_flask_context)
diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py
index 7c16bc231f..c52dcf8a57 100644
--- a/api/contexts/__init__.py
+++ b/api/contexts/__init__.py
@@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
- from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
@@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
-plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
-
-plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
- ContextVar("plugin_model_schemas")
-)
-
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)
diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py
index df9de825de..c16a23fac8 100644
--- a/api/controllers/common/fields.py
+++ b/api/controllers/common/fields.py
@@ -1,62 +1,59 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from libs.helper import AppIconUrlField
+from typing import Any, TypeAlias
-parameters__system_parameters = {
- "image_file_size_limit": fields.Integer,
- "video_file_size_limit": fields.Integer,
- "audio_file_size_limit": fields.Integer,
- "file_size_limit": fields.Integer,
- "workflow_file_upload_limit": fields.Integer,
-}
+from pydantic import BaseModel, ConfigDict, computed_field
+
+from core.file import helpers as file_helpers
+from models.model import IconType
+
+JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
+JSONObject: TypeAlias = dict[str, Any]
-def build_system_parameters_model(api_or_ns: Api | Namespace):
- """Build the system parameters model for the API or Namespace."""
- return api_or_ns.model("SystemParameters", parameters__system_parameters)
+class SystemParameters(BaseModel):
+ image_file_size_limit: int
+ video_file_size_limit: int
+ audio_file_size_limit: int
+ file_size_limit: int
+ workflow_file_upload_limit: int
-parameters_fields = {
- "opening_statement": fields.String,
- "suggested_questions": fields.Raw,
- "suggested_questions_after_answer": fields.Raw,
- "speech_to_text": fields.Raw,
- "text_to_speech": fields.Raw,
- "retriever_resource": fields.Raw,
- "annotation_reply": fields.Raw,
- "more_like_this": fields.Raw,
- "user_input_form": fields.Raw,
- "sensitive_word_avoidance": fields.Raw,
- "file_upload": fields.Raw,
- "system_parameters": fields.Nested(parameters__system_parameters),
-}
+class Parameters(BaseModel):
+ opening_statement: str | None = None
+ suggested_questions: list[str]
+ suggested_questions_after_answer: JSONObject
+ speech_to_text: JSONObject
+ text_to_speech: JSONObject
+ retriever_resource: JSONObject
+ annotation_reply: JSONObject
+ more_like_this: JSONObject
+ user_input_form: list[JSONObject]
+ sensitive_word_avoidance: JSONObject
+ file_upload: JSONObject
+ system_parameters: SystemParameters
-def build_parameters_model(api_or_ns: Api | Namespace):
- """Build the parameters model for the API or Namespace."""
- copied_fields = parameters_fields.copy()
- copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
- return api_or_ns.model("Parameters", copied_fields)
+class Site(BaseModel):
+ model_config = ConfigDict(from_attributes=True)
+ title: str
+ chat_color_theme: str | None = None
+ chat_color_theme_inverted: bool
+ icon_type: str | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ description: str | None = None
+ copyright: str | None = None
+ privacy_policy: str | None = None
+ custom_disclaimer: str | None = None
+ default_language: str
+ show_workflow_steps: bool
+ use_icon_as_answer_icon: bool
-site_fields = {
- "title": fields.String,
- "chat_color_theme": fields.String,
- "chat_color_theme_inverted": fields.Boolean,
- "icon_type": fields.String,
- "icon": fields.String,
- "icon_background": fields.String,
- "icon_url": AppIconUrlField,
- "description": fields.String,
- "copyright": fields.String,
- "privacy_policy": fields.String,
- "custom_disclaimer": fields.String,
- "default_language": fields.String,
- "show_workflow_steps": fields.Boolean,
- "use_icon_as_answer_icon": fields.Boolean,
-}
-
-
-def build_site_model(api_or_ns: Api | Namespace):
- """Build the site model for the API or Namespace."""
- return api_or_ns.model("Site", site_fields)
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ if self.icon and self.icon_type == IconType.IMAGE:
+ return file_helpers.get_signed_file_url(self.icon)
+ return None
diff --git a/api/controllers/common/file_response.py b/api/controllers/common/file_response.py
new file mode 100644
index 0000000000..ca8ea3d52e
--- /dev/null
+++ b/api/controllers/common/file_response.py
@@ -0,0 +1,57 @@
+import os
+from email.message import Message
+from urllib.parse import quote
+
+from flask import Response
+
+HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
+HTML_EXTENSIONS = frozenset({"html", "htm"})
+
+
+def _normalize_mime_type(mime_type: str | None) -> str:
+ if not mime_type:
+ return ""
+ message = Message()
+ message["Content-Type"] = mime_type
+ return message.get_content_type().strip().lower()
+
+
+def _is_html_extension(extension: str | None) -> bool:
+ if not extension:
+ return False
+ return extension.lstrip(".").lower() in HTML_EXTENSIONS
+
+
+def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
+ normalized_mime_type = _normalize_mime_type(mime_type)
+ if normalized_mime_type in HTML_MIME_TYPES:
+ return True
+
+ if _is_html_extension(extension):
+ return True
+
+ if filename:
+ return _is_html_extension(os.path.splitext(filename)[1])
+
+ return False
+
+
+def enforce_download_for_html(
+ response: Response,
+ *,
+ mime_type: str | None,
+ filename: str | None,
+ extension: str | None = None,
+) -> bool:
+ if not is_html_content(mime_type, filename, extension):
+ return False
+
+ if filename:
+ encoded_filename = quote(filename)
+ response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
+ else:
+ response.headers["Content-Disposition"] = "attachment"
+
+ response.headers["Content-Type"] = "application/octet-stream"
+ response.headers["X-Content-Type-Options"] = "nosniff"
+ return True
diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py
index e0896a8dc2..a5a3e4ebbd 100644
--- a/api/controllers/common/schema.py
+++ b/api/controllers/common/schema.py
@@ -1,7 +1,11 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
+from enum import StrEnum
+
from flask_restx import Namespace
-from pydantic import BaseModel
+from pydantic import BaseModel, TypeAdapter
+
+from controllers.console import console_ns
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -19,8 +23,25 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
register_schema_model(namespace, model)
+def get_or_create_model(model_name: str, field_def):
+ existing = console_ns.models.get(model_name)
+ if existing is None:
+ existing = console_ns.model(model_name, field_def)
+ return existing
+
+
+def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
+ """Register multiple StrEnum with a namespace."""
+ for model in models:
+ namespace.schema_model(
+ model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
+ )
+
+
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
+ "get_or_create_model",
+ "register_enum_models",
"register_schema_model",
"register_schema_models",
]
diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py
index ad878fc266..fdc9aabc83 100644
--- a/api/controllers/console/__init__.py
+++ b/api/controllers/console/__init__.py
@@ -107,10 +107,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
+ banner,
installed_app,
parameter,
recommended_app,
saved_message,
+ trial,
)
# Import tag controllers
@@ -145,6 +147,7 @@ __all__ = [
"apikey",
"app",
"audio",
+ "banner",
"billing",
"bp",
"completion",
@@ -198,6 +201,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
+ "trial",
"trigger_providers",
"version",
"website",
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index 7aa1e6dbd8..03b602f6e8 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -6,18 +6,19 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
-from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
-P = ParamSpec("P")
-R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
+from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
-from models.model import App, InstalledApp, RecommendedApp
+from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
+
+P = ParamSpec("P")
+R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -31,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
+ can_trial: bool = Field(default=False)
+ trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
@@ -38,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
return supported_language(value)
+class InsertExploreBannerPayload(BaseModel):
+ category: str = Field(...)
+ title: str = Field(...)
+ description: str = Field(...)
+ img_src: str = Field(..., alias="img-src")
+ language: str = Field(default="en-US")
+ link: str = Field(...)
+ sort: int = Field(...)
+
+ @field_validator("language")
+ @classmethod
+ def validate_language(cls, value: str) -> str:
+ return supported_language(value)
+
+ model_config = {"populate_by_name": True}
+
+
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
+console_ns.schema_model(
+ InsertExploreBannerPayload.__name__,
+ InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
def admin_required(view: Callable[P, R]):
@wraps(view)
@@ -90,7 +115,7 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
@@ -108,6 +133,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
+ if payload.can_trial:
+ trial_app = db.session.execute(
+ select(TrialApp).where(TrialApp.app_id == payload.app_id)
+ ).scalar_one_or_none()
+ if not trial_app:
+ db.session.add(
+ TrialApp(
+ app_id=payload.app_id,
+ tenant_id=app.tenant_id,
+ trial_limit=payload.trial_limit,
+ )
+ )
+ else:
+ trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -122,6 +161,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = payload.category
recommended_app.position = payload.position
+ if payload.can_trial:
+ trial_app = db.session.execute(
+ select(TrialApp).where(TrialApp.app_id == payload.app_id)
+ ).scalar_one_or_none()
+ if not trial_app:
+ db.session.add(
+ TrialApp(
+ app_id=payload.app_id,
+ tenant_id=app.tenant_id,
+ trial_limit=payload.trial_limit,
+ )
+ )
+ else:
+ trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@@ -138,7 +191,7 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
@@ -146,13 +199,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app:
return {"result": "success"}, 204
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
installed_apps = (
session.execute(
select(InstalledApp).where(
@@ -167,7 +220,60 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
+ trial_app = session.execute(
+ select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
+ ).scalar_one_or_none()
+ if trial_app:
+ session.delete(trial_app)
+
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
+
+
+@console_ns.route("/admin/insert-explore-banner")
+class InsertExploreBannerApi(Resource):
+ @console_ns.doc("insert_explore_banner")
+ @console_ns.doc(description="Insert an explore banner")
+ @console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
+ @console_ns.response(201, "Banner inserted successfully")
+ @only_edition_cloud
+ @admin_required
+ def post(self):
+ payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
+
+ banner = ExporleBanner(
+ content={
+ "category": payload.category,
+ "title": payload.title,
+ "description": payload.description,
+ "img-src": payload.img_src,
+ },
+ link=payload.link,
+ sort=payload.sort,
+ language=payload.language,
+ )
+ db.session.add(banner)
+ db.session.commit()
+
+ return {"result": "success"}, 201
+
+
+@console_ns.route("/admin/delete-explore-banner/")
+class DeleteExploreBannerApi(Resource):
+ @console_ns.doc("delete_explore_banner")
+ @console_ns.doc(description="Delete an explore banner")
+ @console_ns.doc(params={"banner_id": "Banner ID to delete"})
+ @console_ns.response(204, "Banner deleted successfully")
+ @only_edition_cloud
+ @admin_required
+ def delete(self, banner_id):
+ banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
+ if not banner:
+ raise NotFound(f"Banner '{banner_id}' is not found")
+
+ db.session.delete(banner)
+ db.session.commit()
+
+ return {"result": "success"}, 204
diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py
index 9b0d4b1a78..c81709e985 100644
--- a/api/controllers/console/apikey.py
+++ b/api/controllers/console/apikey.py
@@ -22,10 +22,10 @@ api_key_fields = {
"created_at": TimestampField,
}
-api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
-
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
+api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
+
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py
index 3b6fb58931..6a4c1528b0 100644
--- a/api/controllers/console/app/annotation.py
+++ b/api/controllers/console/app/annotation.py
@@ -1,6 +1,6 @@
from typing import Any, Literal
-from flask import request
+from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
@@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
+ annotation_import_concurrency_limit,
+ annotation_import_rate_limit,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
@@ -257,7 +259,7 @@ class AnnotationApi(Resource):
@console_ns.route("/apps//annotations/export")
class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations")
- @console_ns.doc(description="Export all annotations for an app")
+ @console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
@@ -272,8 +274,14 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
- response = {"data": marshal(annotation_list, annotation_fields)}
- return response, 200
+ response_data = {"data": marshal(annotation_list, annotation_fields)}
+
+ # Create response with secure headers for CSV export
+ response = make_response(response_data, 200)
+ response.headers["Content-Type"] = "application/json; charset=utf-8"
+ response.headers["X-Content-Type-Options"] = "nosniff"
+
+ return response
@console_ns.route("/apps//annotations/")
@@ -314,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps//annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations")
- @console_ns.doc(description="Batch import annotations from CSV file")
+ @console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files")
+ @console_ns.response(413, "File too large")
+ @console_ns.response(429, "Too many requests or concurrent imports")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
+ @annotation_import_rate_limit
+ @annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
+ from configs import dify_config
+
app_id = str(app_id)
+
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@@ -335,9 +350,27 @@ class AnnotationBatchImportApi(Resource):
# get file from request
file = request.files["file"]
+
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
+
+ # Check file size before processing
+ file.seek(0, 2) # Seek to end of file
+ file_size = file.tell()
+ file.seek(0) # Reset to beginning
+
+ max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
+ if file_size > max_size_bytes:
+ abort(
+ 413,
+ f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
+ f"Please reduce the file size and try again.",
+ )
+
+ if file_size == 0:
+ raise ValueError("The uploaded file is empty")
+
return AppAnnotationService.batch_import_app_annotations(app_id, file)
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 62e997dae2..8c371da596 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -1,15 +1,19 @@
import uuid
-from typing import Literal
+from datetime import datetime
+from typing import Any, Literal, TypeAlias
from flask import request
-from flask_restx import Resource, fields, marshal, marshal_with
-from pydantic import BaseModel, Field, field_validator
+from flask_restx import Resource
+from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
+from controllers.common.helpers import FileInfo
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
+from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@@ -18,27 +22,37 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
+from core.file import helpers as file_helpers
from core.ops.ops_trace_manager import OpsTraceManager
-from core.workflow.enums import NodeType
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from core.workflow.enums import NodeType, WorkflowExecutionStatus
from extensions.ext_database import db
-from fields.app_fields import (
- deleted_tool_fields,
- model_config_fields,
- model_config_partial_fields,
- site_fields,
- tag_fields,
-)
-from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
-from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
-from models import App, Workflow
+from models import App, DatasetPermissionEnum, Workflow
+from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
+from services.entities.knowledge_entities.knowledge_entities import (
+ DataSource,
+ InfoList,
+ NotionIcon,
+ NotionInfo,
+ NotionPage,
+ PreProcessingRule,
+ RerankingModel,
+ Rule,
+ Segmentation,
+ WebsiteInfo,
+ WeightKeywordSetting,
+ WeightModel,
+ WeightVectorSetting,
+)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
-DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+register_enum_models(console_ns, IconType)
class AppListQuery(BaseModel):
@@ -134,124 +148,310 @@ class AppTracePayload(BaseModel):
return value
-def reg(cls: type[BaseModel]):
- console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+JSONValue: TypeAlias = Any
-reg(AppListQuery)
-reg(CreateAppPayload)
-reg(UpdateAppPayload)
-reg(CopyAppPayload)
-reg(AppExportQuery)
-reg(AppNamePayload)
-reg(AppIconPayload)
-reg(AppSiteStatusPayload)
-reg(AppApiStatusPayload)
-reg(AppTracePayload)
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(
+ from_attributes=True,
+ extra="ignore",
+ populate_by_name=True,
+ serialize_by_alias=True,
+ protected_namespaces=(),
+ )
-# Register models for flask_restx to avoid dict type issues in Swagger
-# Register base models first
-tag_model = console_ns.model("Tag", tag_fields)
-workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
+def _to_timestamp(value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return int(value.timestamp())
+ return value
-model_config_model = console_ns.model("ModelConfig", model_config_fields)
-model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
+def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
+ if icon is None or icon_type is None:
+ return None
+ icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
+ if icon_type_value.lower() != IconType.IMAGE:
+ return None
+ return file_helpers.get_signed_file_url(icon)
-deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
-site_model = console_ns.model("Site", site_fields)
+class Tag(ResponseModel):
+ id: str
+ name: str
+ type: str
-app_partial_model = console_ns.model(
- "AppPartial",
- {
- "id": fields.String,
- "name": fields.String,
- "max_active_requests": fields.Raw(),
- "description": fields.String(attribute="desc_or_prompt"),
- "mode": fields.String(attribute="mode_compatible_with_agent"),
- "icon_type": fields.String,
- "icon": fields.String,
- "icon_background": fields.String,
- "icon_url": AppIconUrlField,
- "model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
- "workflow": fields.Nested(workflow_partial_model, allow_null=True),
- "use_icon_as_answer_icon": fields.Boolean,
- "created_by": fields.String,
- "created_at": TimestampField,
- "updated_by": fields.String,
- "updated_at": TimestampField,
- "tags": fields.List(fields.Nested(tag_model)),
- "access_mode": fields.String,
- "create_user_name": fields.String,
- "author_name": fields.String,
- "has_draft_trigger": fields.Boolean,
- },
-)
-app_detail_model = console_ns.model(
- "AppDetail",
- {
- "id": fields.String,
- "name": fields.String,
- "description": fields.String,
- "mode": fields.String(attribute="mode_compatible_with_agent"),
- "icon": fields.String,
- "icon_background": fields.String,
- "enable_site": fields.Boolean,
- "enable_api": fields.Boolean,
- "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
- "workflow": fields.Nested(workflow_partial_model, allow_null=True),
- "tracing": fields.Raw,
- "use_icon_as_answer_icon": fields.Boolean,
- "created_by": fields.String,
- "created_at": TimestampField,
- "updated_by": fields.String,
- "updated_at": TimestampField,
- "access_mode": fields.String,
- "tags": fields.List(fields.Nested(tag_model)),
- },
-)
+class WorkflowPartial(ResponseModel):
+ id: str
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
-app_detail_with_site_model = console_ns.model(
- "AppDetailWithSite",
- {
- "id": fields.String,
- "name": fields.String,
- "description": fields.String,
- "mode": fields.String(attribute="mode_compatible_with_agent"),
- "icon_type": fields.String,
- "icon": fields.String,
- "icon_background": fields.String,
- "icon_url": AppIconUrlField,
- "enable_site": fields.Boolean,
- "enable_api": fields.Boolean,
- "model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
- "workflow": fields.Nested(workflow_partial_model, allow_null=True),
- "api_base_url": fields.String,
- "use_icon_as_answer_icon": fields.Boolean,
- "max_active_requests": fields.Integer,
- "created_by": fields.String,
- "created_at": TimestampField,
- "updated_by": fields.String,
- "updated_at": TimestampField,
- "deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
- "access_mode": fields.String,
- "tags": fields.List(fields.Nested(tag_model)),
- "site": fields.Nested(site_model),
- },
-)
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
-app_pagination_model = console_ns.model(
- "AppPagination",
- {
- "page": fields.Integer,
- "limit": fields.Integer(attribute="per_page"),
- "total": fields.Integer,
- "has_more": fields.Boolean(attribute="has_next"),
- "data": fields.List(fields.Nested(app_partial_model), attribute="items"),
- },
+
+class ModelConfigPartial(ResponseModel):
+ model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
+ pre_prompt: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class ModelConfig(ResponseModel):
+ opening_statement: str | None = None
+ suggested_questions: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("suggested_questions_list", "suggested_questions")
+ )
+ suggested_questions_after_answer: JSONValue | None = Field(
+ default=None,
+ validation_alias=AliasChoices("suggested_questions_after_answer_dict", "suggested_questions_after_answer"),
+ )
+ speech_to_text: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("speech_to_text_dict", "speech_to_text")
+ )
+ text_to_speech: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("text_to_speech_dict", "text_to_speech")
+ )
+ retriever_resource: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("retriever_resource_dict", "retriever_resource")
+ )
+ annotation_reply: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("annotation_reply_dict", "annotation_reply")
+ )
+ more_like_this: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("more_like_this_dict", "more_like_this")
+ )
+ sensitive_word_avoidance: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("sensitive_word_avoidance_dict", "sensitive_word_avoidance")
+ )
+ external_data_tools: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("external_data_tools_list", "external_data_tools")
+ )
+ model: JSONValue | None = Field(default=None, validation_alias=AliasChoices("model_dict", "model"))
+ user_input_form: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("user_input_form_list", "user_input_form")
+ )
+ dataset_query_variable: str | None = None
+ pre_prompt: str | None = None
+ agent_mode: JSONValue | None = Field(default=None, validation_alias=AliasChoices("agent_mode_dict", "agent_mode"))
+ prompt_type: str | None = None
+ chat_prompt_config: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("chat_prompt_config_dict", "chat_prompt_config")
+ )
+ completion_prompt_config: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("completion_prompt_config_dict", "completion_prompt_config")
+ )
+ dataset_configs: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("dataset_configs_dict", "dataset_configs")
+ )
+ file_upload: JSONValue | None = Field(
+ default=None, validation_alias=AliasChoices("file_upload_dict", "file_upload")
+ )
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class Site(ResponseModel):
+ access_token: str | None = Field(default=None, validation_alias="code")
+ code: str | None = None
+ title: str | None = None
+ icon_type: str | IconType | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ description: str | None = None
+ default_language: str | None = None
+ chat_color_theme: str | None = None
+ chat_color_theme_inverted: bool | None = None
+ customize_domain: str | None = None
+ copyright: str | None = None
+ privacy_policy: str | None = None
+ custom_disclaimer: str | None = None
+ customize_token_strategy: str | None = None
+ prompt_public: bool | None = None
+ app_base_url: str | None = None
+ show_workflow_steps: bool | None = None
+ use_icon_as_answer_icon: bool | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ return _build_icon_url(self.icon_type, self.icon)
+
+ @field_validator("icon_type", mode="before")
+ @classmethod
+ def _normalize_icon_type(cls, value: str | IconType | None) -> str | None:
+ if isinstance(value, IconType):
+ return value.value
+ return value
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class DeletedTool(ResponseModel):
+ type: str
+ tool_name: str
+ provider_id: str
+
+
+class AppPartial(ResponseModel):
+ id: str
+ name: str
+ max_active_requests: int | None = None
+ description: str | None = Field(default=None, validation_alias=AliasChoices("desc_or_prompt", "description"))
+ mode: str = Field(validation_alias="mode_compatible_with_agent")
+ icon_type: str | None = None
+ icon: str | None = None
+ icon_background: str | None = None
+ model_config_: ModelConfigPartial | None = Field(
+ default=None,
+ validation_alias=AliasChoices("app_model_config", "model_config"),
+ alias="model_config",
+ )
+ workflow: WorkflowPartial | None = None
+ use_icon_as_answer_icon: bool | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+ tags: list[Tag] = Field(default_factory=list)
+ access_mode: str | None = None
+ create_user_name: str | None = None
+ author_name: str | None = None
+ has_draft_trigger: bool | None = None
+
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ return _build_icon_url(self.icon_type, self.icon)
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class AppDetail(ResponseModel):
+ id: str
+ name: str
+ description: str | None = None
+ mode: str = Field(validation_alias="mode_compatible_with_agent")
+ icon: str | None = None
+ icon_background: str | None = None
+ enable_site: bool
+ enable_api: bool
+ model_config_: ModelConfig | None = Field(
+ default=None,
+ validation_alias=AliasChoices("app_model_config", "model_config"),
+ alias="model_config",
+ )
+ workflow: WorkflowPartial | None = None
+ tracing: JSONValue | None = None
+ use_icon_as_answer_icon: bool | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ updated_by: str | None = None
+ updated_at: int | None = None
+ access_mode: str | None = None
+ tags: list[Tag] = Field(default_factory=list)
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class AppDetailWithSite(AppDetail):
+ icon_type: str | None = None
+ api_base_url: str | None = None
+ max_active_requests: int | None = None
+ deleted_tools: list[DeletedTool] = Field(default_factory=list)
+ site: Site | None = None
+
+ @computed_field(return_type=str | None) # type: ignore
+ @property
+ def icon_url(self) -> str | None:
+ return _build_icon_url(self.icon_type, self.icon)
+
+
+class AppPagination(ResponseModel):
+ page: int
+ limit: int = Field(validation_alias=AliasChoices("per_page", "limit"))
+ total: int
+ has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
+ data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
+
+
+class AppExportResponse(ResponseModel):
+ data: str
+
+
+register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
+
+register_schema_models(
+ console_ns,
+ AppListQuery,
+ CreateAppPayload,
+ UpdateAppPayload,
+ CopyAppPayload,
+ AppExportQuery,
+ AppNamePayload,
+ AppIconPayload,
+ AppSiteStatusPayload,
+ AppApiStatusPayload,
+ AppTracePayload,
+ Tag,
+ WorkflowPartial,
+ ModelConfigPartial,
+ ModelConfig,
+ Site,
+ DeletedTool,
+ AppPartial,
+ AppDetail,
+ AppDetailWithSite,
+ AppPagination,
+ AppExportResponse,
+ Segmentation,
+ PreProcessingRule,
+ Rule,
+ WeightVectorSetting,
+ WeightKeywordSetting,
+ WeightModel,
+ RerankingModel,
+ InfoList,
+ NotionInfo,
+ FileInfo,
+ WebsiteInfo,
+ NotionPage,
+ NotionIcon,
+ RerankingModel,
+ DataSource,
+ LoadBalancingPayload,
)
@@ -260,7 +460,7 @@ class AppListApi(Resource):
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__])
- @console_ns.response(200, "Success", app_pagination_model)
+ @console_ns.response(200, "Success", console_ns.models[AppPagination.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -276,7 +476,8 @@ class AppListApi(Resource):
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination:
- return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
+ empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
+ return empty.model_dump(mode="json"), 200
if FeatureService.get_system_features().webapp_auth.enabled:
app_ids = [str(app.id) for app in app_pagination.items]
@@ -320,18 +521,18 @@ class AppListApi(Resource):
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
- return marshal(app_pagination, app_pagination_model), 200
+ pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
+ return pagination_model.model_dump(mode="json"), 200
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
- @console_ns.response(201, "App created successfully", app_detail_model)
+ @console_ns.response(201, "App created successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
- @marshal_with(app_detail_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@@ -341,8 +542,8 @@ class AppListApi(Resource):
app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
-
- return app, 201
+ app_detail = AppDetail.model_validate(app, from_attributes=True)
+ return app_detail.model_dump(mode="json"), 201
@console_ns.route("/apps/")
@@ -350,13 +551,12 @@ class AppApi(Resource):
@console_ns.doc("get_app_detail")
@console_ns.doc(description="Get application details")
@console_ns.doc(params={"app_id": "Application ID"})
- @console_ns.response(200, "Success", app_detail_with_site_model)
+ @console_ns.response(200, "Success", console_ns.models[AppDetailWithSite.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
- @get_app_model
- @marshal_with(app_detail_with_site_model)
+ @get_app_model(mode=None)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
@@ -367,21 +567,21 @@ class AppApi(Resource):
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
app_model.access_mode = app_setting.access_mode
- return app_model
+ response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
- @console_ns.response(200, "App updated successfully", app_detail_with_site_model)
+ @console_ns.response(200, "App updated successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
- @get_app_model
+ @get_app_model(mode=None)
@edit_permission_required
- @marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload)
@@ -398,8 +598,8 @@ class AppApi(Resource):
"max_active_requests": args.max_active_requests or 0,
}
app_model = app_service.update_app(app_model, args_dict)
-
- return app_model
+ response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.doc("delete_app")
@console_ns.doc(description="Delete application")
@@ -425,14 +625,13 @@ class AppCopyApi(Resource):
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
- @console_ns.response(201, "App copied successfully", app_detail_with_site_model)
+ @console_ns.response(201, "App copied successfully", console_ns.models[AppDetailWithSite.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
- @get_app_model
+ @get_app_model(mode=None)
@edit_permission_required
- @marshal_with(app_detail_with_site_model)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
@@ -458,7 +657,8 @@ class AppCopyApi(Resource):
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)
- return app, 201
+ response_model = AppDetailWithSite.model_validate(app, from_attributes=True)
+ return response_model.model_dump(mode="json"), 201
@console_ns.route("/apps//export")
@@ -467,11 +667,7 @@ class AppExportApi(Resource):
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
- @console_ns.response(
- 200,
- "App exported successfully",
- console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
- )
+ @console_ns.response(200, "App exported successfully", console_ns.models[AppExportResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@@ -482,13 +678,14 @@ class AppExportApi(Resource):
"""Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- return {
- "data": AppDslService.export_dsl(
+ payload = AppExportResponse(
+ data=AppDslService.export_dsl(
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
)
- }
+ )
+ return payload.model_dump(mode="json")
@console_ns.route("/apps//name")
@@ -497,20 +694,19 @@ class AppNameApi(Resource):
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
- @console_ns.response(200, "Name availability checked")
+ @console_ns.response(200, "Name availability checked", console_ns.models[AppDetail.__name__])
@setup_required
@login_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.name)
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//icon")
@@ -524,16 +720,15 @@ class AppIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//site-enable")
@@ -542,21 +737,20 @@ class AppSiteStatus(Resource):
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
- @console_ns.response(200, "Site status updated successfully", app_detail_model)
+ @console_ns.response(200, "Site status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.enable_site)
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//api-enable")
@@ -565,21 +759,20 @@ class AppApiStatus(Resource):
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
- @console_ns.response(200, "API status updated successfully", app_detail_model)
+ @console_ns.response(200, "API status updated successfully", console_ns.models[AppDetail.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
- @get_app_model
- @marshal_with(app_detail_model)
+ @get_app_model(mode=None)
def post(self, app_model):
args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.enable_api)
-
- return app_model
+ response_model = AppDetail.model_validate(app_model, from_attributes=True)
+ return response_model.model_dump(mode="json")
@console_ns.route("/apps//trace")
diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py
index 22e2aeb720..fdef54ba5a 100644
--- a/api/controllers/console/app/app_import.py
+++ b/api/controllers/console/app/app_import.py
@@ -41,14 +41,14 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
- yaml_content: str | None = None
- yaml_url: str | None = None
- name: str | None = None
- description: str | None = None
- icon_type: str | None = None
- icon: str | None = None
- icon_background: str | None = None
- app_id: str | None = None
+ yaml_content: str | None = Field(None)
+ yaml_url: str | None = Field(None)
+ name: str | None = Field(None)
+ description: str | None = Field(None)
+ icon_type: str | None = Field(None)
+ icon: str | None = Field(None)
+ icon_background: str | None = Field(None)
+ app_id: str | None = Field(None)
console_ns.schema_model(
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index c16dcfd91f..55fdcb51e4 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
-from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
@@ -177,6 +176,12 @@ annotation_hit_history_model = console_ns.model(
},
)
+
+class MessageTextField(fields.Raw):
+ def format(self, value):
+ return value[0]["text"] if value else ""
+
+
# Simple message detail model
simple_message_detail_model = console_ns.model(
"SimpleMessageDetail",
@@ -343,10 +348,13 @@ class CompletionConversationApi(Resource):
)
if args.keyword:
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(args.keyword)
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
- Message.query.ilike(f"%{args.keyword}%"),
- Message.answer.ilike(f"%{args.keyword}%"),
+ Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
+ Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
@@ -455,7 +463,10 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
- keyword_filter = f"%{args.keyword}%"
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(args.keyword)
+ keyword_filter = f"%{escaped_keyword}%"
query = (
query.join(
Message,
@@ -464,11 +475,11 @@ class ChatConversationApi(Resource):
.join(subquery, subquery.c.conversation_id == Conversation.id)
.where(
or_(
- Message.query.ilike(keyword_filter),
- Message.answer.ilike(keyword_filter),
- Conversation.name.ilike(keyword_filter),
- Conversation.introduction.ilike(keyword_filter),
- subquery.c.from_end_user_session_id.ilike(keyword_filter),
+ Message.query.ilike(keyword_filter, escape="\\"),
+ Message.answer.ilike(keyword_filter, escape="\\"),
+ Conversation.name.ilike(keyword_filter, escape="\\"),
+ Conversation.introduction.ilike(keyword_filter, escape="\\"),
+ subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
),
)
.group_by(Conversation.id)
@@ -581,9 +592,12 @@ def _get_conversation(app_model, conversation_id):
if not conversation:
raise NotFound("Conversation Not Exists.")
- if not conversation.read_at:
- conversation.read_at = naive_utc_now()
- conversation.read_account_id = current_user.id
- db.session.commit()
+ db.session.execute(
+ sa.update(Conversation)
+ .where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
+ .values(read_at=naive_utc_now(), read_account_id=current_user.id)
+ )
+ db.session.commit()
+ db.session.refresh(conversation)
return conversation
diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py
index fbd7901646..3fa15d6d6d 100644
--- a/api/controllers/console/app/error.py
+++ b/api/controllers/console/app/error.py
@@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."
- code = 400
+ code = 404
class DraftWorkflowNotSync(BaseHTTPException):
error_code = "draft_workflow_not_sync"
description = "Workflow graph might have been modified, please refresh and resubmit."
- code = 400
+ code = 409
class TracingConfigNotExist(BaseHTTPException):
@@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
+
+
+class NeedAddIdsError(BaseHTTPException):
+ error_code = "need_add_ids"
+ description = "Need to add ids."
+ code = 400
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index b4fc44767a..1ac55b5e8d 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -1,5 +1,4 @@
from collections.abc import Sequence
-from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
@@ -12,10 +11,12 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
+from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
+from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
@@ -26,28 +27,13 @@ from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-class RuleGeneratePayload(BaseModel):
- instruction: str = Field(..., description="Rule generation instruction")
- model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
- no_variable: bool = Field(default=False, description="Whether to exclude variables")
-
-
-class RuleCodeGeneratePayload(RuleGeneratePayload):
- code_language: str = Field(default="javascript", description="Programming language for code generation")
-
-
-class RuleStructuredOutputPayload(BaseModel):
- instruction: str = Field(..., description="Structured output generation instruction")
- model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
-
-
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
- model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
+ model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
@@ -64,6 +50,7 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
+reg(ModelConfig)
@console_ns.route("/rule-generate")
@@ -82,12 +69,7 @@ class RuleGenerateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
try:
- rules = LLMGenerator.generate_rule_config(
- tenant_id=current_tenant_id,
- instruction=args.instruction,
- model_config=args.model_config_data,
- no_variable=args.no_variable,
- )
+ rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@@ -118,9 +100,7 @@ class RuleCodeGenerateApi(Resource):
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
- instruction=args.instruction,
- model_config=args.model_config_data,
- code_language=args.code_language,
+ args=args,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -152,8 +132,7 @@ class RuleStructuredOutputGenerateApi(Resource):
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
- instruction=args.instruction,
- model_config=args.model_config_data,
+ args=args,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -204,23 +183,29 @@ class InstructionGenerateApi(Resource):
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
- instruction=args.instruction,
- model_config=args.model_config_data,
- no_variable=True,
+ args=RuleGeneratePayload(
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ no_variable=True,
+ ),
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
- instruction=args.instruction,
- model_config=args.model_config_data,
- no_variable=True,
+ args=RuleGeneratePayload(
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ no_variable=True,
+ ),
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
- instruction=args.instruction,
- model_config=args.model_config_data,
- code_language=args.language,
+ args=RuleCodeGeneratePayload(
+ instruction=args.instruction,
+ model_config=args.model_config_data,
+ code_language=args.language,
+ ),
)
case _:
return {"error": f"invalid node type: {node_type}"}
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index b4f2ef0ba8..755463cb70 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -12,6 +12,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
+from controllers.console.app.workflow_run import workflow_run_node_execution_model
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@@ -35,7 +36,6 @@ from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
-from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
@@ -88,26 +88,6 @@ workflow_pagination_fields_copy = workflow_pagination_fields.copy()
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
-# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
-# Otherwise register it here
-from fields.end_user_fields import simple_end_user_fields
-
-simple_end_user_model = None
-try:
- simple_end_user_model = console_ns.models.get("SimpleEndUser")
-except AttributeError:
- pass
-if simple_end_user_model is None:
- simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
-
-workflow_run_node_execution_model = None
-try:
- workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
-except AttributeError:
- pass
-if workflow_run_node_execution_model is None:
- workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
-
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
@@ -470,7 +450,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
- args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
+ args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = AppGenerateService.generate_single_loop(
@@ -508,7 +488,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
- args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
+ args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = AppGenerateService.generate_single_loop(
@@ -999,6 +979,7 @@ class DraftWorkflowTriggerRunApi(Resource):
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
workflow_args = dict(event.workflow_args)
+
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
return helper.compact_generate_response(
AppGenerateService.generate(
@@ -1147,6 +1128,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
try:
workflow_args = dict(trigger_debug_event.workflow_args)
+
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
response = AppGenerateService.generate(
app_model=app_model,
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index fa67fb8154..6736f24a2e 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
-from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
+from fields.workflow_app_log_fields import (
+ build_workflow_app_log_pagination_model,
+ build_workflow_archived_log_pagination_model,
+)
from libs.login import login_required
from models import App
from models.model import AppMode
@@ -61,6 +64,7 @@ console_ns.schema_model(
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
+workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@console_ns.route("/apps//workflow-app-logs")
@@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource):
)
return workflow_app_log_pagination
+
+
+@console_ns.route("/apps//workflow-archived-logs")
+class WorkflowArchivedLogApi(Resource):
+ @console_ns.doc("get_workflow_archived_logs")
+ @console_ns.doc(description="Get workflow archived execution logs")
+ @console_ns.doc(params={"app_id": "Application ID"})
+ @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
+ @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model(mode=[AppMode.WORKFLOW])
+ @marshal_with(workflow_archived_log_pagination_model)
+ def get(self, app_model: App):
+ """
+ Get workflow archived logs
+ """
+ args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
+
+ workflow_app_service = WorkflowAppService()
+ with Session(db.engine) as session:
+ workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
+ session=session,
+ app_model=app_model,
+ page=args.page,
+ limit=args.limit,
+ )
+
+ return workflow_app_log_pagination
diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py
index 8f1871f1e9..fa74f8aea1 100644
--- a/api/controllers/console/app/workflow_run.py
+++ b/api/controllers/console/app/workflow_run.py
@@ -1,12 +1,15 @@
+from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
+from sqlalchemy import select
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
+from extensions.ext_database import db
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import (
@@ -19,14 +22,17 @@ from fields.workflow_run_fields import (
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
+from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
-from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
+from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
+EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
+workflow_run_export_fields = console_ns.model(
+ "WorkflowRunExport",
+ {
+ "status": fields.String(description="Export status: success/failed"),
+ "presigned_url": fields.String(description="Pre-signed URL for download", required=False),
+ "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
+ },
+)
+
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
return result
+@console_ns.route("/apps//workflow-runs//export")
+class WorkflowRunExportApi(Resource):
+ @console_ns.doc("get_workflow_run_export_url")
+ @console_ns.doc(description="Generate a download URL for an archived workflow run.")
+ @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
+ @console_ns.response(200, "Export URL generated", workflow_run_export_fields)
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model()
+ def get(self, app_model: App, run_id: str):
+ tenant_id = str(app_model.tenant_id)
+ app_id = str(app_model.id)
+ run_id_str = str(run_id)
+
+ run_created_at = db.session.scalar(
+ select(WorkflowArchiveLog.run_created_at)
+ .where(
+ WorkflowArchiveLog.tenant_id == tenant_id,
+ WorkflowArchiveLog.app_id == app_id,
+ WorkflowArchiveLog.workflow_run_id == run_id_str,
+ )
+ .limit(1)
+ )
+ if not run_created_at:
+ return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404
+
+ prefix = (
+ f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/"
+ f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}"
+ )
+ archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+
+ try:
+ archive_storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ return {"code": "archive_storage_not_configured", "message": str(e)}, 500
+
+ presigned_url = archive_storage.generate_presigned_url(
+ archive_key,
+ expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
+ )
+ expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
+ return {
+ "status": "success",
+ "presigned_url": presigned_url,
+ "presigned_url_expires_at": expires_at.isoformat(),
+ }, 200
+
+
@console_ns.route("/apps//advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs_count")
diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py
index 5d16e4f979..8236e766ae 100644
--- a/api/controllers/console/app/workflow_trigger.py
+++ b/api/controllers/console/app/workflow_trigger.py
@@ -1,13 +1,14 @@
import logging
from flask import request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
+from controllers.common.schema import get_or_create_model
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
@@ -22,6 +23,14 @@ from ..wraps import account_initialization_required, edit_permission_required, s
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
+
+triggers_list_fields_copy = triggers_list_fields.copy()
+triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
+triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
+
+webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
+
class Parser(BaseModel):
node_id: str
@@ -48,7 +57,7 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
- @marshal_with(webhook_trigger_fields)
+ @marshal_with(webhook_trigger_model)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
@@ -80,7 +89,7 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
- @marshal_with(triggers_list_fields)
+ @marshal_with(triggers_list_model)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
@@ -114,13 +123,13 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps//trigger-enable")
class AppTriggerEnableApi(Resource):
- @console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
+ @console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
- @marshal_with(trigger_fields)
+ @marshal_with(trigger_model)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py
index 9bb2718f89..e687d980fa 100644
--- a/api/controllers/console/app/wraps.py
+++ b/api/controllers/console/app/wraps.py
@@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
+def _load_app_model_with_trial(app_id: str) -> App | None:
+ app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
+ return app_model
+
+
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
+
+
+def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
+ def decorator(view_func: Callable[P, R]):
+ @wraps(view_func)
+ def decorated_view(*args: P.args, **kwargs: P.kwargs):
+ if not kwargs.get("app_id"):
+ raise ValueError("missing app_id in path parameters")
+
+ app_id = kwargs.get("app_id")
+ app_id = str(app_id)
+
+ del kwargs["app_id"]
+
+ app_model = _load_app_model_with_trial(app_id)
+
+ if not app_model:
+ raise AppNotFoundError()
+
+ app_mode = AppMode.value_of(app_model.mode)
+
+ if mode is not None:
+ if isinstance(mode, list):
+ modes = mode
+ else:
+ modes = [mode]
+
+ if app_mode not in modes:
+ mode_values = {m.value for m in modes}
+ raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
+
+ kwargs["app_model"] = app_model
+
+ return view_func(*args, **kwargs)
+
+ return decorated_view
+
+ if view is None:
+ return decorator
+ else:
+ return decorator(view)
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index 6834656a7f..f741107b87 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -7,9 +7,9 @@ from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
-from libs.helper import EmailStr, extract_remote_ip, timezone
+from libs.helper import EmailStr, timezone
from models import AccountStatus
-from services.account_service import AccountService, RegisterService
+from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -63,13 +63,19 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
- reg_email = args.email
token = args.token
- invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
+ invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
+
+ # Check workspace permission
+ if tenant:
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
@@ -93,7 +99,6 @@ class ActivateApi(Resource):
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
- "data": fields.Raw(description="Login token data"),
},
),
)
@@ -101,11 +106,12 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
- invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
+ normalized_request_email = args.email.lower() if args.email else None
+ invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
- RegisterService.revoke_token(args.workspace_id, args.email, args.token)
+ RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"]
account.name = args.name
@@ -117,6 +123,4 @@ class ActivateApi(Resource):
account.initialized_at = naive_utc_now()
db.session.commit()
- token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
-
- return {"result": "success", "data": token_pair.model_dump()}
+ return {"result": "success"}
diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py
index fa082c735d..c2a95ddad2 100644
--- a/api/controllers/console/auth/email_register.py
+++ b/api/controllers/console/auth/email_register.py
@@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
- token = None
- token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
+ token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
+ is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_email_register_error_rate_limit(args.email)
+ AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
- AccountService.reset_email_register_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_email_register_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register")
@@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
+ normalized_email = email.lower()
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
- account = self._create_new_account(email, args.password_confirm)
+ account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
- def _create_new_account(self, email, password) -> Account | None:
+ def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
try:
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index 661f591182..394f205d93 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
-from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
- email=args.email,
+ email=normalized_email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(args.email)
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=args.code, additional_data={"phase": "reset"}
+ token_email, code=args.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets")
@@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
-
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index f486f4c313..400df138b8 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -1,3 +1,5 @@
+from typing import Any
+
import flask_login
from flask import make_response, request
from flask_restx import Resource
@@ -22,7 +24,12 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
-from controllers.console.wraps import email_password_login_enabled, setup_required
+from controllers.console.wraps import (
+ decrypt_code_field,
+ decrypt_password_field,
+ email_password_login_enabled,
+ setup_required,
+)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
@@ -79,36 +86,42 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
+ @decrypt_password_field
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
+ request_email = args.email
+ normalized_email = request_email.lower()
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
- is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
+ is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
- # TODO: why invitation is re-assigned with different type?
- invitation = args.invite_token # type: ignore
- if invitation:
- invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
+ invite_token = args.invite_token
+ invitation_data: dict[str, Any] | None = None
+ if invite_token:
+ invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
+ if invitation_data is None:
+ invite_token = None
try:
- if invitation:
- data = invitation.get("data", {}) # type: ignore
+ if invitation_data:
+ data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
- if invitee_email != args.email:
+ invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
+ if invitee_email_normalized != normalized_email:
raise InvalidEmailError()
- account = AccountService.authenticate(args.email, args.password, args.invite_token)
- else:
- account = AccountService.authenticate(args.email, args.password)
+ account = _authenticate_account_with_case_fallback(
+ request_email, normalized_email, args.password, invite_token
+ )
except services.errors.account.AccountLoginError:
raise AccountBannedError()
- except services.errors.account.AccountPasswordError:
- AccountService.add_login_error_rate_limit(args.email)
- raise AuthenticationFailedError()
+ except services.errors.account.AccountPasswordError as exc:
+ AccountService.add_login_error_rate_limit(normalized_email)
+ raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@@ -123,7 +136,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -163,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
- email=args.email,
+ email=normalized_email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -189,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
+ normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -199,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
- account = AccountService.get_user_through_email(args.email)
+ account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
- token = AccountService.send_email_code_login_email(email=args.email, language=language)
+ token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else:
raise AccountNotFound()
else:
@@ -218,17 +233,21 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
+ @decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
- user_email = args.email
+ original_email = args.email
+ user_email = original_email.lower()
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args.email:
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@@ -236,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
- account = AccountService.get_user_through_email(user_email)
+ account = _get_account_with_case_fallback(original_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@@ -267,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
- AccountService.reset_login_error_rate_limit(args.email)
+ AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -301,3 +320,22 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
+
+
+def _get_account_with_case_fallback(email: str):
+ account = AccountService.get_user_through_email(email)
+ if account or email == email.lower():
+ return account
+
+ return AccountService.get_user_through_email(email.lower())
+
+
+def _authenticate_account_with_case_fallback(
+ original_email: str, normalized_email: str, password: str, invite_token: str | None
+):
+ try:
+ return AccountService.authenticate(original_email, password, invite_token)
+ except services.errors.account.AccountPasswordError:
+ if original_email == normalized_email:
+ raise
+ return AccountService.authenticate(normalized_email, password, invite_token)
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 7ad1e56373..112e152432 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -3,7 +3,6 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
-from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -118,13 +117,16 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
- if invitation_email != user_info.email:
+ invitation_email_normalized = (
+ invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
+ )
+ if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
- account = _generate_account(provider, user_info)
+ account, oauth_new_user = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
@@ -159,7 +161,10 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
- response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
+ base_url = dify_config.CONSOLE_WEB_URL
+ query_char = "&" if "?" in base_url else "?"
+ target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
+ response = redirect(target_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@@ -172,14 +177,15 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
-def _generate_account(provider: str, user_info: OAuthUserInfo):
+def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
+ oauth_new_user = False
if account:
tenants = TenantService.get_join_tenants(account)
@@ -193,8 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(new_tenant)
if not account:
+ normalized_email = user_info.email.lower()
+ oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
- if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
+ if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@@ -205,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
- email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
+ email=normalized_email,
+ name=account_name,
+ password=None,
+ open_id=user_info.id,
+ provider=provider,
)
# Set interface language
@@ -220,4 +232,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
- return account
+ return account, oauth_new_user
diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py
index 7f907dc420..ac039f9c5d 100644
--- a/api/controllers/console/billing/billing.py
+++ b/api/controllers/console/billing/billing.py
@@ -1,8 +1,9 @@
import base64
+from typing import Literal
from flask import request
from flask_restx import Resource, fields
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
- plan: str = Field(..., description="Subscription plan")
- interval: str = Field(..., description="Billing interval")
-
- @field_validator("plan")
- @classmethod
- def validate_plan(cls, value: str) -> str:
- if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
- raise ValueError("Invalid plan")
- return value
-
- @field_validator("interval")
- @classmethod
- def validate_interval(cls, value: str) -> str:
- if value not in {"month", "year"}:
- raise ValueError("Invalid interval")
- return value
+ plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
+ interval: Literal["month", "year"] = Field(..., description="Billing interval")
class PartnerTenantsPayload(BaseModel):
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index 01f268d94d..01e9bf77c0 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -3,13 +3,13 @@ from collections.abc import Generator
from typing import Any, cast
from flask import request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
-from controllers.common.schema import register_schema_model
+from controllers.common.schema import get_or_create_model, register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
-from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
+from fields.data_source_fields import (
+ integrate_fields,
+ integrate_icon_fields,
+ integrate_list_fields,
+ integrate_notion_info_list_fields,
+ integrate_page_fields,
+ integrate_workspace_fields,
+)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
@@ -36,9 +43,62 @@ class NotionEstimatePayload(BaseModel):
doc_language: str = Field(default="English")
+class DataSourceNotionListQuery(BaseModel):
+ dataset_id: str | None = Field(default=None, description="Dataset ID")
+ credential_id: str = Field(..., description="Credential ID", min_length=1)
+ datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
+
+
+class DataSourceNotionPreviewQuery(BaseModel):
+ credential_id: str = Field(..., description="Credential ID", min_length=1)
+
+
register_schema_model(console_ns, NotionEstimatePayload)
+integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
+
+integrate_page_fields_copy = integrate_page_fields.copy()
+integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
+integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
+
+integrate_workspace_fields_copy = integrate_workspace_fields.copy()
+integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
+integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
+
+integrate_fields_copy = integrate_fields.copy()
+integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
+integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
+
+integrate_list_fields_copy = integrate_list_fields.copy()
+integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
+integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
+
+notion_page_fields = {
+ "page_name": fields.String,
+ "page_id": fields.String,
+ "page_icon": fields.Nested(integrate_icon_model, allow_null=True),
+ "is_bound": fields.Boolean,
+ "parent_id": fields.String,
+ "type": fields.String,
+}
+notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
+
+notion_workspace_fields = {
+ "workspace_name": fields.String,
+ "workspace_id": fields.String,
+ "workspace_icon": fields.String,
+ "pages": fields.List(fields.Nested(notion_page_model)),
+}
+notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
+
+integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
+integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
+integrate_notion_info_list_model = get_or_create_model(
+ "NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
+)
+
+
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates//",
@@ -47,7 +107,7 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(integrate_list_fields)
+ @marshal_with(integrate_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
@@ -132,18 +192,19 @@ class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(integrate_notion_info_list_fields)
+ @marshal_with(integrate_notion_info_list_model)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
- dataset_id = request.args.get("dataset_id", default=None, type=str)
- credential_id = request.args.get("credential_id", default=None, type=str)
- if not credential_id:
- raise ValueError("Credential id is required.")
+ query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
+
+ # Get datasource_parameters from query string (optional, for GitHub and other datasources)
+ datasource_parameters = query.datasource_parameters or {}
+
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
- credential_id=credential_id,
+ credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
@@ -152,8 +213,8 @@ class DataSourceNotionListApi(Resource):
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
- if dataset_id:
- dataset = DatasetService.get_dataset(dataset_id)
+ if query.dataset_id:
+ dataset = DatasetService.get_dataset(query.dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
@@ -161,7 +222,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
- dataset_id=dataset_id,
+ dataset_id=query.dataset_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
@@ -187,7 +248,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
- datasource_parameters={},
+ datasource_parameters=datasource_parameters,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@@ -218,32 +279,30 @@ class DataSourceNotionListApi(Resource):
@console_ns.route(
- "/notion/workspaces//pages///preview",
+ "/notion/pages///preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
- def get(self, workspace_id, page_id, page_type):
+ def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
- credential_id = request.args.get("credential_id", default=None, type=str)
- if not credential_id:
- raise ValueError("Credential id is required.")
+ query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
+
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
- credential_id=credential_id,
+ credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
- workspace_id = str(workspace_id)
page_id = str(page_id)
extractor = NotionExtractor(
- notion_workspace_id=workspace_id,
+ notion_workspace_id="",
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 70b6e932e9..30e4ed1119 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
+ content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
@@ -41,6 +42,7 @@ from fields.dataset_fields import (
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
+ file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
@@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
-
-def _get_or_create_model(model_name: str, field_def):
- existing = console_ns.models.get(model_name)
- if existing is None:
- existing = console_ns.model(model_name, field_def)
- return existing
-
-
# Register models for flask_restx to avoid dict type issues in Swagger
-dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
+dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
-tag_model = _get_or_create_model("Tag", tag_fields)
+tag_model = get_or_create_model("Tag", tag_fields)
-keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
-vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
+keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
+vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
-weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
+weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
-reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
+reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
-dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
+dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
-external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
+external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
-external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
+external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
-doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
+doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
-icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
+icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
-dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
+dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
-dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
+file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
-app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
+content_fields_copy = content_fields.copy()
+content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
+content_model = get_or_create_model("DatasetContent", content_fields_copy)
+
+dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
+dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
+dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
+
+app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
-related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
+related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
@@ -146,7 +148,8 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
- partial_member_list: list[str] | None = None
+ summary_index_setting: dict[str, Any] | None = None
+ partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
@@ -176,7 +179,18 @@ class IndexingEstimatePayload(BaseModel):
return result
-register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
+class ConsoleDatasetListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+ include_all: bool = Field(default=False, description="Include all datasets")
+ ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs")
+ tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
+
+
+register_schema_models(
+ console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
+)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@@ -223,6 +237,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
+ VectorType.SEEKDB,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
@@ -230,6 +245,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
+ VectorType.IRIS,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
@@ -273,18 +289,26 @@ class DatasetListApi(Resource):
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- ids = request.args.getlist("ids")
+ # Convert query parameters to dict, handling list parameters correctly
+ query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
+ # Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
+ if "ids" in request.args:
+ query_params["ids"] = request.args.getlist("ids")
+ if "tag_ids" in request.args:
+ query_params["tag_ids"] = request.args.getlist("tag_ids")
+ query = ConsoleDatasetListQuery.model_validate(query_params)
# provider = request.args.get("provider", default="vendor")
- search = request.args.get("keyword", default=None, type=str)
- tag_ids = request.args.getlist("tag_ids")
- include_all = request.args.get("include_all", default="false").lower() == "true"
- if ids:
- datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
+ if query.ids:
+ datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
- page, limit, current_tenant_id, current_user, search, tag_ids, include_all
+ query.page,
+ query.limit,
+ current_tenant_id,
+ current_user,
+ query.keyword,
+ query.tag_ids,
+ query.include_all,
)
# check embedding setting
@@ -316,7 +340,13 @@ class DatasetListApi(Resource):
else:
item.update({"partial_member_list": []})
- response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
+ response = {
+ "data": data,
+ "has_more": len(datasets) == query.limit,
+ "limit": query.limit,
+ "total": total,
+ "page": query.page,
+ }
return response, 200
@console_ns.doc("create_dataset")
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 6145da31a5..6e3c0db8a3 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -2,17 +2,19 @@ import json
import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
-from typing import Literal, cast
+from contextlib import ExitStack
+from typing import Any, Literal, cast
+from uuid import UUID
import sqlalchemy as sa
-from flask import request
+from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@@ -42,6 +44,8 @@ from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
+from services.file_service import FileService
+from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@@ -65,35 +69,31 @@ from ..wraps import (
logger = logging.getLogger(__name__)
-
-def _get_or_create_model(model_name: str, field_def):
- existing = console_ns.models.get(model_name)
- if existing is None:
- existing = console_ns.model(model_name, field_def)
- return existing
+# NOTE: Keep constants near the top of the module for discoverability.
+DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
# Register models for flask_restx to avoid dict type issues in Swagger
-dataset_model = _get_or_create_model("Dataset", dataset_fields)
+dataset_model = get_or_create_model("Dataset", dataset_fields)
-document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
+document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
-document_model = _get_or_create_model("Document", document_fields_copy)
+document_model = get_or_create_model("Document", document_fields_copy)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
-document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
+document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
-dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
+dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
@@ -104,6 +104,25 @@ class DocumentRenamePayload(BaseModel):
name: str
+class GenerateSummaryPayload(BaseModel):
+ document_list: list[str]
+
+
+class DocumentBatchDownloadZipPayload(BaseModel):
+ """Request payload for bulk downloading documents as a zip archive."""
+
+ document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
+
+
+class DocumentDatasetListParam(BaseModel):
+ page: int = Field(1, title="Page", description="Page number.")
+ limit: int = Field(20, title="Limit", description="Page size.")
+ search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
+ sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
+ status: str | None = Field(None, title="Status", description="Document status.")
+ fetch_val: str = Field("false", alias="fetch")
+
+
register_schema_models(
console_ns,
KnowledgeConfig,
@@ -111,6 +130,8 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
+ GenerateSummaryPayload,
+ DocumentBatchDownloadZipPayload,
)
@@ -225,14 +246,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
- sort = request.args.get("sort", default="-created_at", type=str)
- status = request.args.get("status", default=None, type=str)
+ raw_args = request.args.to_dict()
+ param = DocumentDatasetListParam.model_validate(raw_args)
+ page = param.page
+ limit = param.limit
+ search = param.search
+ sort = param.sort_by
+ status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
- fetch_val = request.args.get("fetch", default="false")
+ fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:
@@ -295,6 +318,13 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
+
+ DocumentService.enrich_documents_with_summary_index_status(
+ documents=documents,
+ dataset=dataset,
+ tenant_id=current_tenant_id,
+ )
+
if fetch:
for document in documents:
completed_segments = (
@@ -572,7 +602,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
- "credential_id": data_source_info["credential_id"],
+ "credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
@@ -751,12 +781,12 @@ class DocumentApi(DocumentResource):
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
- data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
- "data_source_info": data_source_info,
+ "data_source_info": document.data_source_info_dict,
+ "data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@@ -780,16 +810,17 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
+ "need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
- data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
- "data_source_info": data_source_info,
+ "data_source_info": document.data_source_info_dict,
+ "data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@@ -815,6 +846,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
+ "need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@@ -842,6 +874,62 @@ class DocumentApi(DocumentResource):
return {"result": "success"}, 204
+@console_ns.route("/datasets//documents//download")
+class DocumentDownloadApi(DocumentResource):
+ """Return a signed download URL for a dataset document's original uploaded file."""
+
+ @console_ns.doc("get_dataset_document_download_url")
+ @console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
+ # Reuse the shared permission/tenant checks implemented in DocumentResource.
+ document = self.get_document(str(dataset_id), str(document_id))
+ return {"url": DocumentService.get_document_download_url(document)}
+
+
+@console_ns.route("/datasets//documents/download-zip")
+class DocumentBatchDownloadZipApi(DocumentResource):
+ """Download multiple uploaded-file documents as a single ZIP (avoids browser multi-download limits)."""
+
+ @console_ns.doc("download_dataset_documents_as_zip")
+ @console_ns.doc(description="Download selected dataset documents as a single ZIP archive (upload-file only)")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ @console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
+ def post(self, dataset_id: str):
+ """Stream a ZIP archive containing the requested uploaded documents."""
+ # Parse and validate request payload.
+ payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
+
+ current_user, current_tenant_id = current_account_with_tenant()
+ dataset_id = str(dataset_id)
+ document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
+ upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
+ dataset_id=dataset_id,
+ document_ids=document_ids,
+ tenant_id=current_tenant_id,
+ current_user=current_user,
+ )
+
+ # Delegate ZIP packing to FileService, but keep Flask response+cleanup in the route.
+ with ExitStack() as stack:
+ zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
+ response = send_file(
+ zip_path,
+ mimetype="application/zip",
+ as_attachment=True,
+ download_name=download_name,
+ )
+ cleanup = stack.pop_all()
+ response.call_on_close(cleanup.close)
+ return response
+
+
@console_ns.route("/datasets//documents//processing/")
class DocumentProcessingApi(DocumentResource):
@console_ns.doc("update_document_processing")
@@ -1098,7 +1186,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(document_fields)
+ @marshal_with(document_model)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@@ -1182,3 +1270,137 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
+
+
+@console_ns.route("/datasets//documents/generate-summary")
+class DocumentGenerateSummaryApi(Resource):
+ @console_ns.doc("generate_summary_for_documents")
+ @console_ns.doc(description="Generate summary index for documents")
+ @console_ns.doc(params={"dataset_id": "Dataset ID"})
+ @console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
+ @console_ns.response(200, "Summary generation started successfully")
+ @console_ns.response(400, "Invalid request or dataset configuration")
+ @console_ns.response(403, "Permission denied")
+ @console_ns.response(404, "Dataset not found")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def post(self, dataset_id):
+ """
+ Generate summary index for specified documents.
+
+ This endpoint checks if the dataset configuration supports summary generation
+ (indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
+ then asynchronously generates summary indexes for the provided documents.
+ """
+ current_user, _ = current_account_with_tenant()
+ dataset_id = str(dataset_id)
+
+ # Get dataset
+ dataset = DatasetService.get_dataset(dataset_id)
+ if not dataset:
+ raise NotFound("Dataset not found.")
+
+ # Check permissions
+ if not current_user.is_dataset_editor:
+ raise Forbidden()
+
+ try:
+ DatasetService.check_dataset_permission(dataset, current_user)
+ except services.errors.account.NoPermissionError as e:
+ raise Forbidden(str(e))
+
+ # Validate request payload
+ payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
+ document_list = payload.document_list
+
+ if not document_list:
+ from werkzeug.exceptions import BadRequest
+
+ raise BadRequest("document_list cannot be empty.")
+
+ # Check if dataset configuration supports summary generation
+ if dataset.indexing_technique != "high_quality":
+ raise ValueError(
+ f"Summary generation is only available for 'high_quality' indexing technique. "
+ f"Current indexing technique: {dataset.indexing_technique}"
+ )
+
+ summary_index_setting = dataset.summary_index_setting
+ if not summary_index_setting or not summary_index_setting.get("enable"):
+ raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
+
+ # Verify all documents exist and belong to the dataset
+ documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
+
+ if len(documents) != len(document_list):
+ found_ids = {doc.id for doc in documents}
+ missing_ids = set(document_list) - found_ids
+ raise NotFound(f"Some documents not found: {list(missing_ids)}")
+
+ # Dispatch async tasks for each document
+ for document in documents:
+ # Skip qa_model documents as they don't generate summaries
+ if document.doc_form == "qa_model":
+ logger.info("Skipping summary generation for qa_model document %s", document.id)
+ continue
+
+ # Dispatch async task
+ generate_summary_index_task.delay(dataset_id, document.id)
+ logger.info(
+ "Dispatched summary generation task for document %s in dataset %s",
+ document.id,
+ dataset_id,
+ )
+
+ return {"result": "success"}, 200
+
+
+@console_ns.route("/datasets//documents//summary-status")
+class DocumentSummaryStatusApi(DocumentResource):
+ @console_ns.doc("get_document_summary_status")
+ @console_ns.doc(description="Get summary index generation status for a document")
+ @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
+ @console_ns.response(200, "Summary status retrieved successfully")
+ @console_ns.response(404, "Document not found")
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, dataset_id, document_id):
+ """
+ Get summary index generation status for a document.
+
+ Returns:
+ - total_segments: Total number of segments in the document
+ - summary_status: Dictionary with status counts
+ - completed: Number of summaries completed
+ - generating: Number of summaries being generated
+ - error: Number of summaries with errors
+ - not_started: Number of segments without summary records
+ - summaries: List of summary records with status and content preview
+ """
+ current_user, _ = current_account_with_tenant()
+ dataset_id = str(dataset_id)
+ document_id = str(document_id)
+
+ # Get dataset
+ dataset = DatasetService.get_dataset(dataset_id)
+ if not dataset:
+ raise NotFound("Dataset not found.")
+
+ # Check permissions
+ try:
+ DatasetService.check_dataset_permission(dataset, current_user)
+ except services.errors.account.NoPermissionError as e:
+ raise Forbidden(str(e))
+
+ # Get summary status detail from service
+ from services.summary_index_service import SummaryIndexService
+
+ result = SummaryIndexService.get_document_summary_status_detail(
+ document_id=document_id,
+ dataset_id=dataset_id,
+ )
+
+ return result, 200
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index e73abc2555..23a668112d 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -3,10 +3,12 @@ import uuid
from flask import request
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
-from sqlalchemy import select
+from sqlalchemy import String, cast, func, or_, select
+from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.exceptions import Forbidden, NotFound
import services
+from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@@ -28,6 +30,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
+from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
@@ -38,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
+def _get_segment_with_summary(segment, dataset_id):
+ """Helper function to marshal segment and add summary information."""
+ from services.summary_index_service import SummaryIndexService
+
+ segment_dict = dict(marshal(segment, segment_fields))
+ # Query summary for this segment (only enabled summaries)
+ summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
+ segment_dict["summary"] = summary.summary_content if summary else None
+ return segment_dict
+
+
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@@ -60,6 +74,7 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
+ summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@@ -87,6 +102,7 @@ register_schema_models(
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
+ ChildChunkUpdateArgs,
)
@@ -143,7 +159,31 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
- query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
+ # Escape special characters in keyword to prevent SQL injection via LIKE wildcards
+ escaped_keyword = escape_like_pattern(keyword)
+ # Search in both content and keywords fields
+ # Use database-specific methods for JSON array search
+ if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
+ # PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
+ keywords_condition = func.array_to_string(
+ func.array(
+ select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
+ .correlate(DocumentSegment)
+ .scalar_subquery()
+ ),
+ ",",
+ ).ilike(f"%{escaped_keyword}%", escape="\\")
+ else:
+ # MySQL: Cast JSON to string for pattern matching
+ # MySQL stores Chinese text directly in JSON without Unicode escaping
+ keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")
+
+ query = query.where(
+ or_(
+ DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
+ keywords_condition,
+ )
+ )
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":
@@ -153,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
+ # Query summaries for all segments in this page (batch query for efficiency)
+ segment_ids = [segment.id for segment in segments.items]
+ summaries = {}
+ if segment_ids:
+ from services.summary_index_service import SummaryIndexService
+
+ summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
+ # Only include enabled summaries (already filtered by service)
+ summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
+
+ # Add summary to each segment
+ segments_with_summary = []
+ for segment in segments.items:
+ segment_dict = dict(marshal(segment, segment_fields))
+ segment_dict["summary"] = summaries.get(segment.id)
+ segments_with_summary.append(segment_dict)
+
response = {
- "data": marshal(segments.items, segment_fields),
+ "data": segments_with_summary,
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@@ -300,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
- return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
+ return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets//documents//segments/")
@@ -362,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
+
+ # Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
- return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
+ return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required
@login_required
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index 89c9fcad36..86090bcd10 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
-def _get_or_create_model(model_name: str, field_def):
- existing = console_ns.models.get(model_name)
- if existing is None:
- existing = console_ns.model(model_name, field_def)
- return existing
-
-
def _build_dataset_detail_model():
- keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
- vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
+ keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
+ vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
- weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
+ weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
- reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
+ reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
- dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
+ dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
- tag_model = _get_or_create_model("Tag", tag_fields)
- doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
- external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
- external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
- icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
+ tag_model = get_or_create_model("Tag", tag_fields)
+ doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
+ external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
+ external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
+ icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@@ -64,7 +57,7 @@ def _build_dataset_detail_model():
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
- return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
+ return get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
try:
@@ -81,7 +74,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
- name: str = Field(..., min_length=1, max_length=40)
+ name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
@@ -98,12 +91,19 @@ class BedrockRetrievalPayload(BaseModel):
knowledge_id: str
+class ExternalApiTemplateListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+
+
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
+ ExternalApiTemplateListQuery,
)
@@ -124,19 +124,17 @@ class ExternalApiTemplateListApi(Resource):
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
+ query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
- page, limit, current_tenant_id, search
+ query.page, query.limit, current_tenant_id, query.keyword
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
- "has_more": len(external_knowledge_apis) == limit,
- "limit": limit,
+ "has_more": len(external_knowledge_apis) == query.limit,
+ "limit": query.limit,
"total": total,
- "page": page,
+ "page": query.page,
}
return response, 200
diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py
index 932cb4fcce..e62be13c2f 100644
--- a/api/controllers/console/datasets/hit_testing.py
+++ b/api/controllers/console/datasets/hit_testing.py
@@ -1,6 +1,13 @@
-from flask_restx import Resource
+from flask_restx import Resource, fields
from controllers.common.schema import register_schema_model
+from fields.hit_testing_fields import (
+ child_chunk_fields,
+ document_fields,
+ files_fields,
+ hit_testing_record_fields,
+ segment_fields,
+)
from libs.login import login_required
from .. import console_ns
@@ -14,13 +21,45 @@ from ..wraps import (
register_schema_model(console_ns, HitTestingPayload)
+def _get_or_create_model(model_name: str, field_def):
+ """Get or create a flask_restx model to avoid dict type issues in Swagger."""
+ existing = console_ns.models.get(model_name)
+ if existing is None:
+ existing = console_ns.model(model_name, field_def)
+ return existing
+
+
+# Register models for flask_restx to avoid dict type issues in Swagger
+document_model = _get_or_create_model("HitTestingDocument", document_fields)
+
+segment_fields_copy = segment_fields.copy()
+segment_fields_copy["document"] = fields.Nested(document_model)
+segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
+
+child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
+files_model = _get_or_create_model("HitTestingFile", files_fields)
+
+hit_testing_record_fields_copy = hit_testing_record_fields.copy()
+hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
+hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
+hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
+hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
+
+# Response model for hit testing API
+hit_testing_response_fields = {
+ "query": fields.String,
+ "records": fields.List(fields.Nested(hit_testing_record_model)),
+}
+hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
+
+
@console_ns.route("/datasets//hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
- @console_ns.response(200, "Hit testing completed successfully")
+ @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required
diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py
index db7c50f422..db1a874437 100644
--- a/api/controllers/console/datasets/hit_testing_base.py
+++ b/api/controllers/console/datasets/hit_testing_base.py
@@ -1,7 +1,7 @@
import logging
from typing import Any
-from flask_restx import marshal, reqparse
+from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -56,15 +56,10 @@ class DatasetsHitTestingBase:
HitTestingService.hit_testing_args_check(args)
@staticmethod
- def parse_args():
- parser = (
- reqparse.RequestParser()
- .add_argument("query", type=str, required=False, location="json")
- .add_argument("attachment_ids", type=list, required=False, location="json")
- .add_argument("retrieval_model", type=dict, required=False, location="json")
- .add_argument("external_retrieval_model", type=dict, required=False, location="json")
- )
- return parser.parse_args()
+ def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
+ """Validate and return hit-testing arguments from an incoming payload."""
+ hit_testing_payload = HitTestingPayload.model_validate(payload or {})
+ return hit_testing_payload.model_dump(exclude_none=True)
@staticmethod
def perform_hit_testing(dataset, args):
diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py
index 8eead1696a..05fc4cd714 100644
--- a/api/controllers/console/datasets/metadata.py
+++ b/api/controllers/console/datasets/metadata.py
@@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
-from controllers.common.schema import register_schema_model, register_schema_models
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
+ DocumentMetadataOperation,
MetadataArgs,
+ MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel):
name: str
-register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
-register_schema_model(console_ns, MetadataUpdatePayload)
+register_schema_models(
+ console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
+)
@console_ns.route("/datasets//metadata")
diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py
index 42387557d6..7caf5b52ed 100644
--- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py
+++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py
@@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview")
class DataSourceContentPreviewApi(Resource):
- @console_ns.expect(console_ns.models[Parser.__name__], validate=True)
+ @console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
index 720e2ce365..2911b1cf18 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py
@@ -2,7 +2,7 @@ import logging
from typing import Any, NoReturn
from flask import Response, request
-from flask_restx import Resource, fields, marshal, marshal_with
+from flask_restx import Resource, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@@ -14,7 +14,9 @@ from controllers.console.app.error import (
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
- _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
+ workflow_draft_variable_list_model,
+ workflow_draft_variable_list_without_value_model,
+ workflow_draft_variable_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
@@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
-from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
-def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
- return var_list.variables
-
-
-_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
- "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
- "total": fields.Raw(),
-}
-
-_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
- "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
-}
-
-
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
@@ -92,7 +79,7 @@ def _api_prerequisite(f):
@console_ns.route("/rag/pipelines//workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
+ @marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
@@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/rag/pipelines//workflows/draft/nodes//variables")
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
@@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource):
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ @marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
@@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ @marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
@@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
@console_ns.route("/rag/pipelines//workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
- @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ @marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
index d43ee9a6e0..af142b4646 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py
@@ -1,9 +1,9 @@
from flask import request
-from flask_restx import Resource, marshal_with # type: ignore
+from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
-from controllers.common.schema import register_schema_models
+from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@@ -12,7 +12,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
-from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
+from fields.rag_pipeline_fields import (
+ leaked_dependency_fields,
+ pipeline_import_check_dependencies_fields,
+ pipeline_import_fields,
+)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
@@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel):
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
+pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
+
+leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
+pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
+pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
+ fields.Nested(leaked_dependency_model)
+)
+pipeline_import_check_dependencies_model = get_or_create_model(
+ "RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
+)
+
+
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
- @marshal_with(pipeline_import_fields)
+ @marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
@@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
- @marshal_with(pipeline_import_fields)
+ @marshal_with(pipeline_import_model)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
@@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
- @marshal_with(pipeline_import_check_dependencies_fields)
+ @marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
index debe8eed97..d34fd5088d 100644
--- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
+++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py
@@ -4,7 +4,7 @@ from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
-from flask_restx import Resource, marshal_with # type: ignore
+from flask_restx import Resource, marshal_with, reqparse # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -17,6 +17,13 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
+from controllers.console.app.workflow import workflow_model, workflow_pagination_model
+from controllers.console.app.workflow_run import (
+ workflow_run_detail_model,
+ workflow_run_node_execution_list_model,
+ workflow_run_node_execution_model,
+ workflow_run_pagination_model,
+)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
@@ -30,13 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
-from fields.workflow_fields import workflow_fields, workflow_pagination_fields
-from fields.workflow_run_fields import (
- workflow_run_detail_fields,
- workflow_run_node_execution_fields,
- workflow_run_node_execution_list_fields,
- workflow_run_pagination_fields,
-)
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
@@ -145,7 +145,7 @@ class DraftRagPipelineApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
@@ -355,7 +355,7 @@ class PublishedRagPipelineRunApi(Resource):
pipeline=pipeline,
user=current_user,
args=args,
- invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED_PIPELINE,
streaming=streaming,
)
@@ -521,7 +521,7 @@ class RagPipelineDraftNodeRunApi(Resource):
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow node
@@ -569,7 +569,7 @@ class PublishedRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get published pipeline
@@ -664,7 +664,7 @@ class PublishedAllRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
- @marshal_with(workflow_pagination_fields)
+ @marshal_with(workflow_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get published workflows
@@ -708,7 +708,7 @@ class RagPipelineByIdApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
- @marshal_with(workflow_fields)
+ @marshal_with(workflow_model)
def patch(self, pipeline: Pipeline, workflow_id: str):
"""
Update workflow attributes
@@ -830,7 +830,7 @@ class RagPipelineWorkflowRunListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_pagination_fields)
+ @marshal_with(workflow_run_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get workflow run list
@@ -858,7 +858,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_detail_fields)
+ @marshal_with(workflow_run_detail_model)
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
@@ -877,7 +877,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_node_execution_list_fields)
+ @marshal_with(workflow_run_node_execution_list_model)
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
@@ -911,7 +911,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
def get(self, pipeline: Pipeline, node_id: str):
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@@ -952,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
- @marshal_with(workflow_run_node_execution_fields)
+ @marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline):
"""
Set datasource variables
@@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("type", type=str, location="args", required=False, default="all")
+ args = parser.parse_args()
+ type = args["type"]
+
rag_pipeline_service = RagPipelineService()
- recommended_plugins = rag_pipeline_service.get_recommended_plugins()
+ recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins
diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py
new file mode 100644
index 0000000000..da306fbc9d
--- /dev/null
+++ b/api/controllers/console/explore/banner.py
@@ -0,0 +1,43 @@
+from flask import request
+from flask_restx import Resource
+
+from controllers.console import api
+from controllers.console.explore.wraps import explore_banner_enabled
+from extensions.ext_database import db
+from models.model import ExporleBanner
+
+
+class BannerApi(Resource):
+ """Resource for banner list."""
+
+ @explore_banner_enabled
+ def get(self):
+ """Get banner list."""
+ language = request.args.get("language", "en-US")
+
+ # Build base query for enabled banners
+ base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
+
+ # Try to get banners in the requested language
+ banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
+
+ # Fallback to en-US if no banners found and language is not en-US
+ if not banners and language != "en-US":
+ banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
+ # Convert banners to serializable format
+ result = []
+ for banner in banners:
+ banner_data = {
+ "id": banner.id,
+ "content": banner.content, # Already parsed as JSON by SQLAlchemy
+ "link": banner.link,
+ "sort": banner.sort,
+ "status": banner.status,
+ "created_at": banner.created_at.isoformat() if banner.created_at else None,
+ }
+ result.append(banner_data)
+
+ return result
+
+
+api.add_resource(BannerApi, "/explore/banners")
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index 5901eca915..a6e5b2822a 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -40,7 +40,7 @@ from .. import console_ns
logger = logging.getLogger(__name__)
-class CompletionMessagePayload(BaseModel):
+class CompletionMessageExplorePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
@@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
raise ValueError("must be a valid UUID") from exc
-register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
+register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
# define completion api for user
@@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource):
- @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
+ @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
- payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
+ payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"
diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py
index 92da591ab4..933c80f509 100644
--- a/api/controllers/console/explore/conversation.py
+++ b/api/controllers/console/explore/conversation.py
@@ -1,9 +1,7 @@
from typing import Any
-from uuid import UUID
from flask import request
-from flask_restx import marshal_with
-from pydantic import BaseModel, Field, model_validator
+from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -12,7 +10,12 @@ from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
-from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
+from fields.conversation_fields import (
+ ConversationInfiniteScrollPagination,
+ ResultResponse,
+ SimpleConversation,
+)
+from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
from models.model import AppMode
@@ -24,7 +27,7 @@ from .. import console_ns
class ConversationListQuery(BaseModel):
- last_id: UUID | None = None
+ last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
@@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
- @marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
@@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session:
- return WebConversationService.pagination_by_last_id(
+ pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
@@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource):
invoke_from=InvokeFrom.EXPLORE,
pinned=args.pinned,
)
+ adapter = TypeAdapter(SimpleConversation)
+ conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
+ return ConversationInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=conversations,
+ ).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
@console_ns.route(
@@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource):
endpoint="installed_app_conversation_rename",
)
class ConversationRenameApi(InstalledAppResource):
- @marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
@@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
- return ConversationService.rename(
+ conversation = ConversationService.rename(
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
+ return (
+ TypeAdapter(SimpleConversation)
+ .validate_python(conversation, from_attributes=True)
+ .model_dump(mode="json")
+ )
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user)
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py
index 1e05ff4206..e96fa64f84 100644
--- a/api/controllers/console/explore/error.py
+++ b/api/controllers/console/explore/error.py
@@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
+
+
+class TrialAppNotAllowed(BaseHTTPException):
+ """*403* `Trial App Not Allowed`
+
+ Raise if the user has reached the trial app limit.
+ """
+
+ error_code = "trial_app_not_allowed"
+ code = 403
+ description = "the app is not allowed to be trial."
+
+
+class TrialAppLimitExceeded(BaseHTTPException):
+ """*403* `Trial App Limit Exceeded`
+
+ Raise if the user has exceeded the trial app limit.
+ """
+
+ error_code = "trial_app_limit_exceeded"
+ code = 403
+ description = "The user has exceeded the trial app limit."
diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py
index 3c95779475..aca766567f 100644
--- a/api/controllers/console/explore/installed_app.py
+++ b/api/controllers/console/explore/installed_app.py
@@ -2,15 +2,17 @@ import logging
from typing import Any
from flask import request
-from flask_restx import Resource, inputs, marshal_with, reqparse
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
+from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
-from fields.installed_app_fields import installed_app_list_fields
+from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
@@ -18,22 +20,46 @@ from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
+
+class InstalledAppCreatePayload(BaseModel):
+ app_id: str
+
+
+class InstalledAppUpdatePayload(BaseModel):
+ is_pinned: bool | None = None
+
+
+class InstalledAppsListQuery(BaseModel):
+ app_id: str | None = Field(default=None, description="App ID to filter by")
+
+
logger = logging.getLogger(__name__)
+app_model = get_or_create_model("InstalledAppInfo", app_fields)
+
+installed_app_fields_copy = installed_app_fields.copy()
+installed_app_fields_copy["app"] = fields.Nested(app_model)
+installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
+
+installed_app_list_fields_copy = installed_app_list_fields.copy()
+installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
+installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
+
+
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
- @marshal_with(installed_app_list_fields)
+ @marshal_with(installed_app_list_model)
def get(self):
- app_id = request.args.get("app_id", default=None, type=str)
+ query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
- if app_id:
+ if query.app_id:
installed_apps = db.session.scalars(
select(InstalledApp).where(
- and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
+ and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id)
)
).all()
else:
@@ -105,26 +131,25 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
def post(self):
- parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
- args = parser.parse_args()
+ payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
- recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
+ recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
if recommended_app is None:
- raise NotFound("App not found")
+ raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
- app = db.session.query(App).where(App.id == args["app_id"]).first()
+ app = db.session.query(App).where(App.id == payload.app_id).first()
if app is None:
- raise NotFound("App not found")
+ raise NotFound("App entity not found")
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = (
db.session.query(InstalledApp)
- .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
+ .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first()
)
@@ -133,7 +158,7 @@ class InstalledAppsListApi(Resource):
recommended_app.install_count += 1
new_installed_app = InstalledApp(
- app_id=args["app_id"],
+ app_id=payload.app_id,
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
@@ -163,12 +188,11 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app):
- parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
- args = parser.parse_args()
+ payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
commit_args = False
- if "is_pinned" in args:
- installed_app.is_pinned = args["is_pinned"]
+ if payload.is_pinned is not None:
+ installed_app.is_pinned = payload.is_pinned
commit_args = True
if commit_args:
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index 229b7c8865..88487ac96f 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -1,10 +1,8 @@
import logging
from typing import Literal
-from uuid import UUID
from flask import request
-from flask_restx import marshal_with
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@@ -24,8 +22,10 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from fields.message_fields import message_infinite_scroll_pagination_fields
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
+from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
- conversation_id: UUID
- first_id: UUID | None = None
+ conversation_id: UUIDStrOrEmpty
+ first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
@@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
- @marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource):
args = MessageListQuery.model_validate(request.args.to_dict())
try:
- return MessageService.pagination_by_first_id(
+ pagination = MessageService.pagination_by_first_id(
app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
)
+ adapter = TypeAdapter(MessageListItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return MessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logger.exception("internal server error.")
raise InternalServerError()
- return {"data": questions}
+ return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py
index 9c6b2aedfb..660a4d5aea 100644
--- a/api/controllers/console/explore/parameter.py
+++ b/api/controllers/console/explore/parameter.py
@@ -1,5 +1,3 @@
-from flask_restx import marshal_with
-
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@@ -13,7 +11,6 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
- @marshal_with(fields.parameters_fields)
def get(self, installed_app: InstalledApp):
"""Retrieve app parameters."""
app_model = installed_app.app
@@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", [])
- return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@console_ns.route("/installed-apps//meta", endpoint="installed_app_meta")
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index 2b2f807694..c9920c97cf 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -3,6 +3,7 @@ from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages
+from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
@@ -19,8 +20,10 @@ app_fields = {
"icon_background": fields.String,
}
+app_model = get_or_create_model("RecommendedAppInfo", app_fields)
+
recommended_app_fields = {
- "app": fields.Nested(app_fields, attribute="app"),
+ "app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
@@ -29,13 +32,18 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
+ "can_trial": fields.Boolean,
}
+recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
+
recommended_app_list_fields = {
- "recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
+ "recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
+recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
+
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
@@ -52,7 +60,7 @@ class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required
@account_initialization_required
- @marshal_with(recommended_app_list_fields)
+ @marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py
index 6a9e274a0e..ea3de91741 100644
--- a/api/controllers/console/explore/saved_message.py
+++ b/api/controllers/console/explore/saved_message.py
@@ -1,55 +1,33 @@
-from uuid import UUID
-
from flask import request
-from flask_restx import fields, marshal_with
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
-from fields.conversation_fields import message_file_fields
-from libs.helper import TimestampField
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
+from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
- last_id: UUID | None = None
+ last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
- message_id: UUID
+ message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
-feedback_fields = {"rating": fields.String}
-
-message_fields = {
- "id": fields.String,
- "inputs": fields.Raw,
- "query": fields.String,
- "answer": fields.String,
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "created_at": TimestampField,
-}
-
-
@console_ns.route("/installed-apps//saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
- saved_message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
- }
-
- @marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@@ -59,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource):
args = SavedMessageListQuery.model_validate(request.args.to_dict())
- return SavedMessageService.pagination_by_last_id(
+ pagination = SavedMessageService.pagination_by_last_id(
app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
)
+ adapter = TypeAdapter(SavedMessageItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return SavedMessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
@@ -80,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -98,4 +83,4 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py
new file mode 100644
index 0000000000..1eb0cdb019
--- /dev/null
+++ b/api/controllers/console/explore/trial.py
@@ -0,0 +1,555 @@
+import logging
+from typing import Any, cast
+
+from flask import request
+from flask_restx import Resource, fields, marshal, marshal_with, reqparse
+from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
+
+import services
+from controllers.common.fields import Parameters as ParametersResponse
+from controllers.common.fields import Site as SiteResponse
+from controllers.common.schema import get_or_create_model
+from controllers.console import api, console_ns
+from controllers.console.app.error import (
+ AppUnavailableError,
+ AudioTooLargeError,
+ CompletionRequestError,
+ ConversationCompletedError,
+ NeedAddIdsError,
+ NoAudioUploadedError,
+ ProviderModelCurrentlyNotSupportError,
+ ProviderNotInitializeError,
+ ProviderNotSupportSpeechToTextError,
+ ProviderQuotaExceededError,
+ UnsupportedAudioTypeError,
+)
+from controllers.console.app.wraps import get_app_model_with_trial
+from controllers.console.explore.error import (
+ AppSuggestedQuestionsAfterAnswerDisabledError,
+ NotChatAppError,
+ NotCompletionAppError,
+ NotWorkflowAppError,
+)
+from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
+from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
+from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
+from core.app.apps.base_app_queue_manager import AppQueueManager
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.errors.error import (
+ ModelCurrentlyNotSupportError,
+ ProviderTokenNotInitError,
+ QuotaExceededError,
+)
+from core.model_runtime.errors.invoke import InvokeError
+from core.workflow.graph_engine.manager import GraphEngineManager
+from extensions.ext_database import db
+from fields.app_fields import (
+ app_detail_fields_with_site,
+ deleted_tool_fields,
+ model_config_fields,
+ site_fields,
+ tag_fields,
+)
+from fields.dataset_fields import dataset_fields
+from fields.member_fields import build_simple_account_model
+from fields.workflow_fields import (
+ conversation_variable_fields,
+ pipeline_variable_fields,
+ workflow_fields,
+ workflow_partial_fields,
+)
+from libs import helper
+from libs.helper import uuid_value
+from libs.login import current_user
+from models import Account
+from models.account import TenantStatus
+from models.model import AppMode, Site
+from models.workflow import Workflow
+from services.app_generate_service import AppGenerateService
+from services.app_service import AppService
+from services.audio_service import AudioService
+from services.dataset_service import DatasetService
+from services.errors.audio import (
+ AudioTooLargeServiceError,
+ NoAudioUploadedServiceError,
+ ProviderNotSupportSpeechToTextServiceError,
+ UnsupportedAudioTypeServiceError,
+)
+from services.errors.conversation import ConversationNotExistsError
+from services.errors.llm import InvokeRateLimitError
+from services.errors.message import (
+ MessageNotExistsError,
+ SuggestedQuestionsAfterAnswerDisabledError,
+)
+from services.message_service import MessageService
+from services.recommended_app_service import RecommendedAppService
+
+logger = logging.getLogger(__name__)
+
+
+model_config_model = get_or_create_model("TrialAppModelConfig", model_config_fields)
+workflow_partial_model = get_or_create_model("TrialWorkflowPartial", workflow_partial_fields)
+deleted_tool_model = get_or_create_model("TrialDeletedTool", deleted_tool_fields)
+tag_model = get_or_create_model("TrialTag", tag_fields)
+site_model = get_or_create_model("TrialSite", site_fields)
+
+app_detail_fields_with_site_copy = app_detail_fields_with_site.copy()
+app_detail_fields_with_site_copy["model_config"] = fields.Nested(
+ model_config_model, attribute="app_model_config", allow_null=True
+)
+app_detail_fields_with_site_copy["workflow"] = fields.Nested(workflow_partial_model, allow_null=True)
+app_detail_fields_with_site_copy["deleted_tools"] = fields.List(fields.Nested(deleted_tool_model))
+app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
+app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
+app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
+
+simple_account_model = build_simple_account_model(console_ns)
+conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
+pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
+
+workflow_fields_copy = workflow_fields.copy()
+workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
+workflow_fields_copy["updated_by"] = fields.Nested(
+ simple_account_model, attribute="updated_by_account", allow_null=True
+)
+workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
+workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
+workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
+
+
+class TrialAppWorkflowRunApi(TrialAppResource):
+ def post(self, trial_app):
+ """
+ Run workflow
+ """
+ app_model = trial_app
+ if not app_model:
+ raise NotWorkflowAppError()
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode != AppMode.WORKFLOW:
+ raise NotWorkflowAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("files", type=list, required=False, location="json")
+ args = parser.parse_args()
+ assert current_user is not None
+ try:
+ app_id = app_model.id
+ user_id = current_user.id
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
+ )
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except InvokeRateLimitError as ex:
+ raise InvokeRateLimitHttpError(ex.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialAppWorkflowTaskStopApi(TrialAppResource):
+ def post(self, trial_app, task_id: str):
+ """
+ Stop workflow task
+ """
+ app_model = trial_app
+ if not app_model:
+ raise NotWorkflowAppError()
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode != AppMode.WORKFLOW:
+ raise NotWorkflowAppError()
+ assert current_user is not None
+
+ # Stop using both mechanisms for backward compatibility
+ # Legacy stop flag mechanism (without user check)
+ AppQueueManager.set_stop_flag_no_user_check(task_id)
+
+ # New graph engine command channel mechanism
+ GraphEngineManager.send_stop_command(task_id)
+
+ return {"result": "success"}
+
+
+class TrialChatApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
+ raise NotChatAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, location="json")
+ parser.add_argument("query", type=str, required=True, location="json")
+ parser.add_argument("files", type=list, required=False, location="json")
+ parser.add_argument("conversation_id", type=uuid_value, location="json")
+ parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
+ parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
+ args = parser.parse_args()
+
+ args["auto_generate_name"] = False
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
+ )
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except InvokeRateLimitError as ex:
+ raise InvokeRateLimitHttpError(ex.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialMessageSuggestedQuestionApi(TrialAppResource):
+ @trial_feature_enable
+ def get(self, trial_app, message_id):
+ app_model = trial_app
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
+ raise NotChatAppError()
+
+ message_id = str(message_id)
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+ questions = MessageService.get_suggested_questions_after_answer(
+ app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
+ )
+ except MessageNotExistsError:
+ raise NotFound("Message not found")
+ except ConversationNotExistsError:
+ raise NotFound("Conversation not found")
+ except SuggestedQuestionsAfterAnswerDisabledError:
+ raise AppSuggestedQuestionsAfterAnswerDisabledError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+ return {"data": questions}
+
+
+class TrialChatAudioApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+
+ file = request.files["file"]
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return response
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except NoAudioUploadedServiceError:
+ raise NoAudioUploadedError()
+ except AudioTooLargeServiceError as e:
+ raise AudioTooLargeError(str(e))
+ except UnsupportedAudioTypeServiceError:
+ raise UnsupportedAudioTypeError()
+ except ProviderNotSupportSpeechToTextServiceError:
+ raise ProviderNotSupportSpeechToTextError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialChatTextApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ try:
+ parser = reqparse.RequestParser()
+ parser.add_argument("message_id", type=str, required=False, location="json")
+ parser.add_argument("voice", type=str, location="json")
+ parser.add_argument("text", type=str, location="json")
+ parser.add_argument("streaming", type=bool, location="json")
+ args = parser.parse_args()
+
+ message_id = args.get("message_id", None)
+ text = args.get("text", None)
+ voice = args.get("voice", None)
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return response
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except NoAudioUploadedServiceError:
+ raise NoAudioUploadedError()
+ except AudioTooLargeServiceError as e:
+ raise AudioTooLargeError(str(e))
+ except UnsupportedAudioTypeServiceError:
+ raise UnsupportedAudioTypeError()
+ except ProviderNotSupportSpeechToTextServiceError:
+ raise ProviderNotSupportSpeechToTextError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception as e:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialCompletionApi(TrialAppResource):
+ @trial_feature_enable
+ def post(self, trial_app):
+ app_model = trial_app
+ if app_model.mode != "completion":
+ raise NotCompletionAppError()
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("inputs", type=dict, required=True, location="json")
+ parser.add_argument("query", type=str, location="json", default="")
+ parser.add_argument("files", type=list, required=False, location="json")
+ parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
+ parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
+ args = parser.parse_args()
+
+ streaming = args["response_mode"] == "streaming"
+ args["auto_generate_name"] = False
+
+ try:
+ if not isinstance(current_user, Account):
+ raise ValueError("current_user must be an Account instance")
+
+ # Get IDs before they might be detached from session
+ app_id = app_model.id
+ user_id = current_user.id
+
+ response = AppGenerateService.generate(
+ app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
+ )
+
+ RecommendedAppService.add_trial_app_record(app_id, user_id)
+ return helper.compact_generate_response(response)
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+ except services.errors.conversation.ConversationCompletedError:
+ raise ConversationCompletedError()
+ except services.errors.app_model_config.AppModelConfigBrokenError:
+ logger.exception("App model config broken.")
+ raise AppUnavailableError()
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+ except ValueError as e:
+ raise e
+ except Exception:
+ logger.exception("internal server error.")
+ raise InternalServerError()
+
+
+class TrialSitApi(Resource):
+ """Resource for trial app sites."""
+
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ """Retrieve app site info.
+
+ Returns the site configuration for the application including theme, icons, and text.
+ """
+ site = db.session.query(Site).where(Site.app_id == app_model.id).first()
+
+ if not site:
+ raise Forbidden()
+
+ assert app_model.tenant
+ if app_model.tenant.status == TenantStatus.ARCHIVE:
+ raise Forbidden()
+
+ return SiteResponse.model_validate(site).model_dump(mode="json")
+
+
+class TrialAppParameterApi(Resource):
+ """Resource for app variables."""
+
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ """Retrieve app parameters."""
+
+ if app_model is None:
+ raise AppUnavailableError()
+
+ if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+ workflow = app_model.workflow
+ if workflow is None:
+ raise AppUnavailableError()
+
+ features_dict = workflow.features_dict
+ user_input_form = workflow.user_input_form(to_old_structure=True)
+ else:
+ app_model_config = app_model.app_model_config
+ if app_model_config is None:
+ raise AppUnavailableError()
+
+ features_dict = app_model_config.to_dict()
+
+ user_input_form = features_dict.get("user_input_form", [])
+
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return ParametersResponse.model_validate(parameters).model_dump(mode="json")
+
+
+class AppApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ @marshal_with(app_detail_with_site_model)
+ def get(self, app_model):
+ """Get app detail"""
+
+ app_service = AppService()
+ app_model = app_service.get_app(app_model)
+
+ return app_model
+
+
+class AppWorkflowApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ @marshal_with(workflow_model)
+ def get(self, app_model):
+ """Get workflow detail"""
+ if not app_model.workflow_id:
+ raise AppUnavailableError()
+
+ workflow = (
+ db.session.query(Workflow)
+ .where(
+ Workflow.id == app_model.workflow_id,
+ )
+ .first()
+ )
+ return workflow
+
+
+class DatasetListApi(Resource):
+ @trial_feature_enable
+ @get_app_model_with_trial
+ def get(self, app_model):
+ page = request.args.get("page", default=1, type=int)
+ limit = request.args.get("limit", default=20, type=int)
+ ids = request.args.getlist("ids")
+
+ tenant_id = app_model.tenant_id
+ if ids:
+ datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
+ else:
+ raise NeedAddIdsError()
+
+ data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
+
+ response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
+ return response
+
+
+api.add_resource(TrialChatApi, "/trial-apps//chat-messages", endpoint="trial_app_chat_completion")
+
+api.add_resource(
+ TrialMessageSuggestedQuestionApi,
+ "/trial-apps//messages//suggested-questions",
+ endpoint="trial_app_suggested_question",
+)
+
+api.add_resource(TrialChatAudioApi, "/trial-apps//audio-to-text", endpoint="trial_app_audio")
+api.add_resource(TrialChatTextApi, "/trial-apps//text-to-audio", endpoint="trial_app_text")
+
+api.add_resource(TrialCompletionApi, "/trial-apps//completion-messages", endpoint="trial_app_completion")
+
+api.add_resource(TrialSitApi, "/trial-apps//site")
+
+api.add_resource(TrialAppParameterApi, "/trial-apps//parameters", endpoint="trial_app_parameters")
+
+api.add_resource(AppApi, "/trial-apps/", endpoint="trial_app")
+
+api.add_resource(TrialAppWorkflowRunApi, "/trial-apps//workflows/run", endpoint="trial_app_workflow_run")
+api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps//workflows/tasks//stop")
+
+api.add_resource(AppWorkflowApi, "/trial-apps//workflows", endpoint="trial_app_workflow")
+api.add_resource(DatasetListApi, "/trial-apps//datasets", endpoint="trial_app_datasets")
diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py
index 2a97d312aa..38f0a04904 100644
--- a/api/controllers/console/explore/wraps.py
+++ b/api/controllers/console/explore/wraps.py
@@ -2,14 +2,15 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
+from flask import abort
from flask_restx import Resource
from werkzeug.exceptions import NotFound
-from controllers.console.explore.error import AppAccessDeniedError
+from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
-from models import InstalledApp
+from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
+def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
+ def decorator(view: Callable[Concatenate[App, P], R]):
+ @wraps(view)
+ def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
+ current_user, _ = current_account_with_tenant()
+
+ trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
+
+ if trial_app is None:
+ raise TrialAppNotAllowed()
+ app = trial_app.app
+
+ if app is None:
+ raise TrialAppNotAllowed()
+
+ account_trial_app_record = (
+ db.session.query(AccountTrialAppRecord)
+ .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
+ .first()
+ )
+ if account_trial_app_record:
+ if account_trial_app_record.count >= trial_app.trial_limit:
+ raise TrialAppLimitExceeded()
+
+ return view(app, *args, **kwargs)
+
+ return decorated
+
+ if view:
+ return decorator(view)
+ return decorator
+
+
+def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if not features.enable_trial_app:
+ abort(403, "Trial app feature is not enabled.")
+ return view(*args, **kwargs)
+
+ return decorated
+
+
+def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if not features.enable_explore_banner:
+ abort(403, "Explore banner feature is not enabled.")
+ return view(*args, **kwargs)
+
+ return decorated
+
+
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@@ -80,3 +136,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
+
+
+class TrialAppResource(Resource):
+ # must be reversed if there are multiple decorators
+
+ method_decorators = [
+ trial_app_required,
+ account_initialization_required,
+ login_required,
+ ]
diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py
index 08f29b4655..efa46c9779 100644
--- a/api/controllers/console/extension.py
+++ b/api/controllers/console/extension.py
@@ -1,14 +1,32 @@
-from flask_restx import Resource, fields, marshal_with, reqparse
+from flask import request
+from flask_restx import Resource, fields, marshal_with
+from pydantic import BaseModel, Field
from constants import HIDDEN_VALUE
-from controllers.console import console_ns
-from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
+from ..common.schema import register_schema_models
+from . import console_ns
+from .wraps import account_initialization_required, setup_required
+
+
+class CodeBasedExtensionQuery(BaseModel):
+ module: str
+
+
+class APIBasedExtensionPayload(BaseModel):
+ name: str = Field(description="Extension name")
+ api_endpoint: str = Field(description="API endpoint URL")
+ api_key: str = Field(description="API key for authentication")
+
+
+register_schema_models(console_ns, APIBasedExtensionPayload)
+
+
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
@@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
class CodeBasedExtensionAPI(Resource):
@console_ns.doc("get_code_based_extension")
@console_ns.doc(description="Get code-based extension data by module name")
- @console_ns.expect(
- console_ns.parser().add_argument(
- "module", type=str, required=True, location="args", help="Extension module name"
- )
- )
+ @console_ns.doc(params={"module": "Extension module name"})
@console_ns.response(
200,
"Success",
@@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
@login_required
@account_initialization_required
def get(self):
- parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
- args = parser.parse_args()
+ query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
+ return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
@console_ns.route("/api-based-extension")
@@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
@console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension")
- @console_ns.expect(
- console_ns.model(
- "CreateAPIBasedExtensionRequest",
- {
- "name": fields.String(required=True, description="Extension name"),
- "api_endpoint": fields.String(required=True, description="API endpoint URL"),
- "api_key": fields.String(required=True, description="API key for authentication"),
- },
- )
- )
+ @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self):
- args = console_ns.payload
+ payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_tenant_id,
- name=args["name"],
- api_endpoint=args["api_endpoint"],
- api_key=args["api_key"],
+ name=payload.name,
+ api_endpoint=payload.api_endpoint,
+ api_key=payload.api_key,
)
return APIBasedExtensionService.save(extension_data)
@@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"})
- @console_ns.expect(
- console_ns.model(
- "UpdateAPIBasedExtensionRequest",
- {
- "name": fields.String(required=True, description="Extension name"),
- "api_endpoint": fields.String(required=True, description="API endpoint URL"),
- "api_key": fields.String(required=True, description="API key for authentication"),
- },
- )
- )
+ @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@setup_required
@login_required
@@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
- args = console_ns.payload
+ payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
- extension_data_from_db.name = args["name"]
- extension_data_from_db.api_endpoint = args["api_endpoint"]
+ extension_data_from_db.name = payload.name
+ extension_data_from_db.api_endpoint = payload.api_endpoint
- if args["api_key"] != HIDDEN_VALUE:
- extension_data_from_db.api_key = args["api_key"]
+ if payload.api_key != HIDDEN_VALUE:
+ extension_data_from_db.api_key = payload.api_key
return APIBasedExtensionService.save(extension_data_from_db)
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index 6951c906e9..1e98d622fe 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -1,43 +1,58 @@
-from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
+from werkzeug.exceptions import Unauthorized
-from libs.login import current_account_with_tenant, login_required
-from services.feature_service import FeatureService
+from controllers.fastopenapi import console_router
+from libs.login import current_account_with_tenant, current_user, login_required
+from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
-from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
-@console_ns.route("/features")
-class FeatureApi(Resource):
- @console_ns.doc("get_tenant_features")
- @console_ns.doc(description="Get feature configuration for current tenant")
- @console_ns.response(
- 200,
- "Success",
- console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
- )
- @setup_required
- @login_required
- @account_initialization_required
- @cloud_utm_record
- def get(self):
- """Get feature configuration for current tenant"""
- _, current_tenant_id = current_account_with_tenant()
-
- return FeatureService.get_features(current_tenant_id).model_dump()
+class FeatureResponse(BaseModel):
+ features: FeatureModel = Field(description="Feature configuration object")
-@console_ns.route("/system-features")
-class SystemFeatureApi(Resource):
- @console_ns.doc("get_system_features")
- @console_ns.doc(description="Get system-wide feature configuration")
- @console_ns.response(
- 200,
- "Success",
- console_ns.model(
- "SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
- ),
- )
- def get(self):
- """Get system-wide feature configuration"""
- return FeatureService.get_system_features().model_dump()
+class SystemFeatureResponse(BaseModel):
+ features: SystemFeatureModel = Field(description="System feature configuration object")
+
+
+@console_router.get(
+ "/features",
+ response_model=FeatureResponse,
+ tags=["console"],
+)
+@setup_required
+@login_required
+@account_initialization_required
+@cloud_utm_record
+def get_tenant_features() -> FeatureResponse:
+ """Get feature configuration for current tenant."""
+ _, current_tenant_id = current_account_with_tenant()
+
+ return FeatureResponse(features=FeatureService.get_features(current_tenant_id))
+
+
+@console_router.get(
+ "/system-features",
+ response_model=SystemFeatureResponse,
+ tags=["console"],
+)
+def get_system_features() -> SystemFeatureResponse:
+ """Get system-wide feature configuration
+
+ NOTE: This endpoint is unauthenticated by design, as it provides system features
+ data required for dashboard initialization.
+
+ Authentication would create circular dependency (can't login without dashboard loading).
+
+ Only non-sensitive configuration data should be returned by this endpoint.
+ """
+ # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
+ # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
+ # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
+ # raise `Unauthorized` exception if authentication token is not provided.
+ try:
+ is_authenticated = current_user.is_authenticated
+ except Unauthorized:
+ is_authenticated = False
+ return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated))
diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py
index 29417dc896..109a3cd0d3 100644
--- a/api/controllers/console/files.py
+++ b/api/controllers/console/files.py
@@ -1,7 +1,7 @@
from typing import Literal
from flask import request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource
from werkzeug.exceptions import Forbidden
import services
@@ -15,18 +15,21 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_schema_models
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
-from fields.file_fields import file_fields, upload_config_fields
+from fields.file_fields import FileResponse, UploadConfig
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
from . import console_ns
+register_schema_models(console_ns, UploadConfig, FileResponse)
+
PREVIEW_WORDS_LIMIT = 3000
@@ -35,26 +38,27 @@ class FileApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(upload_config_fields)
+ @console_ns.response(200, "Success", console_ns.models[UploadConfig.__name__])
def get(self):
- return {
- "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
- "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
- "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
- "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
- "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
- "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
- "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
- "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
- "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
- "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
- }, 200
+ config = UploadConfig(
+ file_size_limit=dify_config.UPLOAD_FILE_SIZE_LIMIT,
+ batch_count_limit=dify_config.UPLOAD_FILE_BATCH_LIMIT,
+ file_upload_limit=dify_config.BATCH_UPLOAD_LIMIT,
+ image_file_size_limit=dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
+ video_file_size_limit=dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
+ audio_file_size_limit=dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
+ workflow_file_upload_limit=dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
+ image_file_batch_limit=dify_config.IMAGE_FILE_BATCH_LIMIT,
+ single_chunk_attachment_limit=dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
+ attachment_image_file_size_limit=dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
+ )
+ return config.model_dump(mode="json"), 200
@setup_required
@login_required
@account_initialization_required
- @marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
+ @console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source")
@@ -90,7 +94,8 @@ class FileApi(Resource):
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
raise BlockedFileExtensionError(blocked_extension_error.description)
- return upload_file, 201
+ response = FileResponse.model_validate(upload_file, from_attributes=True)
+ return response.model_dump(mode="json"), 201
@console_ns.route("/files//preview")
diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py
index 2bebe79eac..f086bf1862 100644
--- a/api/controllers/console/init_validate.py
+++ b/api/controllers/console/init_validate.py
@@ -1,87 +1,74 @@
import os
+from typing import Literal
from flask import session
-from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
+from controllers.fastopenapi import console_router
from extensions.ext_database import db
from models.model import DifySetup
from services.account_service import TenantService
-from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
-DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-
class InitValidatePayload(BaseModel):
- password: str = Field(..., max_length=30)
+ password: str = Field(..., max_length=30, description="Initialization password")
-console_ns.schema_model(
- InitValidatePayload.__name__,
- InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+class InitStatusResponse(BaseModel):
+ status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
+
+
+class InitValidateResponse(BaseModel):
+ result: str = Field(description="Operation result", examples=["success"])
+
+
+@console_router.get(
+ "/init",
+ response_model=InitStatusResponse,
+ tags=["console"],
)
+def get_init_status() -> InitStatusResponse:
+ """Get initialization validation status."""
+ init_status = get_init_validate_status()
+ if init_status:
+ return InitStatusResponse(status="finished")
+ return InitStatusResponse(status="not_started")
-@console_ns.route("/init")
-class InitValidateAPI(Resource):
- @console_ns.doc("get_init_status")
- @console_ns.doc(description="Get initialization validation status")
- @console_ns.response(
- 200,
- "Success",
- model=console_ns.model(
- "InitStatusResponse",
- {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
- ),
- )
- def get(self):
- """Get initialization validation status"""
- init_status = get_init_validate_status()
- if init_status:
- return {"status": "finished"}
- return {"status": "not_started"}
+@console_router.post(
+ "/init",
+ response_model=InitValidateResponse,
+ tags=["console"],
+ status_code=201,
+)
+@only_edition_self_hosted
+def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
+ """Validate initialization password."""
+ tenant_count = TenantService.get_tenant_count()
+ if tenant_count > 0:
+ raise AlreadySetupError()
- @console_ns.doc("validate_init_password")
- @console_ns.doc(description="Validate initialization password for self-hosted edition")
- @console_ns.expect(console_ns.models[InitValidatePayload.__name__])
- @console_ns.response(
- 201,
- "Success",
- model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
- )
- @console_ns.response(400, "Already setup or validation failed")
- @only_edition_self_hosted
- def post(self):
- """Validate initialization password"""
- # is tenant created
- tenant_count = TenantService.get_tenant_count()
- if tenant_count > 0:
- raise AlreadySetupError()
+ if payload.password != os.environ.get("INIT_PASSWORD"):
+ session["is_init_validated"] = False
+ raise InitValidateFailedError()
- payload = InitValidatePayload.model_validate(console_ns.payload)
- input_password = payload.password
-
- if input_password != os.environ.get("INIT_PASSWORD"):
- session["is_init_validated"] = False
- raise InitValidateFailedError()
-
- session["is_init_validated"] = True
- return {"result": "success"}, 201
+ session["is_init_validated"] = True
+ return InitValidateResponse(result="success")
-def get_init_validate_status():
+def get_init_validate_status() -> bool:
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
- return db_session.execute(select(DifySetup)).scalar_one_or_none()
+ return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
return True
diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py
index 25a3d80522..d480af312b 100644
--- a/api/controllers/console/ping.py
+++ b/api/controllers/console/ping.py
@@ -1,17 +1,17 @@
-from flask_restx import Resource, fields
+from pydantic import BaseModel, Field
-from . import console_ns
+from controllers.fastopenapi import console_router
-@console_ns.route("/ping")
-class PingApi(Resource):
- @console_ns.doc("health_check")
- @console_ns.doc(description="Health check endpoint for connection testing")
- @console_ns.response(
- 200,
- "Success",
- console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
- )
- def get(self):
- """Health check endpoint for connection testing"""
- return {"result": "pong"}
+class PingResponse(BaseModel):
+ result: str = Field(description="Health check result", examples=["pong"])
+
+
+@console_router.get(
+ "/ping",
+ response_model=PingResponse,
+ tags=["console"],
+)
+def ping() -> PingResponse:
+ """Health check endpoint for connection testing."""
+ return PingResponse(result="pong")
diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py
index 47eef7eb7e..88a9ce3a79 100644
--- a/api/controllers/console/remote_files.py
+++ b/api/controllers/console/remote_files.py
@@ -1,7 +1,6 @@
import urllib.parse
import httpx
-from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services
@@ -11,87 +10,82 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
+from controllers.fastopenapi import console_router
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
-from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
+from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant
from services.file_service import FileService
-from . import console_ns
-
-
-@console_ns.route("/remote-files/")
-class RemoteFileInfoApi(Resource):
- @marshal_with(remote_file_info_fields)
- def get(self, url):
- decoded_url = urllib.parse.unquote(url)
- resp = ssrf_proxy.head(decoded_url)
- if resp.status_code != httpx.codes.OK:
- # failed back to get method
- resp = ssrf_proxy.get(decoded_url, timeout=3)
- resp.raise_for_status()
- return {
- "file_type": resp.headers.get("Content-Type", "application/octet-stream"),
- "file_length": int(resp.headers.get("Content-Length", 0)),
- }
-
class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
-console_ns.schema_model(
- RemoteFileUploadPayload.__name__,
- RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
+@console_router.get(
+ "/remote-files/",
+ response_model=RemoteFileInfo,
+ tags=["console"],
)
+def get_remote_file_info(url: str) -> RemoteFileInfo:
+ decoded_url = urllib.parse.unquote(url)
+ resp = ssrf_proxy.head(decoded_url)
+ if resp.status_code != httpx.codes.OK:
+ resp = ssrf_proxy.get(decoded_url, timeout=3)
+ resp.raise_for_status()
+ return RemoteFileInfo(
+ file_type=resp.headers.get("Content-Type", "application/octet-stream"),
+ file_length=int(resp.headers.get("Content-Length", 0)),
+ )
-@console_ns.route("/remote-files/upload")
-class RemoteFileUploadApi(Resource):
- @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
- @marshal_with(file_fields_with_signed_url)
- def post(self):
- args = RemoteFileUploadPayload.model_validate(console_ns.payload)
- url = args.url
+@console_router.post(
+ "/remote-files/upload",
+ response_model=FileWithSignedUrl,
+ tags=["console"],
+ status_code=201,
+)
+def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
+ url = payload.url
- try:
- resp = ssrf_proxy.head(url=url)
- if resp.status_code != httpx.codes.OK:
- resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
- if resp.status_code != httpx.codes.OK:
- raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
- except httpx.RequestError as e:
- raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
+ try:
+ resp = ssrf_proxy.head(url=url)
+ if resp.status_code != httpx.codes.OK:
+ resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
+ if resp.status_code != httpx.codes.OK:
+ raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
+ except httpx.RequestError as e:
+ raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
- file_info = helpers.guess_file_info_from_response(resp)
+ file_info = helpers.guess_file_info_from_response(resp)
- if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
- raise FileTooLargeError
+ if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
+ raise FileTooLargeError
- content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
+ content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
- try:
- user, _ = current_account_with_tenant()
- upload_file = FileService(db.engine).upload_file(
- filename=file_info.filename,
- content=content,
- mimetype=file_info.mimetype,
- user=user,
- source_url=url,
- )
- except services.errors.file.FileTooLargeError as file_too_large_error:
- raise FileTooLargeError(file_too_large_error.description)
- except services.errors.file.UnsupportedFileTypeError:
- raise UnsupportedFileTypeError()
+ try:
+ user, _ = current_account_with_tenant()
+ upload_file = FileService(db.engine).upload_file(
+ filename=file_info.filename,
+ content=content,
+ mimetype=file_info.mimetype,
+ user=user,
+ source_url=url,
+ )
+ except services.errors.file.FileTooLargeError as file_too_large_error:
+ raise FileTooLargeError(file_too_large_error.description)
+ except services.errors.file.UnsupportedFileTypeError:
+ raise UnsupportedFileTypeError()
- return {
- "id": upload_file.id,
- "name": upload_file.name,
- "size": upload_file.size,
- "extension": upload_file.extension,
- "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
- "mime_type": upload_file.mime_type,
- "created_by": upload_file.created_by,
- "created_at": upload_file.created_at,
- }, 201
+ return FileWithSignedUrl(
+ id=upload_file.id,
+ name=upload_file.name,
+ size=upload_file.size,
+ extension=upload_file.extension,
+ url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
+ mime_type=upload_file.mime_type,
+ created_by=upload_file.created_by,
+ created_at=int(upload_file.created_at.timestamp()),
+ )
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index 7fa02ae280..e1ea007232 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -1,20 +1,19 @@
+from typing import Literal
+
from flask import request
-from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from configs import dify_config
+from controllers.fastopenapi import console_router
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
-from . import console_ns
from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
-DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-
class SetupRequestPayload(BaseModel):
email: EmailStr = Field(..., description="Admin email address")
@@ -28,77 +27,66 @@ class SetupRequestPayload(BaseModel):
return valid_password(value)
-console_ns.schema_model(
- SetupRequestPayload.__name__,
- SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+class SetupStatusResponse(BaseModel):
+ step: Literal["not_started", "finished"] = Field(description="Setup step status")
+ setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)")
+
+
+class SetupResponse(BaseModel):
+ result: str = Field(description="Setup result", examples=["success"])
+
+
+@console_router.get(
+ "/setup",
+ response_model=SetupStatusResponse,
+ tags=["console"],
)
+def get_setup_status_api() -> SetupStatusResponse:
+ """Get system setup status."""
+ if dify_config.EDITION == "SELF_HOSTED":
+ setup_status = get_setup_status()
+ if setup_status and not isinstance(setup_status, bool):
+ return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat())
+ if setup_status:
+ return SetupStatusResponse(step="finished")
+ return SetupStatusResponse(step="not_started")
+ return SetupStatusResponse(step="finished")
-@console_ns.route("/setup")
-class SetupApi(Resource):
- @console_ns.doc("get_setup_status")
- @console_ns.doc(description="Get system setup status")
- @console_ns.response(
- 200,
- "Success",
- console_ns.model(
- "SetupStatusResponse",
- {
- "step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
- "setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
- },
- ),
+@console_router.post(
+ "/setup",
+ response_model=SetupResponse,
+ tags=["console"],
+ status_code=201,
+)
+@only_edition_self_hosted
+def setup_system(payload: SetupRequestPayload) -> SetupResponse:
+ """Initialize system setup with admin account."""
+ if get_setup_status():
+ raise AlreadySetupError()
+
+ tenant_count = TenantService.get_tenant_count()
+ if tenant_count > 0:
+ raise AlreadySetupError()
+
+ if not get_init_validate_status():
+ raise NotInitValidateError()
+
+ normalized_email = payload.email.lower()
+
+ RegisterService.setup(
+ email=normalized_email,
+ name=payload.name,
+ password=payload.password,
+ ip_address=extract_remote_ip(request),
+ language=payload.language,
)
- def get(self):
- """Get system setup status"""
- if dify_config.EDITION == "SELF_HOSTED":
- setup_status = get_setup_status()
- # Check if setup_status is a DifySetup object rather than a bool
- if setup_status and not isinstance(setup_status, bool):
- return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
- elif setup_status:
- return {"step": "finished"}
- return {"step": "not_started"}
- return {"step": "finished"}
- @console_ns.doc("setup_system")
- @console_ns.doc(description="Initialize system setup with admin account")
- @console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
- @console_ns.response(
- 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
- )
- @console_ns.response(400, "Already setup or validation failed")
- @only_edition_self_hosted
- def post(self):
- """Initialize system setup with admin account"""
- # is set up
- if get_setup_status():
- raise AlreadySetupError()
-
- # is tenant created
- tenant_count = TenantService.get_tenant_count()
- if tenant_count > 0:
- raise AlreadySetupError()
-
- if not get_init_validate_status():
- raise NotInitValidateError()
-
- args = SetupRequestPayload.model_validate(console_ns.payload)
-
- # setup
- RegisterService.setup(
- email=args.email,
- name=args.name,
- password=args.password,
- ip_address=extract_remote_ip(request),
- language=args.language,
- )
-
- return {"result": "success"}, 201
+ return SetupResponse(result="success")
-def get_setup_status():
+def get_setup_status() -> DifySetup | bool | None:
if dify_config.EDITION == "SELF_HOSTED":
return db.session.query(DifySetup).first()
- else:
- return True
+
+ return True
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index 17cfc3ff4b..e828d54ff4 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -1,152 +1,160 @@
-from flask import request
-from flask_restx import Resource, marshal_with, reqparse
+from typing import Literal
+from uuid import UUID
+
+from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
-from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
-from fields.tag_fields import dataset_tag_fields
+from controllers.fastopenapi import console_router
from libs.login import current_account_with_tenant, login_required
-from models.model import Tag
from services.tag_service import TagService
-def _validate_name(name):
- if not name or len(name) < 1 or len(name) > 50:
- raise ValueError("Name must be between 1 to 50 characters.")
- return name
+class TagBasePayload(BaseModel):
+ name: str = Field(description="Tag name", min_length=1, max_length=50)
+ type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
-parser_tags = (
- reqparse.RequestParser()
- .add_argument(
- "name",
- nullable=False,
- required=True,
- help="Name must be between 1 to 50 characters.",
- type=_validate_name,
- )
- .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
+class TagBindingPayload(BaseModel):
+ tag_ids: list[str] = Field(description="Tag IDs to bind")
+ target_id: str = Field(description="Target ID to bind tags to")
+ type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
+
+
+class TagBindingRemovePayload(BaseModel):
+ tag_id: str = Field(description="Tag ID to remove")
+ target_id: str = Field(description="Target ID to unbind tag from")
+ type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
+
+
+class TagListQueryParam(BaseModel):
+ type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
+ keyword: str | None = Field(None, description="Search keyword")
+
+
+class TagResponse(BaseModel):
+ id: str = Field(description="Tag ID")
+ name: str = Field(description="Tag name")
+ type: str = Field(description="Tag type")
+ binding_count: int = Field(description="Number of bindings")
+
+
+class TagBindingResult(BaseModel):
+ result: Literal["success"] = Field(description="Operation result", examples=["success"])
+
+
+@console_router.get(
+ "/tags",
+ response_model=list[TagResponse],
+ tags=["console"],
)
+@setup_required
+@login_required
+@account_initialization_required
+def list_tags(query: TagListQueryParam) -> list[TagResponse]:
+ _, current_tenant_id = current_account_with_tenant()
+ tags = TagService.get_tags(query.type, current_tenant_id, query.keyword)
+
+ return [
+ TagResponse(
+ id=tag.id,
+ name=tag.name,
+ type=tag.type,
+ binding_count=int(tag.binding_count),
+ )
+ for tag in tags
+ ]
-@console_ns.route("/tags")
-class TagListApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- @marshal_with(dataset_tag_fields)
- def get(self):
- _, current_tenant_id = current_account_with_tenant()
- tag_type = request.args.get("type", type=str, default="")
- keyword = request.args.get("keyword", default=None, type=str)
- tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
-
- return tags, 200
-
- @console_ns.expect(parser_tags)
- @setup_required
- @login_required
- @account_initialization_required
- def post(self):
- current_user, _ = current_account_with_tenant()
- # The role of the current user in the ta table must be admin, owner, or editor
- if not (current_user.has_edit_permission or current_user.is_dataset_editor):
- raise Forbidden()
-
- args = parser_tags.parse_args()
- tag = TagService.save_tags(args)
-
- response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
-
- return response, 200
-
-
-parser_tag_id = reqparse.RequestParser().add_argument(
- "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
+@console_router.post(
+ "/tags",
+ response_model=TagResponse,
+ tags=["console"],
)
+@setup_required
+@login_required
+@account_initialization_required
+def create_tag(payload: TagBasePayload) -> TagResponse:
+ current_user, _ = current_account_with_tenant()
+ # The role of the current user in the tag table must be admin, owner, or editor
+ if not (current_user.has_edit_permission or current_user.is_dataset_editor):
+ raise Forbidden()
+
+ tag = TagService.save_tags(payload.model_dump())
+
+ return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0)
-@console_ns.route("/tags/")
-class TagUpdateDeleteApi(Resource):
- @console_ns.expect(parser_tag_id)
- @setup_required
- @login_required
- @account_initialization_required
- def patch(self, tag_id):
- current_user, _ = current_account_with_tenant()
- tag_id = str(tag_id)
- # The role of the current user in the ta table must be admin, owner, or editor
- if not (current_user.has_edit_permission or current_user.is_dataset_editor):
- raise Forbidden()
-
- args = parser_tag_id.parse_args()
- tag = TagService.update_tags(args, tag_id)
-
- binding_count = TagService.get_tag_binding_count(tag_id)
-
- response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
-
- return response, 200
-
- @setup_required
- @login_required
- @account_initialization_required
- @edit_permission_required
- def delete(self, tag_id):
- tag_id = str(tag_id)
-
- TagService.delete_tag(tag_id)
-
- return 204
-
-
-parser_create = (
- reqparse.RequestParser()
- .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
- .add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
- .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
+@console_router.patch(
+ "/tags/",
+ response_model=TagResponse,
+ tags=["console"],
)
+@setup_required
+@login_required
+@account_initialization_required
+def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse:
+ current_user, _ = current_account_with_tenant()
+ tag_id_str = str(tag_id)
+ # The role of the current user in the ta table must be admin, owner, or editor
+ if not (current_user.has_edit_permission or current_user.is_dataset_editor):
+ raise Forbidden()
+
+ tag = TagService.update_tags(payload.model_dump(), tag_id_str)
+
+ binding_count = TagService.get_tag_binding_count(tag_id_str)
+
+ return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count)
-@console_ns.route("/tag-bindings/create")
-class TagBindingCreateApi(Resource):
- @console_ns.expect(parser_create)
- @setup_required
- @login_required
- @account_initialization_required
- def post(self):
- current_user, _ = current_account_with_tenant()
- # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
- if not (current_user.has_edit_permission or current_user.is_dataset_editor):
- raise Forbidden()
-
- args = parser_create.parse_args()
- TagService.save_tag_binding(args)
-
- return {"result": "success"}, 200
-
-
-parser_remove = (
- reqparse.RequestParser()
- .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
- .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
- .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
+@console_router.delete(
+ "/tags/",
+ tags=["console"],
+ status_code=204,
)
+@setup_required
+@login_required
+@account_initialization_required
+@edit_permission_required
+def delete_tag(tag_id: UUID) -> None:
+ tag_id_str = str(tag_id)
+
+ TagService.delete_tag(tag_id_str)
-@console_ns.route("/tag-bindings/remove")
-class TagBindingDeleteApi(Resource):
- @console_ns.expect(parser_remove)
- @setup_required
- @login_required
- @account_initialization_required
- def post(self):
- current_user, _ = current_account_with_tenant()
- # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
- if not (current_user.has_edit_permission or current_user.is_dataset_editor):
- raise Forbidden()
+@console_router.post(
+ "/tag-bindings/create",
+ response_model=TagBindingResult,
+ tags=["console"],
+)
+@setup_required
+@login_required
+@account_initialization_required
+def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult:
+ current_user, _ = current_account_with_tenant()
+ # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
+ if not (current_user.has_edit_permission or current_user.is_dataset_editor):
+ raise Forbidden()
- args = parser_remove.parse_args()
- TagService.delete_tag_binding(args)
+ TagService.save_tag_binding(payload.model_dump())
- return {"result": "success"}, 200
+ return TagBindingResult(result="success")
+
+
+@console_router.post(
+ "/tag-bindings/remove",
+ response_model=TagBindingResult,
+ tags=["console"],
+)
+@setup_required
+@login_required
+@account_initialization_required
+def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult:
+ current_user, _ = current_account_with_tenant()
+ # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
+ if not (current_user.has_edit_permission or current_user.is_dataset_editor):
+ raise Forbidden()
+
+ TagService.delete_tag_binding(payload.model_dump())
+
+ return TagBindingResult(result="success")
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index 419261ba2a..fdb23acf52 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -1,15 +1,11 @@
-import json
import logging
import httpx
-from flask import request
-from flask_restx import Resource, fields
from packaging import version
from pydantic import BaseModel, Field
from configs import dify_config
-
-from . import console_ns
+from controllers.fastopenapi import console_router
logger = logging.getLogger(__name__)
@@ -18,69 +14,61 @@ class VersionQuery(BaseModel):
current_version: str = Field(..., description="Current application version")
-console_ns.schema_model(
- VersionQuery.__name__,
- VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
+class VersionFeatures(BaseModel):
+ can_replace_logo: bool = Field(description="Whether logo replacement is supported")
+ model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled")
+
+
+class VersionResponse(BaseModel):
+ version: str = Field(description="Latest version number")
+ release_date: str = Field(description="Release date of latest version")
+ release_notes: str = Field(description="Release notes for latest version")
+ can_auto_update: bool = Field(description="Whether auto-update is supported")
+ features: VersionFeatures = Field(description="Feature flags and capabilities")
+
+
+@console_router.get(
+ "/version",
+ response_model=VersionResponse,
+ tags=["console"],
)
+def check_version_update(query: VersionQuery) -> VersionResponse:
+ """Check for application version updates."""
+ check_update_url = dify_config.CHECK_UPDATE_URL
-
-@console_ns.route("/version")
-class VersionApi(Resource):
- @console_ns.doc("check_version_update")
- @console_ns.doc(description="Check for application version updates")
- @console_ns.expect(console_ns.models[VersionQuery.__name__])
- @console_ns.response(
- 200,
- "Success",
- console_ns.model(
- "VersionResponse",
- {
- "version": fields.String(description="Latest version number"),
- "release_date": fields.String(description="Release date of latest version"),
- "release_notes": fields.String(description="Release notes for latest version"),
- "can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
- "features": fields.Raw(description="Feature flags and capabilities"),
- },
+ result = VersionResponse(
+ version=dify_config.project.version,
+ release_date="",
+ release_notes="",
+ can_auto_update=False,
+ features=VersionFeatures(
+ can_replace_logo=dify_config.CAN_REPLACE_LOGO,
+ model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED,
),
)
- def get(self):
- """Check for application version updates"""
- args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- check_update_url = dify_config.CHECK_UPDATE_URL
- result = {
- "version": dify_config.project.version,
- "release_date": "",
- "release_notes": "",
- "can_auto_update": False,
- "features": {
- "can_replace_logo": dify_config.CAN_REPLACE_LOGO,
- "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
- },
- }
-
- if not check_update_url:
- return result
-
- try:
- response = httpx.get(
- check_update_url,
- params={"current_version": args.current_version},
- timeout=httpx.Timeout(timeout=10.0, connect=3.0),
- )
- except Exception as error:
- logger.warning("Check update version error: %s.", str(error))
- result["version"] = args.current_version
- return result
-
- content = json.loads(response.content)
- if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
- result["version"] = content["version"]
- result["release_date"] = content["releaseDate"]
- result["release_notes"] = content["releaseNotes"]
- result["can_auto_update"] = content["canAutoUpdate"]
+ if not check_update_url:
return result
+ try:
+ response = httpx.get(
+ check_update_url,
+ params={"current_version": query.current_version},
+ timeout=httpx.Timeout(timeout=10.0, connect=3.0),
+ )
+ content = response.json()
+ except Exception as error:
+ logger.warning("Check update version error: %s.", str(error))
+ result.version = query.current_version
+ return result
+ latest_version = content.get("version", result.version)
+ if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"):
+ result.version = latest_version
+ result.release_date = content.get("releaseDate", "")
+ result.release_notes = content.get("releaseNotes", "")
+ result.can_auto_update = content.get("canAutoUpdate", False)
+ return result
+
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
try:
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 55eaa2f09f..38c66525b3 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from datetime import datetime
from typing import Literal
@@ -39,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
-from models import Account, AccountIntegrate, InvitationCode
+from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel):
repeat_new_password: str
@model_validator(mode="after")
- def check_passwords_match(self) -> "AccountPasswordPayload":
+ def check_passwords_match(self) -> AccountPasswordPayload:
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self
@@ -169,6 +171,19 @@ reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
+integrate_fields = {
+ "provider": fields.String,
+ "created_at": TimestampField,
+ "is_bound": fields.Boolean,
+ "link": fields.String,
+}
+
+integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
+integrate_list_model = console_ns.model(
+ "AccountIntegrateList",
+ {"data": fields.List(fields.Nested(integrate_model))},
+)
+
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@@ -334,21 +349,10 @@ class AccountPasswordApi(Resource):
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource):
- integrate_fields = {
- "provider": fields.String,
- "created_at": TimestampField,
- "is_bound": fields.Boolean,
- "link": fields.String,
- }
-
- integrate_list_fields = {
- "data": fields.List(fields.Nested(integrate_fields)),
- }
-
@setup_required
@login_required
@account_initialization_required
- @marshal_with(integrate_list_fields)
+ @marshal_with(integrate_list_model)
def get(self):
account, _ = current_account_with_tenant()
@@ -534,7 +538,8 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
- user_email = args.email
+ user_email = None
+ email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@@ -544,16 +549,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
- if user_email != current_user.email:
+ if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
+
+ user_email = current_user.email
else:
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
+ email_for_sending = account.email
+ user_email = account.email
token = AccountService.send_change_email_email(
- account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
+ account=account,
+ email=email_for_sending,
+ old_email=user_email,
+ language=language,
+ phase=args.phase,
)
return {"result": "success", "data": token}
@@ -569,9 +582,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
- user_email = args.email
+ user_email = args.email.lower()
- is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@@ -579,11 +592,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
- AccountService.add_change_email_error_rate_limit(args.email)
+ AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -594,8 +609,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
- AccountService.reset_change_email_error_rate_limit(args.email)
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_change_email_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset")
@@ -609,11 +624,12 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
+ normalized_new_email = args.new_email.lower()
- if AccountService.is_account_in_freeze(args.new_email):
+ if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.new_email):
+ if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@@ -624,13 +640,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
- if current_user.email != old_email:
+ if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
- updated_account = AccountService.update_account_email(current_user, email=args.new_email)
+ updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
- email=args.new_email,
+ email=normalized_new_email,
)
return updated_account
@@ -643,8 +659,9 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
- if AccountService.is_account_in_freeze(args.email):
+ normalized_email = args.email.lower()
+ if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
- if not AccountService.check_email_unique(args.email):
+ if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}
diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py
index 9bf393ea2e..ccb60b1461 100644
--- a/api/controllers/console/workspace/load_balancing_config.py
+++ b/api/controllers/console/workspace/load_balancing_config.py
@@ -1,6 +1,8 @@
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@@ -10,10 +12,20 @@ from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
+class LoadBalancingCredentialPayload(BaseModel):
+ model: str
+ model_type: ModelType
+ credentials: dict[str, object]
+
+
+register_schema_models(console_ns, LoadBalancingCredentialPayload)
+
+
@console_ns.route(
"/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource):
+ @console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
- parser = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
@@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
+ model=payload.model,
+ model_type=payload.model_type,
+ credentials=payload.credentials,
)
except CredentialsValidateFailedError as ex:
result = False
@@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
"/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource):
+ @console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
tenant_id = current_tenant_id
- parser = (
- reqparse.RequestParser()
- .add_argument("model", type=str, required=True, nullable=False, location="json")
- .add_argument(
- "model_type",
- type=str,
- required=True,
- nullable=False,
- choices=[mt.value for mt in ModelType],
- location="json",
- )
- .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
@@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
- model=args["model"],
- model_type=args["model_type"],
- credentials=args["credentials"],
+ model=payload.model,
+ model_type=payload.model_type,
+ credentials=payload.credentials,
config_id=config_id,
)
except CredentialsValidateFailedError as ex:
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index 0142e14fb0..271cdce3c3 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,11 +1,12 @@
from urllib import parse
from flask import abort, request
-from flask_restx import Resource, marshal_with
+from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
import services
from configs import dify_config
+from controllers.common.schema import get_or_create_model, register_enum_models
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
@@ -24,7 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
-from fields.member_fields import account_with_role_list_fields
+from fields.member_fields import account_with_role_fields, account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
@@ -67,6 +68,13 @@ reg(MemberRoleUpdatePayload)
reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
+register_enum_models(console_ns, TenantAccountRole)
+
+account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
+
+account_with_role_list_fields_copy = account_with_role_list_fields.copy()
+account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
+account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
@console_ns.route("/workspaces/current/members")
@@ -76,7 +84,7 @@ class MemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(account_with_role_list_fields)
+ @marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@@ -107,6 +115,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
+
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(inviter.current_tenant.id)
+
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
@@ -116,26 +130,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
+ normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
- inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
+ tenant=inviter.current_tenant,
+ email=invitee_email,
+ language=interface_language,
+ role=invitee_role,
+ inviter=inviter,
)
- encoded_invitee_email = parse.quote(invitee_email)
+ encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
- "email": invitee_email,
+ "email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
- {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
+ {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
- invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
+ invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",
@@ -216,7 +235,7 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(account_with_role_list_fields)
+ @marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index a5b45ef514..583e3e3057 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -5,6 +5,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@@ -23,12 +24,13 @@ class ParserGetDefault(BaseModel):
model_type: ModelType
-class ParserPostDefault(BaseModel):
- class Inner(BaseModel):
- model_type: ModelType
- model: str | None = None
- provider: str | None = None
+class Inner(BaseModel):
+ model_type: ModelType
+ model: str | None = None
+ provider: str | None = None
+
+class ParserPostDefault(BaseModel):
model_settings: list[Inner]
@@ -105,19 +107,21 @@ class ParserParameter(BaseModel):
model: str
-def reg(cls: type[BaseModel]):
- console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
+register_schema_models(
+ console_ns,
+ ParserGetDefault,
+ ParserPostDefault,
+ ParserDeleteModels,
+ ParserPostModels,
+ ParserGetCredentials,
+ ParserCreateCredential,
+ ParserUpdateCredential,
+ ParserDeleteCredential,
+ ParserParameter,
+ Inner,
+)
-
-reg(ParserGetDefault)
-reg(ParserPostDefault)
-reg(ParserDeleteModels)
-reg(ParserPostModels)
-reg(ParserGetCredentials)
-reg(ParserCreateCredential)
-reg(ParserUpdateCredential)
-reg(ParserDeleteCredential)
-reg(ParserParameter)
+register_enum_models(console_ns, ModelType)
@console_ns.route("/workspaces/current/default-model")
@@ -282,9 +286,10 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id=tenant_id, provider_name=provider
)
else:
- model_type = args.model_type
+ # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
+ normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
- tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
+ tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
)
return jsonable_encoder(
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index c5624e0fc2..d1485bc1c0 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -1,5 +1,6 @@
import io
-from typing import Literal
+from collections.abc import Mapping
+from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
@@ -7,6 +8,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@@ -19,55 +21,10 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
-DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
-
-
-def reg(cls: type[BaseModel]):
- console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
-
-
-@console_ns.route("/workspaces/current/plugin/debugging-key")
-class PluginDebuggingKeyApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- @plugin_permission_required(debug_required=True)
- def get(self):
- _, tenant_id = current_account_with_tenant()
-
- try:
- return {
- "key": PluginService.get_debugging_key(tenant_id),
- "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
- "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
- }
- except PluginDaemonClientSideError as e:
- raise ValueError(e)
-
class ParserList(BaseModel):
- page: int = Field(default=1)
- page_size: int = Field(default=256)
-
-
-reg(ParserList)
-
-
-@console_ns.route("/workspaces/current/plugin/list")
-class PluginListApi(Resource):
- @console_ns.expect(console_ns.models[ParserList.__name__])
- @setup_required
- @login_required
- @account_initialization_required
- def get(self):
- _, tenant_id = current_account_with_tenant()
- args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
- try:
- plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
- except PluginDaemonClientSideError as e:
- raise ValueError(e)
-
- return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
+ page: int = Field(default=1, ge=1, description="Page number")
+ page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
class ParserLatest(BaseModel):
@@ -106,8 +63,8 @@ class ParserPluginIdentifierQuery(BaseModel):
class ParserTasks(BaseModel):
- page: int
- page_size: int
+ page: int = Field(default=1, ge=1, description="Page number")
+ page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
class ParserMarketplaceUpgrade(BaseModel):
@@ -141,6 +98,15 @@ class ParserDynamicOptions(BaseModel):
provider_type: Literal["tool", "trigger"]
+class ParserDynamicOptionsWithCredentials(BaseModel):
+ plugin_id: str
+ provider: str
+ action: str
+ parameter: str
+ credential_id: str
+ credentials: Mapping[str, Any]
+
+
class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
@@ -170,22 +136,73 @@ class ParserReadme(BaseModel):
language: str = Field(default="en-US")
-reg(ParserLatest)
-reg(ParserIcon)
-reg(ParserAsset)
-reg(ParserGithubUpload)
-reg(ParserPluginIdentifiers)
-reg(ParserGithubInstall)
-reg(ParserPluginIdentifierQuery)
-reg(ParserTasks)
-reg(ParserMarketplaceUpgrade)
-reg(ParserGithubUpgrade)
-reg(ParserUninstall)
-reg(ParserPermissionChange)
-reg(ParserDynamicOptions)
-reg(ParserPreferencesChange)
-reg(ParserExcludePlugin)
-reg(ParserReadme)
+register_schema_models(
+ console_ns,
+ ParserList,
+ PluginAutoUpgradeSettingsPayload,
+ PluginPermissionSettingsPayload,
+ ParserLatest,
+ ParserIcon,
+ ParserAsset,
+ ParserGithubUpload,
+ ParserPluginIdentifiers,
+ ParserGithubInstall,
+ ParserPluginIdentifierQuery,
+ ParserTasks,
+ ParserMarketplaceUpgrade,
+ ParserGithubUpgrade,
+ ParserUninstall,
+ ParserPermissionChange,
+ ParserDynamicOptions,
+ ParserDynamicOptionsWithCredentials,
+ ParserPreferencesChange,
+ ParserExcludePlugin,
+ ParserReadme,
+)
+
+register_enum_models(
+ console_ns,
+ TenantPluginPermission.DebugPermission,
+ TenantPluginAutoUpgradeStrategy.UpgradeMode,
+ TenantPluginAutoUpgradeStrategy.StrategySetting,
+ TenantPluginPermission.InstallPermission,
+)
+
+
+@console_ns.route("/workspaces/current/plugin/debugging-key")
+class PluginDebuggingKeyApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @plugin_permission_required(debug_required=True)
+ def get(self):
+ _, tenant_id = current_account_with_tenant()
+
+ try:
+ return {
+ "key": PluginService.get_debugging_key(tenant_id),
+ "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
+ "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
+ }
+ except PluginDaemonClientSideError as e:
+ raise ValueError(e)
+
+
+@console_ns.route("/workspaces/current/plugin/list")
+class PluginListApi(Resource):
+ @console_ns.expect(console_ns.models[ParserList.__name__])
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ _, tenant_id = current_account_with_tenant()
+ args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
+ try:
+ plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
+ except PluginDaemonClientSideError as e:
+ raise ValueError(e)
+
+ return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
@@ -657,6 +674,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
+@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
+class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
+ @console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
+ @setup_required
+ @login_required
+ @is_admin_or_owner_required
+ @account_initialization_required
+ def post(self):
+ """Fetch dynamic options using credentials directly (for edit mode)."""
+ current_user, tenant_id = current_account_with_tenant()
+ user_id = current_user.id
+
+ args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
+
+ try:
+ options = PluginParameterService.get_dynamic_select_options_with_credentials(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ plugin_id=args.plugin_id,
+ provider=args.provider,
+ action=args.action,
+ parameter=args.parameter,
+ credential_id=args.credential_id,
+ credentials=args.credentials,
+ )
+ except PluginDaemonClientSideError as e:
+ raise ValueError(e)
+
+ return jsonable_encoder({"options": options})
+
+
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 2c54aa5a20..e9e7b72718 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1,4 +1,5 @@
import io
+import logging
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
@@ -17,6 +18,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
+from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
@@ -39,6 +41,8 @@ from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
+logger = logging.getLogger(__name__)
+
def is_valid_url(url: str) -> bool:
if not url:
@@ -944,8 +948,8 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
- # Create provider
- with Session(db.engine) as session, session.begin():
+ # 1) Create provider in a short transaction (no network I/O inside)
+ with session_factory.create_session() as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
@@ -960,7 +964,29 @@ class ToolProviderMCPApi(Resource):
configuration=configuration,
authentication=authentication,
)
- return jsonable_encoder(result)
+
+ # 2) Try to fetch tools immediately after creation so they appear without a second save.
+ # Perform network I/O outside any DB session to avoid holding locks.
+ try:
+ reconnect = MCPToolManageService.reconnect_with_url(
+ server_url=args["server_url"],
+ headers=args.get("headers") or {},
+ timeout=configuration.timeout,
+ sse_read_timeout=configuration.sse_read_timeout,
+ )
+ # Update just-created provider with authed/tools in a new short transaction
+ with session_factory.create_session() as session, session.begin():
+ service = MCPToolManageService(session=session)
+ db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
+ db_provider.authed = reconnect.authed
+ db_provider.tools = reconnect.tools
+
+ result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
+ except Exception:
+ # Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
+ logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
+
+ return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put)
@setup_required
@@ -972,17 +998,23 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
- # Step 1: Validate server URL change if needed (includes URL format validation and network operation)
- validation_result = None
+ # Step 1: Get provider data for URL validation (short-lived session, no network I/O)
+ validation_data = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
- validation_result = service.validate_server_url_change(
- tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
+ validation_data = service.get_provider_for_url_validation(
+ tenant_id=current_tenant_id, provider_id=args["provider_id"]
)
- # No need to check for errors here, exceptions will be raised directly
+ # Step 2: Perform URL validation with network I/O OUTSIDE of any database session
+ # This prevents holding database locks during potentially slow network operations
+ validation_result = MCPToolManageService.validate_server_url_standalone(
+ tenant_id=current_tenant_id,
+ new_server_url=args["server_url"],
+ validation_data=validation_data,
+ )
- # Step 2: Perform database update in a transaction
+ # Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
@@ -999,7 +1031,8 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
validation_result=validation_result,
)
- return {"result": "success"}
+
+ return {"result": "success"}
@console_ns.expect(parser_mcp_delete)
@setup_required
@@ -1012,7 +1045,8 @@ class ToolProviderMCPApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
- return {"result": "success"}
+
+ return {"result": "success"}
parser_auth = (
diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py
index 268473d6d1..6b642af613 100644
--- a/api/controllers/console/workspace/trigger_providers.py
+++ b/api/controllers/console/workspace/trigger_providers.py
@@ -1,11 +1,14 @@
import logging
+from typing import Any
from flask import make_response, redirect, request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
+from controllers.common.schema import register_schema_models
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@@ -32,6 +35,41 @@ from ..wraps import (
logger = logging.getLogger(__name__)
+class TriggerSubscriptionBuilderCreatePayload(BaseModel):
+ credential_type: str = CredentialType.UNAUTHORIZED
+
+
+class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
+ credentials: dict[str, Any]
+
+
+class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
+ name: str | None = None
+ parameters: dict[str, Any] | None = None
+ properties: dict[str, Any] | None = None
+ credentials: dict[str, Any] | None = None
+
+ @model_validator(mode="after")
+ def check_at_least_one_field(self):
+ if all(v is None for v in self.model_dump().values()):
+ raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
+ return self
+
+
+class TriggerOAuthClientPayload(BaseModel):
+ client_params: dict[str, Any] | None = None
+ enabled: bool | None = None
+
+
+register_schema_models(
+ console_ns,
+ TriggerSubscriptionBuilderCreatePayload,
+ TriggerSubscriptionBuilderVerifyPayload,
+ TriggerSubscriptionBuilderUpdatePayload,
+ TriggerOAuthClientPayload,
+)
+
+
@console_ns.route("/workspaces/current/trigger-provider//icon")
class TriggerProviderIconApi(Resource):
@setup_required
@@ -97,16 +135,11 @@ class TriggerSubscriptionListApi(Resource):
raise
-parser = reqparse.RequestParser().add_argument(
- "credential_type", type=str, required=False, nullable=True, location="json"
-)
-
-
@console_ns.route(
"/workspaces/current/trigger-provider//subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource):
- @console_ns.expect(parser)
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -116,10 +149,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- args = parser.parse_args()
+ payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
try:
- credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
+ credential_type = CredentialType.of(payload.credential_type)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@@ -147,28 +180,21 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
-parser_api = (
- reqparse.RequestParser()
- # The credentials of the subscription builder
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route(
- "/workspaces/current/trigger-provider//subscriptions/builder/verify/",
+ "/workspaces/current/trigger-provider//subscriptions/builder/verify-and-update/",
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
- @console_ns.expect(parser_api)
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
- """Verify a subscription instance for a trigger provider"""
+ """Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
- args = parser_api.parse_args()
+ payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_verify to prevent race conditions
@@ -178,7 +204,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
- credentials=args.get("credentials", None),
+ credentials=payload.credentials,
),
)
except Exception as e:
@@ -186,24 +212,11 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
raise ValueError(str(e)) from e
-parser_update_api = (
- reqparse.RequestParser()
- # The name of the subscription builder
- .add_argument("name", type=str, required=False, nullable=True, location="json")
- # The parameters of the subscription builder
- .add_argument("parameters", type=dict, required=False, nullable=True, location="json")
- # The properties of the subscription builder
- .add_argument("properties", type=dict, required=False, nullable=True, location="json")
- # The credentials of the subscription builder
- .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route(
"/workspaces/current/trigger-provider//subscriptions/builder/update/",
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
- @console_ns.expect(parser_update_api)
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -214,7 +227,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
- args = parser_update_api.parse_args()
+ payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@@ -222,10 +235,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
- name=args.get("name", None),
- parameters=args.get("parameters", None),
- properties=args.get("properties", None),
- credentials=args.get("credentials", None),
+ name=payload.name,
+ parameters=payload.parameters,
+ properties=payload.properties,
+ credentials=payload.credentials,
),
)
)
@@ -260,7 +273,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
"/workspaces/current/trigger-provider//subscriptions/builder/build/",
)
class TriggerSubscriptionBuilderBuildApi(Resource):
- @console_ns.expect(parser_update_api)
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -269,7 +282,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
- args = parser_update_api.parse_args()
+ payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@@ -278,9 +291,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
- name=args.get("name", None),
- parameters=args.get("parameters", None),
- properties=args.get("properties", None),
+ name=payload.name,
+ parameters=payload.parameters,
+ properties=payload.properties,
),
)
return 200
@@ -289,6 +302,65 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/update",
+)
+class TriggerSubscriptionUpdateApi(Resource):
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
+ @setup_required
+ @login_required
+ @edit_permission_required
+ @account_initialization_required
+ def post(self, subscription_id: str):
+ """Update a subscription instance"""
+ user = current_user
+ assert user.current_tenant_id is not None
+
+ request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
+
+ subscription = TriggerProviderService.get_subscription_by_id(
+ tenant_id=user.current_tenant_id,
+ subscription_id=subscription_id,
+ )
+ if not subscription:
+ raise NotFoundError(f"Subscription {subscription_id} not found")
+
+ provider_id = TriggerProviderID(subscription.provider_id)
+
+ try:
+ # For rename only, just update the name
+ rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
+ # When credential type is UNAUTHORIZED, it indicates the subscription was manually created
+ # For Manually created subscription, they dont have credentials, parameters
+ # They only have name and properties(which is input by user)
+ manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
+ if rename or manually_created:
+ TriggerProviderService.update_trigger_subscription(
+ tenant_id=user.current_tenant_id,
+ subscription_id=subscription_id,
+ name=request.name,
+ properties=request.properties,
+ )
+ return 200
+
+ # For the rest cases(API_KEY, OAUTH2)
+ # we need to call third party provider(e.g. GitHub) to rebuild the subscription
+ TriggerProviderService.rebuild_trigger_subscription(
+ tenant_id=user.current_tenant_id,
+ name=request.name,
+ provider_id=provider_id,
+ subscription_id=subscription_id,
+ credentials=request.credentials or subscription.credentials,
+ parameters=request.parameters or subscription.parameters,
+ )
+ return 200
+ except ValueError as e:
+ raise BadRequest(str(e))
+ except Exception as e:
+ logger.exception("Error updating subscription", exc_info=e)
+ raise
+
+
@console_ns.route(
"/workspaces/current/trigger-provider//subscriptions/delete",
)
@@ -474,13 +546,6 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
-parser_oauth_client = (
- reqparse.RequestParser()
- .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
- .add_argument("enabled", type=bool, required=False, nullable=True, location="json")
-)
-
-
@console_ns.route("/workspaces/current/trigger-provider//oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@@ -528,7 +593,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
- @console_ns.expect(parser_oauth_client)
+ @console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -538,15 +603,15 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
- args = parser_oauth_client.parse_args()
+ payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
- client_params=args.get("client_params"),
- enabled=args.get("enabled"),
+ client_params=payload.client_params,
+ enabled=payload.enabled,
)
except ValueError as e:
@@ -576,3 +641,36 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
+
+
+@console_ns.route(
+ "/workspaces/current/trigger-provider//subscriptions/verify/",
+)
+class TriggerSubscriptionVerifyApi(Resource):
+ @console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
+ @setup_required
+ @login_required
+ @edit_permission_required
+ @account_initialization_required
+ def post(self, provider, subscription_id):
+ """Verify credentials for an existing subscription (edit mode only)"""
+ user = current_user
+ assert user.current_tenant_id is not None
+
+ verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
+
+ try:
+ result = TriggerProviderService.verify_subscription_credentials(
+ tenant_id=user.current_tenant_id,
+ user_id=user.id,
+ provider_id=TriggerProviderID(provider),
+ subscription_id=subscription_id,
+ credentials=verify_request.credentials,
+ )
+ return result
+ except ValueError as e:
+ logger.warning("Credential verification failed", exc_info=e)
+ raise BadRequest(str(e)) from e
+ except Exception as e:
+ logger.exception("Error verifying subscription credentials", exc_info=e)
+ raise BadRequest(str(e)) from e
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index 909a5ce201..94be81d94f 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
+ only_edition_enterprise,
setup_required,
)
from enums.cloud_plan import CloudPlan
@@ -28,6 +29,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
+from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@@ -80,6 +82,9 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
+ "trial_credits": fields.Integer,
+ "trial_credits_used": fields.Integer,
+ "next_credit_reset_date": fields.Integer,
}
tenants_fields = {
@@ -285,3 +290,31 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
+
+
+@console_ns.route("/workspaces/current/permission")
+class WorkspacePermissionApi(Resource):
+ """Get workspace permissions for the current workspace."""
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @only_edition_enterprise
+ def get(self):
+ """
+ Get workspace permission settings.
+ Returns permission flags that control workspace features like member invitations and owner transfer.
+ """
+ _, current_tenant_id = current_account_with_tenant()
+
+ if not current_tenant_id:
+ raise ValueError("No current tenant")
+
+ # Get workspace permissions from enterprise service
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
+
+ return {
+ "workspace_id": permission.workspace_id,
+ "allow_member_invite": permission.allow_member_invite,
+ "allow_owner_transfer": permission.allow_owner_transfer,
+ }, 200
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index f40f566a36..fd928b077d 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar
from flask import abort, request
from configs import dify_config
+from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
@@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo
P = ParamSpec("P")
R = TypeVar("R")
+# Field names for decryption
+FIELD_NAME_PASSWORD = "password"
+FIELD_NAME_CODE = "code"
+
+# Error messages for decryption failures
+ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
+ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
+
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
@@ -276,13 +286,12 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
- _, current_tenant_id = current_account_with_tenant()
- features = FeatureService.get_features(current_tenant_id)
- if features.is_allow_transfer_workspace:
- return view(*args, **kwargs)
+ from libs.workspace_permission import check_workspace_owner_transfer_permission
- # otherwise, return 403
- abort(403)
+ _, current_tenant_id = current_account_with_tenant()
+ # Check both billing/plan level and workspace policy level permissions
+ check_workspace_owner_transfer_permission(current_tenant_id)
+ return view(*args, **kwargs)
return decorated
@@ -331,3 +340,163 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return f(*args, **kwargs)
return decorated_function
+
+
+def annotation_import_rate_limit(view: Callable[P, R]):
+ """
+ Rate limiting decorator for annotation import operations.
+
+ Implements sliding window rate limiting with two tiers:
+ - Short-term: Configurable requests per minute (default: 5)
+ - Long-term: Configurable requests per hour (default: 20)
+
+ Uses Redis ZSET for distributed rate limiting across multiple instances.
+ """
+
+ @wraps(view)
+ def decorated(*args: P.args, **kwargs: P.kwargs):
+ _, current_tenant_id = current_account_with_tenant()
+ current_time = int(time.time() * 1000)
+
+ # Check per-minute rate limit
+ minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
+ redis_client.zadd(minute_key, {current_time: current_time})
+ redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
+ minute_count = redis_client.zcard(minute_key)
+ redis_client.expire(minute_key, 120) # 2 minutes TTL
+
+ if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
+ abort(
+ 429,
+ f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} "
+ f"requests per minute allowed. Please try again later.",
+ )
+
+ # Check per-hour rate limit
+ hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour"
+ redis_client.zadd(hour_key, {current_time: current_time})
+ redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
+ hour_count = redis_client.zcard(hour_key)
+ redis_client.expire(hour_key, 7200) # 2 hours TTL
+
+ if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
+ abort(
+ 429,
+ f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} "
+ f"requests per hour allowed. Please try again later.",
+ )
+
+ return view(*args, **kwargs)
+
+ return decorated
+
+
+def annotation_import_concurrency_limit(view: Callable[P, R]):
+ """
+ Concurrency control decorator for annotation import operations.
+
+ Limits the number of concurrent import tasks per tenant to prevent
+ resource exhaustion and ensure fair resource allocation.
+
+ Uses Redis ZSET to track active import jobs with automatic cleanup
+ of stale entries (jobs older than 2 minutes).
+ """
+
+ @wraps(view)
+ def decorated(*args: P.args, **kwargs: P.kwargs):
+ _, current_tenant_id = current_account_with_tenant()
+ current_time = int(time.time() * 1000)
+
+ active_jobs_key = f"annotation_import_active:{current_tenant_id}"
+
+ # Clean up stale entries (jobs that should have completed or timed out)
+ stale_threshold = current_time - 120000 # 2 minutes ago
+ redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold)
+
+ # Check current active job count
+ active_count = redis_client.zcard(active_jobs_key)
+
+ if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT:
+ abort(
+ 429,
+ f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} "
+ f"concurrent imports allowed per workspace. Please wait for existing imports to complete.",
+ )
+
+ # Allow the request to proceed
+ # The actual job registration will happen in the service layer
+ return view(*args, **kwargs)
+
+ return decorated
+
+
+def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None:
+ """
+ Helper to decode a Base64 encoded field in the request payload.
+
+ Args:
+ field_name: Name of the field to decode
+ error_class: Exception class to raise on decoding failure
+ error_message: Error message to include in the exception
+ """
+ if not request or not request.is_json:
+ return
+ # Get the payload dict - it's cached and mutable
+ payload = request.get_json()
+ if not payload or field_name not in payload:
+ return
+ encoded_value = payload[field_name]
+ decoded_value = FieldEncryption.decrypt_field(encoded_value)
+
+ # If decoding failed, raise error immediately
+ if decoded_value is None:
+ raise error_class(error_message)
+
+ # Update payload dict in-place with decoded value
+ # Since payload is a mutable dict and get_json() returns the cached reference,
+ # modifying it will affect all subsequent accesses including console_ns.payload
+ payload[field_name] = decoded_value
+
+
+def decrypt_password_field(view: Callable[P, R]):
+ """
+ Decorator to decrypt password field in request payload.
+
+ Automatically decrypts the 'password' field if encryption is enabled.
+ If decryption fails, raises AuthenticationFailedError.
+
+ Usage:
+ @decrypt_password_field
+ def post(self):
+ args = LoginPayload.model_validate(console_ns.payload)
+ # args.password is now decrypted
+ """
+
+ @wraps(view)
+ def decorated(*args: P.args, **kwargs: P.kwargs):
+ _decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA)
+ return view(*args, **kwargs)
+
+ return decorated
+
+
+def decrypt_code_field(view: Callable[P, R]):
+ """
+ Decorator to decrypt verification code field in request payload.
+
+ Automatically decrypts the 'code' field if encryption is enabled.
+ If decryption fails, raises EmailCodeError.
+
+ Usage:
+ @decrypt_code_field
+ def post(self):
+ args = EmailCodeLoginPayload.model_validate(console_ns.payload)
+ # args.code is now decrypted
+ """
+
+ @wraps(view)
+ def decorated(*args: P.args, **kwargs: P.kwargs):
+ _decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE)
+ return view(*args, **kwargs)
+
+ return decorated
diff --git a/api/controllers/fastopenapi.py b/api/controllers/fastopenapi.py
new file mode 100644
index 0000000000..c13f22338b
--- /dev/null
+++ b/api/controllers/fastopenapi.py
@@ -0,0 +1,3 @@
+from fastopenapi.routers import FlaskRouter
+
+console_router = FlaskRouter()
diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py
index 64f47f426a..04db1c67cb 100644
--- a/api/controllers/files/image_preview.py
+++ b/api/controllers/files/image_preview.py
@@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
+from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from extensions.ext_database import db
from services.account_service import TenantService
@@ -138,6 +139,13 @@ class FilePreviewApi(Resource):
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"
+ enforce_download_for_html(
+ response,
+ mime_type=upload_file.mime_type,
+ filename=upload_file.name,
+ extension=upload_file.extension,
+ )
+
return response
diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py
index c487a0a915..89aa472015 100644
--- a/api/controllers/files/tool_files.py
+++ b/api/controllers/files/tool_files.py
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
+from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
@@ -78,4 +79,11 @@ class ToolFileApi(Resource):
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
+ enforce_download_for_html(
+ response,
+ mime_type=tool_file.mimetype,
+ filename=tool_file.name,
+ extension=extension,
+ )
+
return response
diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py
index 6096a87c56..28ec4b3935 100644
--- a/api/controllers/files/upload.py
+++ b/api/controllers/files/upload.py
@@ -4,18 +4,18 @@ from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
-from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
import services
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
-from fields.file_fields import build_file_model
+from fields.file_fields import FileResponse
from ..common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
+from ..common.schema import register_schema_models
from ..console.wraps import setup_required
from ..files import files_ns
from ..inner_api.plugin.wraps import get_user
@@ -35,6 +35,8 @@ files_ns.schema_model(
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
+register_schema_models(files_ns, FileResponse)
+
@files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource):
@@ -51,7 +53,7 @@ class PluginUploadFileApi(Resource):
415: "Unsupported file type",
}
)
- @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED)
+ @files_ns.response(HTTPStatus.CREATED, "File uploaded", files_ns.models[FileResponse.__name__])
def post(self):
"""Upload a file for plugin usage.
@@ -69,7 +71,7 @@ class PluginUploadFileApi(Resource):
"""
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
- file: FileStorage | None = request.files.get("file")
+ file = request.files.get("file")
if file is None:
raise Forbidden("File is required.")
@@ -80,8 +82,8 @@ class PluginUploadFileApi(Resource):
user_id = args.user_id
user = get_user(tenant_id, user_id)
- filename: str | None = file.filename
- mimetype: str | None = file.mimetype
+ filename = file.filename
+ mimetype = file.mimetype
if not filename or not mimetype:
raise Forbidden("Invalid request.")
@@ -111,22 +113,22 @@ class PluginUploadFileApi(Resource):
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
# Create a dictionary with all the necessary attributes
- result = {
- "id": tool_file.id,
- "user_id": tool_file.user_id,
- "tenant_id": tool_file.tenant_id,
- "conversation_id": tool_file.conversation_id,
- "file_key": tool_file.file_key,
- "mimetype": tool_file.mimetype,
- "original_url": tool_file.original_url,
- "name": tool_file.name,
- "size": tool_file.size,
- "mime_type": mimetype,
- "extension": extension,
- "preview_url": preview_url,
- }
+ result = FileResponse(
+ id=tool_file.id,
+ name=tool_file.name,
+ size=tool_file.size,
+ extension=extension,
+ mime_type=mimetype,
+ preview_url=preview_url,
+ source_url=tool_file.original_url,
+ original_url=tool_file.original_url,
+ user_id=tool_file.user_id,
+ tenant_id=tool_file.tenant_id,
+ conversation_id=tool_file.conversation_id,
+ file_key=tool_file.file_key,
+ )
- return result, 201
+ return result.model_dump(mode="json"), 201
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py
index 63c373b50f..85ac9336d6 100644
--- a/api/controllers/service_api/app/annotation.py
+++ b/api/controllers/service_api/app/annotation.py
@@ -1,7 +1,7 @@
from typing import Literal
from flask import request
-from flask_restx import Api, Namespace, Resource, fields
+from flask_restx import Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
@@ -92,7 +92,7 @@ annotation_list_fields = {
}
-def build_annotation_list_model(api_or_ns: Api | Namespace):
+def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py
index 25d7ccccec..562f5e33cc 100644
--- a/api/controllers/service_api/app/app.py
+++ b/api/controllers/service_api/app/app.py
@@ -1,6 +1,6 @@
from flask_restx import Resource
-from controllers.common.fields import build_parameters_model
+from controllers.common.fields import Parameters
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
@@ -23,7 +23,6 @@ class AppParameterApi(Resource):
}
)
@validate_app_token
- @service_api_ns.marshal_with(build_parameters_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app parameters.
@@ -45,7 +44,8 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", [])
- return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return Parameters.model_validate(parameters).model_dump(mode="json")
@service_api_ns.route("/meta")
diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py
index b7fb01c6fe..b3836f3a47 100644
--- a/api/controllers/service_api/app/completion.py
+++ b/api/controllers/service_api/app/completion.py
@@ -61,6 +61,9 @@ class ChatRequestPayload(BaseModel):
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
+ if isinstance(value, str):
+ value = value.strip()
+
if not value:
return None
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index be6d837032..62e8258e25 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -3,8 +3,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
-from flask_restx._http import HTTPStatus
-from pydantic import BaseModel, Field, model_validator
+from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
- build_conversation_delete_model,
- build_conversation_infinite_scroll_pagination_model,
- build_simple_conversation_model,
+ ConversationDelete,
+ ConversationInfiniteScrollPagination,
+ SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
@@ -51,6 +50,32 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
+ variable_name: str | None = Field(
+ default=None, description="Filter variables by name", min_length=1, max_length=255
+ )
+
+ @field_validator("variable_name", mode="before")
+ @classmethod
+ def validate_variable_name(cls, v: str | None) -> str | None:
+ """
+ Validate variable_name to prevent injection attacks.
+ """
+ if v is None:
+ return v
+
+ # Only allow safe characters: alphanumeric, underscore, hyphen, period
+ if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
+ raise ValueError(
+ "Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
+ )
+
+ # Prevent SQL injection patterns
+ dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
+ for pattern in dangerous_patterns:
+ if pattern in v.lower():
+ raise ValueError(f"Variable name contains invalid characters: {pattern}")
+
+ return v
class ConversationVariableUpdatePayload(BaseModel):
@@ -79,7 +104,6 @@ class ConversationApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
- @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List all conversations for the current user.
@@ -94,7 +118,7 @@ class ConversationApi(Resource):
try:
with Session(db.engine) as session:
- return ConversationService.pagination_by_last_id(
+ pagination = ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@@ -103,6 +127,13 @@ class ConversationApi(Resource):
invoke_from=InvokeFrom.SERVICE_API,
sort_by=query_args.sort_by,
)
+ adapter = TypeAdapter(SimpleConversation)
+ conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
+ return ConversationInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=conversations,
+ ).model_dump(mode="json")
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -120,7 +151,6 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
- @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
@@ -133,7 +163,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}, 204
+ return ConversationDelete(result="success").model_dump(mode="json"), 204
@service_api_ns.route("/conversations//name")
@@ -150,7 +180,6 @@ class ConversationRenameApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
- @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
@@ -162,7 +191,14 @@ class ConversationRenameApi(Resource):
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try:
- return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
+ conversation = ConversationService.rename(
+ app_model, conversation_id, end_user, payload.name, payload.auto_generate
+ )
+ return (
+ TypeAdapter(SimpleConversation)
+ .validate_python(conversation, from_attributes=True)
+ .model_dump(mode="json")
+ )
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -199,7 +235,7 @@ class ConversationVariablesApi(Resource):
try:
return ConversationService.get_conversational_variable(
- app_model, conversation_id, end_user, query_args.limit, last_id
+ app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py
index ffe4e0b492..6f6dadf768 100644
--- a/api/controllers/service_api/app/file.py
+++ b/api/controllers/service_api/app/file.py
@@ -10,13 +10,16 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
-from fields.file_fields import build_file_model
+from fields.file_fields import FileResponse
from models import App, EndUser
from services.file_service import FileService
+register_schema_models(service_api_ns, FileResponse)
+
@service_api_ns.route("/files/upload")
class FileApi(Resource):
@@ -31,8 +34,8 @@ class FileApi(Resource):
415: "Unsupported file type",
}
)
- @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
- @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED)
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
+ @service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
def post(self, app_model: App, end_user: EndUser):
"""Upload a file for use in conversations.
@@ -64,4 +67,5 @@ class FileApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
- return upload_file, 201
+ response = FileResponse.model_validate(upload_file, from_attributes=True)
+ return response.model_dump(mode="json"), 201
diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py
index 60f422b88e..f853a124ef 100644
--- a/api/controllers/service_api/app/file_preview.py
+++ b/api/controllers/service_api/app/file_preview.py
@@ -5,6 +5,7 @@ from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
+from controllers.common.file_response import enforce_download_for_html
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@@ -183,6 +184,13 @@ class FilePreviewApi(Resource):
# Override content-type for downloads to force download
response.headers["Content-Type"] = "application/octet-stream"
+ enforce_download_for_html(
+ response,
+ mime_type=upload_file.mime_type,
+ filename=upload_file.name,
+ extension=upload_file.extension,
+ )
+
# Add caching headers for performance
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index d342f4e661..8981bbd7d5 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -1,11 +1,10 @@
-import json
import logging
from typing import Literal
from uuid import UUID
from flask import request
-from flask_restx import Namespace, Resource, fields
-from pydantic import BaseModel, Field
+from flask_restx import Resource
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
-from fields.conversation_fields import build_message_file_model
-from fields.message_fields import build_agent_thought_model, build_feedback_model
-from fields.raws import FilesContainedField
-from libs.helper import TimestampField
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel):
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
-def build_message_model(api_or_ns: Namespace):
- """Build the message model for the API or Namespace."""
- # First build the nested models
- feedback_model = build_feedback_model(api_or_ns)
- agent_thought_model = build_agent_thought_model(api_or_ns)
- message_file_model = build_message_file_model(api_or_ns)
-
- # Then build the message fields with nested models
- message_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "parent_message_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "message_files": fields.List(fields.Nested(message_file_model)),
- "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
- "retriever_resources": fields.Raw(
- attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
- if obj.message_metadata
- else []
- ),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
- "status": fields.String,
- "error": fields.String,
- }
- return api_or_ns.model("Message", message_fields)
-
-
-def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
- """Build the message infinite scroll pagination model for the API or Namespace."""
- # Build the nested message model first
- message_model = build_message_model(api_or_ns)
-
- message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_model)),
- }
- return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
-
-
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@@ -104,7 +58,6 @@ class MessageListApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
- @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List messages in a conversation.
@@ -119,9 +72,16 @@ class MessageListApi(Resource):
first_id = str(query_args.first_id) if query_args.first_id else None
try:
- return MessageService.pagination_by_first_id(
+ pagination = MessageService.pagination_by_first_id(
app_model, end_user, conversation_id, first_id, query_args.limit
)
+ adapter = TypeAdapter(MessageListItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return MessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@service_api_ns.route("/app/feedbacks")
diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py
index 9f8324a84e..8b47a887bb 100644
--- a/api/controllers/service_api/app/site.py
+++ b/api/controllers/service_api/app/site.py
@@ -1,7 +1,7 @@
from flask_restx import Resource
from werkzeug.exceptions import Forbidden
-from controllers.common.fields import build_site_model
+from controllers.common.fields import Site as SiteResponse
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
@@ -23,7 +23,6 @@ class AppSiteApi(Resource):
}
)
@validate_app_token
- @service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app site info.
@@ -38,4 +37,4 @@ class AppSiteApi(Resource):
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
- return site
+ return SiteResponse.model_validate(site).model_dump(mode="json")
diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py
index 4964888fd6..6a549fc926 100644
--- a/api/controllers/service_api/app/workflow.py
+++ b/api/controllers/service_api/app/workflow.py
@@ -3,7 +3,7 @@ from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
-from flask_restx import Api, Namespace, Resource, fields
+from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@@ -78,7 +78,7 @@ workflow_run_fields = {
}
-def build_workflow_run_model(api_or_ns: Api | Namespace):
+def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index 7692aeed23..c11f64585a 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -2,7 +2,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
- validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
@@ -27,6 +26,14 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
+DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
+
+
+service_api_ns.schema_model(
+ DatasetPermissionEnum.__name__,
+ TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
+)
+
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
@@ -39,6 +46,7 @@ class DatasetCreatePayload(BaseModel):
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
+ summary_index_setting: dict | None = None
class DatasetUpdatePayload(BaseModel):
@@ -49,7 +57,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
- partial_member_list: list[str] | None = None
+ partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
@@ -88,6 +96,14 @@ class TagUnbindingPayload(BaseModel):
target_id: str
+class DatasetListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+ include_all: bool = Field(default=False, description="Include all datasets")
+ tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
+
+
register_schema_models(
service_api_ns,
DatasetCreatePayload,
@@ -97,6 +113,7 @@ register_schema_models(
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
+ DatasetListQuery,
)
@@ -114,15 +131,11 @@ class DatasetListApi(DatasetApiResource):
)
def get(self, tenant_id):
"""Resource for getting datasets."""
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
+ query = DatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
- search = request.args.get("keyword", default=None, type=str)
- tag_ids = request.args.getlist("tag_ids")
- include_all = request.args.get("include_all", default="false").lower() == "true"
datasets, total = DatasetService.get_datasets(
- page, limit, tenant_id, current_user, search, tag_ids, include_all
+ query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
)
# check embedding setting
provider_manager = ProviderManager()
@@ -148,7 +161,13 @@ class DatasetListApi(DatasetApiResource):
item["embedding_available"] = False
else:
item["embedding_available"] = True
- response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
+ response = {
+ "data": data,
+ "has_more": len(datasets) == query.limit,
+ "limit": query.limit,
+ "total": total,
+ "page": query.page,
+ }
return response, 200
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@@ -199,6 +218,7 @@ class DatasetListApi(DatasetApiResource):
embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=payload.embedding_model,
retrieval_model=payload.retrieval_model,
+ summary_index_setting=payload.summary_index_setting,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -460,9 +480,8 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
- @validate_dataset_token
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
- def get(self, _, dataset_id):
+ def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@@ -482,8 +501,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
- @validate_dataset_token
- def post(self, _, dataset_id):
+ def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -506,8 +524,7 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
- @validate_dataset_token
- def patch(self, _, dataset_id):
+ def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@@ -533,9 +550,8 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
- @validate_dataset_token
@edit_permission_required
- def delete(self, _, dataset_id):
+ def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
@@ -555,8 +571,7 @@ class DatasetTagBindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
- @validate_dataset_token
- def post(self, _, dataset_id):
+ def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -580,8 +595,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
- @validate_dataset_token
- def post(self, _, dataset_id):
+ def post(self, _):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -604,7 +618,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
- @validate_dataset_token
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index c800c0e4e1..a01524f1bc 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -16,6 +16,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_enum_models, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
@@ -29,13 +30,22 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DatasetService, DocumentService
-from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
+from services.entities.knowledge_entities.knowledge_entities import (
+ KnowledgeConfig,
+ PreProcessingRule,
+ ProcessRule,
+ RetrievalModel,
+ Rule,
+ Segmentation,
+)
from services.file_service import FileService
+from services.summary_index_service import SummaryIndexService
class DocumentTextCreatePayload(BaseModel):
@@ -69,8 +79,26 @@ class DocumentTextUpdate(BaseModel):
return self
-for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
- service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
+class DocumentListQuery(BaseModel):
+ page: int = Field(default=1, description="Page number")
+ limit: int = Field(default=20, description="Number of items per page")
+ keyword: str | None = Field(default=None, description="Search keyword")
+ status: str | None = Field(default=None, description="Document status filter")
+
+
+register_enum_models(service_api_ns, RetrievalMethod)
+
+register_schema_models(
+ service_api_ns,
+ ProcessRule,
+ RetrievalModel,
+ DocumentTextCreatePayload,
+ DocumentTextUpdate,
+ DocumentListQuery,
+ Rule,
+ PreProcessingRule,
+ Segmentation,
+)
@service_api_ns.route(
@@ -261,17 +289,6 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
- args = {}
- if "data" in request.form:
- args = json.loads(request.form["data"])
- if "doc_form" not in args:
- args["doc_form"] = "text_model"
- if "doc_language" not in args:
- args["doc_language"] = "English"
-
- # get dataset info
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -280,6 +297,18 @@ class DocumentAddByFileApi(DatasetApiResource):
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
+ args = {}
+ if "data" in request.form:
+ args = json.loads(request.form["data"])
+ if "doc_form" not in args:
+ args["doc_form"] = dataset.chunk_structure or "text_model"
+ if "doc_language" not in args:
+ args["doc_language"] = "English"
+
+ # get dataset info
+ dataset_id = str(dataset_id)
+ tenant_id = str(tenant_id)
+
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.")
@@ -370,17 +399,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
- args = {}
- if "data" in request.form:
- args = json.loads(request.form["data"])
- if "doc_form" not in args:
- args["doc_form"] = "text_model"
- if "doc_language" not in args:
- args["doc_language"] = "English"
-
- # get dataset info
- dataset_id = str(dataset_id)
- tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@@ -389,6 +407,18 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
+ args = {}
+ if "data" in request.form:
+ args = json.loads(request.form["data"])
+ if "doc_form" not in args:
+ args["doc_form"] = dataset.chunk_structure or "text_model"
+ if "doc_language" not in args:
+ args["doc_language"] = "English"
+
+ # get dataset info
+ dataset_id = str(dataset_id)
+ tenant_id = str(tenant_id)
+
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@@ -458,34 +488,39 @@ class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
- status = request.args.get("status", default=None, type=str)
+ query_params = DocumentListQuery.model_validate(request.args.to_dict())
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
- if status:
- query = DocumentService.apply_display_status_filter(query, status)
+ if query_params.status:
+ query = DocumentService.apply_display_status_filter(query, query_params.status)
- if search:
- search = f"%{search}%"
+ if query_params.keyword:
+ search = f"%{query_params.keyword}%"
query = query.where(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position))
- paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
+ paginated_documents = db.paginate(
+ select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False
+ )
documents = paginated_documents.items
+ DocumentService.enrich_documents_with_summary_index_status(
+ documents=documents,
+ dataset=dataset,
+ tenant_id=tenant_id,
+ )
+
response = {
"data": marshal(documents, document_fields),
- "has_more": len(documents) == limit,
- "limit": limit,
+ "has_more": len(documents) == query_params.limit,
+ "limit": query_params.limit,
"total": paginated_documents.total,
- "page": page,
+ "page": query_params.page,
}
return response
@@ -584,6 +619,16 @@ class DocumentApi(DatasetApiResource):
if metadata not in self.METADATA_CHOICES:
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
+ # Calculate summary_index_status if needed
+ summary_index_status = None
+ has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
+ if has_summary_index and document.need_summary is True:
+ summary_index_status = SummaryIndexService.get_document_summary_index_status(
+ document_id=document_id,
+ dataset_id=dataset_id,
+ tenant_id=tenant_id,
+ )
+
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
@@ -618,6 +663,8 @@ class DocumentApi(DatasetApiResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
+ "summary_index_status": summary_index_status,
+ "need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@@ -653,6 +700,8 @@ class DocumentApi(DatasetApiResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
+ "summary_index_status": summary_index_status,
+ "need_summary": document.need_summary if document.need_summary is not None else False,
}
return response
diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py
index d81287d56f..8dbb690901 100644
--- a/api/controllers/service_api/dataset/hit_testing.py
+++ b/api/controllers/service_api/dataset/hit_testing.py
@@ -24,7 +24,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
- args = self.parse_args()
+ args = self.parse_args(service_api_ns.payload)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py
index aab25c1af3..b8d9508004 100644
--- a/api/controllers/service_api/dataset/metadata.py
+++ b/api/controllers/service_api/dataset/metadata.py
@@ -11,7 +11,9 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_bill
from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
+ DocumentMetadataOperation,
MetadataArgs,
+ MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@@ -22,7 +24,13 @@ class MetadataUpdatePayload(BaseModel):
register_schema_model(service_api_ns, MetadataUpdatePayload)
-register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
+register_schema_models(
+ service_api_ns,
+ MetadataArgs,
+ MetadataDetail,
+ DocumentMetadataOperation,
+ MetadataOperationData,
+)
@service_api_ns.route("/datasets//metadata")
diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py
index 0a2017e2bd..70b5030237 100644
--- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py
+++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py
@@ -174,7 +174,7 @@ class PipelineRunApi(DatasetApiResource):
pipeline=pipeline,
user=current_user,
args=payload.model_dump(),
- invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE if payload.is_published else InvokeFrom.DEBUGGER,
streaming=payload.response_mode == "streaming",
)
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index b242fd2c3e..95679e6fcb 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -60,6 +60,7 @@ register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
+ SegmentUpdateArgs,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
diff --git a/api/controllers/trigger/trigger.py b/api/controllers/trigger/trigger.py
index e69b22d880..c10b94050c 100644
--- a/api/controllers/trigger/trigger.py
+++ b/api/controllers/trigger/trigger.py
@@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str):
if response:
break
if not response:
- logger.error("Endpoint not found for {endpoint_id}")
+ logger.info("Endpoint not found for %s", endpoint_id)
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e:
diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py
index 60193f5f15..62ea532eac 100644
--- a/api/controllers/web/app.py
+++ b/api/controllers/web/app.py
@@ -1,14 +1,13 @@
import logging
from flask import request
-from flask_restx import Resource, marshal_with, reqparse
+from flask_restx import Resource
+from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
from constants import HEADER_NAME_APP_CODE
from controllers.common import fields
-from controllers.web import web_ns
-from controllers.web.error import AppUnavailableError
-from controllers.web.wraps import WebApiResource
+from controllers.common.schema import register_schema_models
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService
from libs.token import extract_webapp_passport
@@ -18,9 +17,23 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
+from . import web_ns
+from .error import AppUnavailableError
+from .wraps import WebApiResource
+
logger = logging.getLogger(__name__)
+class AppAccessModeQuery(BaseModel):
+ model_config = ConfigDict(populate_by_name=True)
+
+ app_id: str | None = Field(default=None, alias="appId", description="Application ID")
+ app_code: str | None = Field(default=None, alias="appCode", description="Application code")
+
+
+register_schema_models(web_ns, AppAccessModeQuery)
+
+
@web_ns.route("/parameters")
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
@@ -37,7 +50,6 @@ class AppParameterApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
@@ -56,7 +68,8 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", [])
- return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
+ return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@web_ns.route("/meta")
@@ -96,21 +109,16 @@ class AppAccessMode(Resource):
}
)
def get(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("appId", type=str, required=False, location="args")
- .add_argument("appCode", type=str, required=False, location="args")
- )
- args = parser.parse_args()
+ raw_args = request.args.to_dict()
+ args = AppAccessModeQuery.model_validate(raw_args)
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"accessMode": "public"}
- app_id = args.get("appId")
- if args.get("appCode"):
- app_code = args["appCode"]
- app_id = AppService.get_app_id_by_code(app_code)
+ app_id = args.app_id
+ if args.app_code:
+ app_id = AppService.get_app_id_by_code(args.app_code)
if not app_id:
raise ValueError("appId or appCode must be provided")
diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py
index b9fef48c4d..15828cc208 100644
--- a/api/controllers/web/audio.py
+++ b/api/controllers/web/audio.py
@@ -1,7 +1,8 @@
import logging
from flask import request
-from flask_restx import fields, marshal_with, reqparse
+from flask_restx import fields, marshal_with
+from pydantic import BaseModel, field_validator
from werkzeug.exceptions import InternalServerError
import services
@@ -20,6 +21,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
+from libs.helper import uuid_value
from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
@@ -29,6 +31,25 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
+from ..common.schema import register_schema_models
+
+
+class TextToAudioPayload(BaseModel):
+ message_id: str | None = None
+ voice: str | None = None
+ text: str | None = None
+ streaming: bool | None = None
+
+ @field_validator("message_id")
+ @classmethod
+ def validate_message_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+register_schema_models(web_ns, TextToAudioPayload)
+
logger = logging.getLogger(__name__)
@@ -88,6 +109,7 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
+ @web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(
@@ -102,18 +124,11 @@ class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
"""Convert text to audio"""
try:
- parser = (
- reqparse.RequestParser()
- .add_argument("message_id", type=str, required=False, location="json")
- .add_argument("voice", type=str, location="json")
- .add_argument("text", type=str, location="json")
- .add_argument("streaming", type=bool, location="json")
- )
- args = parser.parse_args()
+ payload = TextToAudioPayload.model_validate(web_ns.payload or {})
- message_id = args.get("message_id", None)
- text = args.get("text", None)
- voice = args.get("voice", None)
+ message_id = payload.message_id
+ text = payload.text
+ voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)
diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py
index e8a4698375..a97d745471 100644
--- a/api/controllers/web/completion.py
+++ b/api/controllers/web/completion.py
@@ -1,9 +1,11 @@
import logging
+from typing import Any, Literal
-from flask_restx import reqparse
+from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
+class CompletionMessagePayload(BaseModel):
+ inputs: dict[str, Any] = Field(description="Input variables for the completion")
+ query: str = Field(default="", description="Query text for completion")
+ files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
+ response_mode: Literal["blocking", "streaming"] | None = Field(
+ default=None, description="Response mode: blocking or streaming"
+ )
+ retriever_from: str = Field(default="web_app", description="Source of retriever")
+
+
+class ChatMessagePayload(BaseModel):
+ inputs: dict[str, Any] = Field(description="Input variables for the chat")
+ query: str = Field(description="User query/message")
+ files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
+ response_mode: Literal["blocking", "streaming"] | None = Field(
+ default=None, description="Response mode: blocking or streaming"
+ )
+ conversation_id: str | None = Field(default=None, description="Conversation ID")
+ parent_message_id: str | None = Field(default=None, description="Parent message ID")
+ retriever_from: str = Field(default="web_app", description="Source of retriever")
+
+ @field_validator("conversation_id", "parent_message_id")
+ @classmethod
+ def validate_uuid(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
+
+
# define completion api for user
@web_ns.route("/completion-messages")
class CompletionApi(WebApiResource):
@web_ns.doc("Create Completion Message")
@web_ns.doc(description="Create a completion message for text generation applications.")
- @web_ns.doc(
- params={
- "inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
- "query": {"description": "Query text for completion", "type": "string", "required": False},
- "files": {"description": "Files to be processed", "type": "array", "required": False},
- "response_mode": {
- "description": "Response mode: blocking or streaming",
- "type": "string",
- "enum": ["blocking", "streaming"],
- "required": False,
- },
- "retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
- }
- )
+ @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@@ -67,18 +88,10 @@ class CompletionApi(WebApiResource):
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, location="json")
- .add_argument("query", type=str, location="json", default="")
- .add_argument("files", type=list, required=False, location="json")
- .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- .add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
- )
+ payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
- args = parser.parse_args()
-
- streaming = args["response_mode"] == "streaming"
+ streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:
@@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
@web_ns.doc("Create Chat Message")
@web_ns.doc(description="Create a chat message for conversational applications.")
- @web_ns.doc(
- params={
- "inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
- "query": {"description": "User query/message", "type": "string", "required": True},
- "files": {"description": "Files to be processed", "type": "array", "required": False},
- "response_mode": {
- "description": "Response mode: blocking or streaming",
- "type": "string",
- "enum": ["blocking", "streaming"],
- "required": False,
- },
- "conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
- "parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
- "retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
- }
- )
+ @web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@@ -173,20 +171,10 @@ class ChatApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, location="json")
- .add_argument("query", type=str, required=True, location="json")
- .add_argument("files", type=list, required=False, location="json")
- .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
- .add_argument("conversation_id", type=uuid_value, location="json")
- .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
- .add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
- )
+ payload = ChatMessagePayload.model_validate(web_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
- args = parser.parse_args()
-
- streaming = args["response_mode"] == "streaming"
+ streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
try:
diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py
index 86e19423e5..e76649495a 100644
--- a/api/controllers/web/conversation.py
+++ b/api/controllers/web/conversation.py
@@ -1,14 +1,21 @@
-from flask_restx import fields, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from typing import Literal
+
+from flask import request
+from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
-from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
+from fields.conversation_fields import (
+ ConversationInfiniteScrollPagination,
+ ResultResponse,
+ SimpleConversation,
+)
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
@@ -16,6 +23,35 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
from services.web_conversation_service import WebConversationService
+class ConversationListQuery(BaseModel):
+ last_id: str | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+ pinned: bool | None = None
+ sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = "-updated_at"
+
+ @field_validator("last_id")
+ @classmethod
+ def validate_last_id(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class ConversationRenamePayload(BaseModel):
+ name: str | None = None
+ auto_generate: bool = False
+
+ @model_validator(mode="after")
+ def validate_name_requirement(self):
+ if not self.auto_generate:
+ if self.name is None or not self.name.strip():
+ raise ValueError("name is required when auto_generate is false")
+ return self
+
+
+register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
+
+
@web_ns.route("/conversations")
class ConversationListApi(WebApiResource):
@web_ns.doc("Get Conversation List")
@@ -54,54 +90,39 @@ class ConversationListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- .add_argument("pinned", type=str, choices=["true", "false", None], location="args")
- .add_argument(
- "sort_by",
- type=str,
- choices=["created_at", "-created_at", "updated_at", "-updated_at"],
- required=False,
- default="-updated_at",
- location="args",
- )
- )
- args = parser.parse_args()
-
- pinned = None
- if "pinned" in args and args["pinned"] is not None:
- pinned = args["pinned"] == "true"
+ raw_args = request.args.to_dict()
+ query = ConversationListQuery.model_validate(raw_args)
try:
with Session(db.engine) as session:
- return WebConversationService.pagination_by_last_id(
+ pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
- last_id=args["last_id"],
- limit=args["limit"],
+ last_id=query.last_id,
+ limit=query.limit,
invoke_from=InvokeFrom.WEB_APP,
- pinned=pinned,
- sort_by=args["sort_by"],
+ pinned=query.pinned,
+ sort_by=query.sort_by,
)
+ adapter = TypeAdapter(SimpleConversation)
+ conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
+ return ConversationInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=conversations,
+ ).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@web_ns.route("/conversations/")
class ConversationApi(WebApiResource):
- delete_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Delete Conversation")
@web_ns.doc(description="Delete a specific conversation.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -115,7 +136,6 @@ class ConversationApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(delete_response_fields)
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -126,7 +146,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
@web_ns.route("/conversations//name")
@@ -155,7 +175,6 @@ class ConversationRenameApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -163,25 +182,23 @@ class ConversationRenameApi(WebApiResource):
conversation_id = str(c_id)
- parser = (
- reqparse.RequestParser()
- .add_argument("name", type=str, required=False, location="json")
- .add_argument("auto_generate", type=bool, required=False, default=False, location="json")
- )
- args = parser.parse_args()
+ payload = ConversationRenamePayload.model_validate(web_ns.payload or {})
try:
- return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
+ conversation = ConversationService.rename(
+ app_model, conversation_id, end_user, payload.name, payload.auto_generate
+ )
+ return (
+ TypeAdapter(SimpleConversation)
+ .validate_python(conversation, from_attributes=True)
+ .model_dump(mode="json")
+ )
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@web_ns.route("/conversations//pin")
class ConversationPinApi(WebApiResource):
- pin_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Pin Conversation")
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -195,7 +212,6 @@ class ConversationPinApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(pin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -208,15 +224,11 @@ class ConversationPinApi(WebApiResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/conversations//unpin")
class ConversationUnPinApi(WebApiResource):
- unpin_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Unpin Conversation")
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -230,7 +242,6 @@ class ConversationUnPinApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(unpin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -239,4 +250,4 @@ class ConversationUnPinApi(WebApiResource):
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, end_user)
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py
index cce3dae95d..2540bf02f4 100644
--- a/api/controllers/web/feature.py
+++ b/api/controllers/web/feature.py
@@ -17,5 +17,15 @@ class SystemFeatureApi(Resource):
Returns:
dict: System feature configuration object
+
+ This endpoint is akin to the `SystemFeatureApi` endpoint in api/controllers/console/feature.py,
+ except it is intended for use by the web app, instead of the console dashboard.
+
+ NOTE: This endpoint is unauthenticated by design, as it provides system features
+ data required for webapp initialization.
+
+ Authentication would create circular dependency (can't authenticate without webapp loading).
+
+ Only non-sensitive configuration data should be returned by this endpoint.
"""
return FeatureService.get_system_features().model_dump()
diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py
index 80ad61e549..0036c90800 100644
--- a/api/controllers/web/files.py
+++ b/api/controllers/web/files.py
@@ -1,5 +1,4 @@
from flask import request
-from flask_restx import marshal_with
import services
from controllers.common.errors import (
@@ -9,12 +8,15 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
-from fields.file_fields import build_file_model
+from fields.file_fields import FileResponse
from services.file_service import FileService
+register_schema_models(web_ns, FileResponse)
+
@web_ns.route("/files/upload")
class FileApi(WebApiResource):
@@ -28,7 +30,7 @@ class FileApi(WebApiResource):
415: "Unsupported file type",
}
)
- @marshal_with(build_file_model(web_ns))
+ @web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
def post(self, app_model, end_user):
"""Upload a file for use in web applications.
@@ -81,4 +83,5 @@ class FileApi(WebApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
- return upload_file, 201
+ response = FileResponse.model_validate(upload_file, from_attributes=True)
+ return response.model_dump(mode="json"), 201
diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py
index b9e391e049..91d206f727 100644
--- a/api/controllers/web/forgot_password.py
+++ b/api/controllers/web/forgot_password.py
@@ -2,10 +2,11 @@ import base64
import secrets
from flask import request
-from flask_restx import Resource, reqparse
-from sqlalchemy import select
+from flask_restx import Resource
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
+from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
@@ -18,14 +19,40 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns
from extensions.ext_database import db
-from libs.helper import email, extract_remote_ip
+from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
-from models import Account
+from models.account import Account
from services.account_service import AccountService
+class ForgotPasswordSendPayload(BaseModel):
+ email: EmailStr
+ language: str | None = None
+
+
+class ForgotPasswordCheckPayload(BaseModel):
+ email: EmailStr
+ code: str
+ token: str = Field(min_length=1)
+
+
+class ForgotPasswordResetPayload(BaseModel):
+ token: str = Field(min_length=1)
+ new_password: str
+ password_confirm: str
+
+ @field_validator("new_password", "password_confirm")
+ @classmethod
+ def validate_password(cls, value: str) -> str:
+ return valid_password(value)
+
+
+register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
+
+
@web_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
+ @web_ns.expect(web_ns.models[ForgotPasswordSendPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@@ -40,35 +67,34 @@ class ForgotPasswordSendEmailApi(Resource):
}
)
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
+
+ request_email = payload.email
+ normalized_email = request_email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if payload.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
raise AuthenticationFailedError()
else:
- token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
+ token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token}
@web_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
+ @web_ns.expect(web_ns.models[ForgotPasswordCheckPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@@ -78,45 +104,45 @@ class ForgotPasswordCheckApi(Resource):
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
)
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
- user_email = args["email"]
+ user_email = payload.email.lower()
- is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
+ is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
- token_data = AccountService.get_reset_password_data(args["token"])
+ token_data = AccountService.get_reset_password_data(payload.token)
if token_data is None:
raise InvalidTokenError()
- if user_email != token_data.get("email"):
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+
+ if user_email != normalized_token_email:
raise InvalidEmailError()
- if args["code"] != token_data.get("code"):
- AccountService.add_forgot_password_error_rate_limit(args["email"])
+ if payload.code != token_data.get("code"):
+ AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
- AccountService.revoke_reset_password_token(args["token"])
+ AccountService.revoke_reset_password_token(payload.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
- user_email, code=args["code"], additional_data={"phase": "reset"}
+ token_email, code=payload.code, additional_data={"phase": "reset"}
)
- AccountService.reset_forgot_password_error_rate_limit(args["email"])
- return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+ AccountService.reset_forgot_password_error_rate_limit(user_email)
+ return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
+ @web_ns.expect(web_ns.models[ForgotPasswordResetPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@@ -131,20 +157,14 @@ class ForgotPasswordResetApi(Resource):
}
)
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("token", type=str, required=True, nullable=False, location="json")
- .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
- .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
- )
- args = parser.parse_args()
+ payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {})
# Validate passwords match
- if args["new_password"] != args["password_confirm"]:
+ if payload.new_password != payload.password_confirm:
raise PasswordMismatchError()
# Validate token and get reset data
- reset_data = AccountService.get_reset_password_data(args["token"])
+ reset_data = AccountService.get_reset_password_data(payload.token)
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@@ -152,16 +172,16 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
- AccountService.revoke_reset_password_token(args["token"])
+ AccountService.revoke_reset_password_token(payload.token)
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
- password_hashed = hash_password(args["new_password"], salt)
+ password_hashed = hash_password(payload.new_password, salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
- account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
@@ -170,7 +190,7 @@ class ForgotPasswordResetApi(Resource):
return {"result": "success"}
- def _update_existing_account(self, account, password_hashed, salt, session):
+ def _update_existing_account(self, account: Account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py
index 538d0c44be..a824f6d487 100644
--- a/api/controllers/web/login.py
+++ b/api/controllers/web/login.py
@@ -1,19 +1,26 @@
from flask import make_response, request
-from flask_restx import Resource, reqparse
+from flask_restx import Resource
from jwt import InvalidTokenError
+from pydantic import BaseModel, Field, field_validator
import services
from configs import dify_config
+from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
InvalidEmailError,
)
from controllers.console.error import AccountBannedError
-from controllers.console.wraps import only_edition_enterprise, setup_required
+from controllers.console.wraps import (
+ decrypt_code_field,
+ decrypt_password_field,
+ only_edition_enterprise,
+ setup_required,
+)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
-from libs.helper import email
+from libs.helper import EmailStr
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@@ -25,10 +32,35 @@ from services.app_service import AppService
from services.webapp_auth_service import WebAppAuthService
+class LoginPayload(BaseModel):
+ email: EmailStr
+ password: str
+
+ @field_validator("password")
+ @classmethod
+ def validate_password(cls, value: str) -> str:
+ return valid_password(value)
+
+
+class EmailCodeLoginSendPayload(BaseModel):
+ email: EmailStr
+ language: str | None = None
+
+
+class EmailCodeLoginVerifyPayload(BaseModel):
+ email: EmailStr
+ code: str
+ token: str = Field(min_length=1)
+
+
+register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
+
+
@web_ns.route("/login")
class LoginApi(Resource):
"""Resource for web app email/password login."""
+ @web_ns.expect(web_ns.models[LoginPayload.__name__])
@setup_required
@only_edition_enterprise
@web_ns.doc("web_app_login")
@@ -42,17 +74,13 @@ class LoginApi(Resource):
404: "Account not found",
}
)
+ @decrypt_password_field
def post(self):
"""Authenticate user and login."""
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("password", type=valid_password, required=True, location="json")
- )
- args = parser.parse_args()
+ payload = LoginPayload.model_validate(web_ns.payload or {})
try:
- account = WebAppAuthService.authenticate(args["email"], args["password"])
+ account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
@@ -139,6 +167,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@only_edition_enterprise
@web_ns.doc("send_email_code_login")
@web_ns.doc(description="Send email verification code for login")
+ @web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
@web_ns.doc(
responses={
200: "Email code sent successfully",
@@ -147,19 +176,14 @@ class EmailCodeLoginSendEmailApi(Resource):
}
)
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=email, required=True, location="json")
- .add_argument("language", type=str, required=False, location="json")
- )
- args = parser.parse_args()
+ payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
- if args["language"] is not None and args["language"] == "zh-Hans":
+ if payload.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
- account = WebAppAuthService.get_user_through_email(args["email"])
+ account = WebAppAuthService.get_user_through_email(payload.email)
if account is None:
raise AuthenticationFailedError()
else:
@@ -173,6 +197,7 @@ class EmailCodeLoginApi(Resource):
@only_edition_enterprise
@web_ns.doc("verify_email_code_login")
@web_ns.doc(description="Verify email code and complete login")
+ @web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
@web_ns.doc(
responses={
200: "Email code verified and login successful",
@@ -181,34 +206,33 @@ class EmailCodeLoginApi(Resource):
404: "Account not found",
}
)
+ @decrypt_code_field
def post(self):
- parser = (
- reqparse.RequestParser()
- .add_argument("email", type=str, required=True, location="json")
- .add_argument("code", type=str, required=True, location="json")
- .add_argument("token", type=str, required=True, location="json")
- )
- args = parser.parse_args()
+ payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
- user_email = args["email"]
+ user_email = payload.email.lower()
- token_data = WebAppAuthService.get_email_code_login_data(args["token"])
+ token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
raise InvalidTokenError()
- if token_data["email"] != args["email"]:
+ token_email = token_data.get("email")
+ if not isinstance(token_email, str):
+ raise InvalidEmailError()
+ normalized_token_email = token_email.lower()
+ if normalized_token_email != user_email:
raise InvalidEmailError()
- if token_data["code"] != args["code"]:
+ if token_data["code"] != payload.code:
raise EmailCodeError()
- WebAppAuthService.revoke_email_code_login_token(args["token"])
- account = WebAppAuthService.get_user_through_email(user_email)
+ WebAppAuthService.revoke_email_code_login_token(payload.token)
+ account = WebAppAuthService.get_user_through_email(token_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
- AccountService.reset_login_error_rate_limit(args["email"])
+ AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index 9f9aa4838c..80035ba818 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -1,9 +1,11 @@
import logging
+from typing import Literal
-from flask_restx import fields, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import request
+from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppMoreLikeThisDisabledError,
@@ -19,11 +21,10 @@ from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from fields.conversation_fields import message_file_fields
-from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
-from fields.raws import FilesContainedField
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
-from libs.helper import TimestampField, uuid_value
+from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -38,37 +39,45 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
+class MessageListQuery(BaseModel):
+ conversation_id: str = Field(description="Conversation UUID")
+ first_id: str | None = Field(default=None, description="First message ID for pagination")
+ limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
+
+ @field_validator("conversation_id", "first_id")
+ @classmethod
+ def validate_uuid(cls, value: str | None) -> str | None:
+ if value is None:
+ return value
+ return uuid_value(value)
+
+
+class MessageFeedbackPayload(BaseModel):
+ rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
+ content: str | None = Field(default=None, description="Feedback content")
+
+
+class MessageMoreLikeThisQuery(BaseModel):
+ response_mode: Literal["blocking", "streaming"] = Field(
+ description="Response mode",
+ )
+
+
+register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery)
+
+
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
- message_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "parent_message_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
- "metadata": fields.Raw(attribute="message_metadata_dict"),
- "status": fields.String,
- "error": fields.String,
- }
-
- message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
- }
-
@web_ns.doc("Get Message List")
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
@web_ns.doc(
params={
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
- "first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
+ "first_id": {
+ "description": "First message ID for pagination",
+ "type": "string",
+ "required": False,
+ },
"limit": {
"description": "Number of messages to return (1-100)",
"type": "integer",
@@ -87,24 +96,25 @@ class MessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("conversation_id", required=True, type=uuid_value, location="args")
- .add_argument("first_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- )
- args = parser.parse_args()
+ raw_args = request.args.to_dict()
+ query = MessageListQuery.model_validate(raw_args)
try:
- return MessageService.pagination_by_first_id(
- app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
+ pagination = MessageService.pagination_by_first_id(
+ app_model, end_user, query.conversation_id, query.first_id, query.limit
)
+ adapter = TypeAdapter(WebMessageListItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return WebMessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -113,10 +123,6 @@ class MessageListApi(WebApiResource):
@web_ns.route("/messages//feedbacks")
class MessageFeedbackApi(WebApiResource):
- feedback_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Create Message Feedback")
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@@ -128,7 +134,7 @@ class MessageFeedbackApi(WebApiResource):
"enum": ["like", "dislike"],
"required": False,
},
- "content": {"description": "Feedback content/comment", "type": "string", "required": False},
+ "content": {"description": "Feedback content", "type": "string", "required": False},
}
)
@web_ns.doc(
@@ -141,46 +147,30 @@ class MessageFeedbackApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(feedback_response_fields)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
- parser = (
- reqparse.RequestParser()
- .add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
- .add_argument("content", type=str, location="json", default=None)
- )
- args = parser.parse_args()
+ payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
- rating=args.get("rating"),
- content=args.get("content"),
+ rating=payload.rating,
+ content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/messages//more-like-this")
class MessageMoreLikeThisApi(WebApiResource):
@web_ns.doc("Generate More Like This")
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
- @web_ns.doc(
- params={
- "message_id": {"description": "Message UUID", "type": "string", "required": True},
- "response_mode": {
- "description": "Response mode",
- "type": "string",
- "enum": ["blocking", "streaming"],
- "required": True,
- },
- }
- )
+ @web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__])
@web_ns.doc(
responses={
200: "Success",
@@ -197,12 +187,10 @@ class MessageMoreLikeThisApi(WebApiResource):
message_id = str(message_id)
- parser = reqparse.RequestParser().add_argument(
- "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
- )
- args = parser.parse_args()
+ raw_args = request.args.to_dict()
+ query = MessageMoreLikeThisQuery.model_validate(raw_args)
- streaming = args["response_mode"] == "streaming"
+ streaming = query.response_mode == "streaming"
try:
response = AppGenerateService.generate_more_like_this(
@@ -235,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource):
@web_ns.route("/messages//suggested-questions")
class MessageSuggestedQuestionApi(WebApiResource):
- suggested_questions_response_fields = {
- "data": fields.List(fields.String),
- }
-
@web_ns.doc("Get Suggested Questions")
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@@ -252,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(suggested_questions_response_fields)
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -265,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
)
# questions is a list of strings, not a list of Message objects
- # so we can directly return it
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
@@ -284,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
logger.exception("internal server error.")
raise InternalServerError()
- return {"data": questions}
+ return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py
index dac4b3da94..b08b3fe858 100644
--- a/api/controllers/web/remote_files.py
+++ b/api/controllers/web/remote_files.py
@@ -1,7 +1,7 @@
import urllib.parse
import httpx
-from flask_restx import marshal_with, reqparse
+from pydantic import BaseModel, Field, HttpUrl
import services
from controllers.common import helpers
@@ -10,14 +10,23 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
-from controllers.web import web_ns
-from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
-from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
+from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from services.file_service import FileService
+from ..common.schema import register_schema_models
+from . import web_ns
+from .wraps import WebApiResource
+
+
+class RemoteFileUploadPayload(BaseModel):
+ url: HttpUrl = Field(description="Remote file URL")
+
+
+register_schema_models(web_ns, RemoteFileUploadPayload, RemoteFileInfo, FileWithSignedUrl)
+
@web_ns.route("/remote-files/")
class RemoteFileInfoApi(WebApiResource):
@@ -31,7 +40,7 @@ class RemoteFileInfoApi(WebApiResource):
500: "Failed to fetch remote file",
}
)
- @marshal_with(build_remote_file_info_model(web_ns))
+ @web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__])
def get(self, app_model, end_user, url):
"""Get information about a remote file.
@@ -55,10 +64,11 @@ class RemoteFileInfoApi(WebApiResource):
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
- return {
- "file_type": resp.headers.get("Content-Type", "application/octet-stream"),
- "file_length": int(resp.headers.get("Content-Length", -1)),
- }
+ info = RemoteFileInfo(
+ file_type=resp.headers.get("Content-Type", "application/octet-stream"),
+ file_length=int(resp.headers.get("Content-Length", -1)),
+ )
+ return info.model_dump(mode="json")
@web_ns.route("/remote-files/upload")
@@ -74,7 +84,7 @@ class RemoteFileUploadApi(WebApiResource):
500: "Failed to fetch remote file",
}
)
- @marshal_with(build_file_with_signed_url_model(web_ns))
+ @web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__])
def post(self, app_model, end_user):
"""Upload a file from a remote URL.
@@ -97,10 +107,8 @@ class RemoteFileUploadApi(WebApiResource):
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
- parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
- args = parser.parse_args()
-
- url = args["url"]
+ payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {})
+ url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)
@@ -131,13 +139,14 @@ class RemoteFileUploadApi(WebApiResource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError
- return {
- "id": upload_file.id,
- "name": upload_file.name,
- "size": upload_file.size,
- "extension": upload_file.extension,
- "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
- "mime_type": upload_file.mime_type,
- "created_by": upload_file.created_by,
- "created_at": upload_file.created_at,
- }, 201
+ payload1 = FileWithSignedUrl(
+ id=upload_file.id,
+ name=upload_file.name,
+ size=upload_file.size,
+ extension=upload_file.extension,
+ url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
+ mime_type=upload_file.mime_type,
+ created_by=upload_file.created_by,
+ created_at=int(upload_file.created_at.timestamp()),
+ )
+ return payload1.model_dump(mode="json"), 201
diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py
index 865f3610a7..29993100f6 100644
--- a/api/controllers/web/saved_message.py
+++ b/api/controllers/web/saved_message.py
@@ -1,40 +1,32 @@
-from flask_restx import fields, marshal_with, reqparse
-from flask_restx.inputs import int_range
+from flask import request
+from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
-from fields.conversation_fields import message_file_fields
-from libs.helper import TimestampField, uuid_value
+from fields.conversation_fields import ResultResponse
+from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
+from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
-feedback_fields = {"rating": fields.String}
-message_fields = {
- "id": fields.String,
- "inputs": fields.Raw,
- "query": fields.String,
- "answer": fields.String,
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "created_at": TimestampField,
-}
+class SavedMessageListQuery(BaseModel):
+ last_id: UUIDStrOrEmpty | None = None
+ limit: int = Field(default=20, ge=1, le=100)
+
+
+class SavedMessageCreatePayload(BaseModel):
+ message_id: UUIDStrOrEmpty
+
+
+register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
- saved_message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
- }
-
- post_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Get Saved Messages")
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
@web_ns.doc(
@@ -58,19 +50,21 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("last_id", type=uuid_value, location="args")
- .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
- )
- args = parser.parse_args()
+ raw_args = request.args.to_dict()
+ query = SavedMessageListQuery.model_validate(raw_args)
- return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
+ pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
+ adapter = TypeAdapter(SavedMessageItem)
+ items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
+ return SavedMessageInfiniteScrollPagination(
+ limit=pagination.limit,
+ has_more=pagination.has_more,
+ data=items,
+ ).model_dump(mode="json")
@web_ns.doc("Save Message")
@web_ns.doc(description="Save a specific message for later reference.")
@@ -89,28 +83,22 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(post_response_fields)
def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
- parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
- args = parser.parse_args()
+ payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
try:
- SavedMessageService.save(app_model, end_user, args["message_id"])
+ SavedMessageService.save(app_model, end_user, payload.message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
- return {"result": "success"}
+ return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/saved-messages/")
class SavedMessageApi(WebApiResource):
- delete_response_fields = {
- "result": fields.String,
- }
-
@web_ns.doc("Delete Saved Message")
@web_ns.doc(description="Remove a message from saved messages.")
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
@@ -124,7 +112,6 @@ class SavedMessageApi(WebApiResource):
500: "Internal Server Error",
}
)
- @marshal_with(delete_response_fields)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
@@ -133,4 +120,4 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
- return {"result": "success"}, 204
+ return ResultResponse(result="success").model_dump(mode="json"), 204
diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py
index 3cbb07a296..95d8c6d5a5 100644
--- a/api/controllers/web/workflow.py
+++ b/api/controllers/web/workflow.py
@@ -1,8 +1,10 @@
import logging
+from typing import Any
-from flask_restx import reqparse
+from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
+from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
CompletionRequestError,
@@ -27,19 +29,22 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
+
+class WorkflowRunPayload(BaseModel):
+ inputs: dict[str, Any] = Field(description="Input variables for the workflow")
+ files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
+
+
logger = logging.getLogger(__name__)
+register_schema_models(web_ns, WorkflowRunPayload)
+
@web_ns.route("/workflows/run")
class WorkflowRunApi(WebApiResource):
@web_ns.doc("Run Workflow")
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
- @web_ns.doc(
- params={
- "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
- "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
- }
- )
+ @web_ns.expect(web_ns.models[WorkflowRunPayload.__name__])
@web_ns.doc(
responses={
200: "Success",
@@ -58,12 +63,8 @@ class WorkflowRunApi(WebApiResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
- parser = (
- reqparse.RequestParser()
- .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
- .add_argument("files", type=list, required=False, location="json")
- )
- args = parser.parse_args()
+ payload = WorkflowRunPayload.model_validate(web_ns.payload or {})
+ args = payload.model_dump(exclude_none=True)
try:
response = AppGenerateService.generate(
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index c196dbbdf1..3c6d36afe4 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -1,6 +1,7 @@
import json
import logging
import uuid
+from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
@@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
+from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
+ tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
- message_unit_price=0,
- message_price_unit=0,
+ message_unit_price=Decimal(0),
+ message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
- answer_unit_price=0,
- answer_price_unit=0,
+ answer_unit_price=Decimal(0),
+ answer_price_unit=Decimal("0.001"),
tokens=0,
- total_price=0,
+ total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
- created_by_role="account",
+ created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
- agent_thought.thought += thought
+ existing_thought = agent_thought.thought or ""
+ agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
- tools = agent_thought.tool
- if tools:
- tools = tools.split(";")
+ tool_names_raw = agent_thought.tool
+ if tool_names_raw:
+ tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
- try:
- tool_inputs = json.loads(agent_thought.tool_input)
- except Exception:
- tool_inputs = {tool: {} for tool in tools}
- try:
- tool_responses = json.loads(agent_thought.observation)
- except Exception:
- tool_responses = dict.fromkeys(tools, agent_thought.observation)
+ tool_input_payload = agent_thought.tool_input
+ if tool_input_payload:
+ try:
+ tool_inputs = json.loads(tool_input_payload)
+ except Exception:
+ tool_inputs = {tool: {} for tool in tool_names}
+ else:
+ tool_inputs = {tool: {} for tool in tool_names}
- for tool in tools:
+ observation_payload = agent_thought.observation
+ if observation_payload:
+ try:
+ tool_responses = json.loads(observation_payload)
+ except Exception:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+ else:
+ tool_responses = dict.fromkeys(tool_names, observation_payload)
+
+ for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
- if not tools:
+ if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py
index b32e35d0ca..a55f2d0f5f 100644
--- a/api/core/agent/cot_agent_runner.py
+++ b/api/core/agent/cot_agent_runner.py
@@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
+from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
+ # Check if max iteration is reached and model still wants to call tools
+ if iteration_step == max_iteration_steps and scratchpad.action:
+ if scratchpad.action.action_name.lower() != "final answer":
+ raise AgentMaxIterationError(app_config.agent.max_iteration)
+
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index 88753ff2de..cef81be9fd 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta, ToolProviderType
from core.tools.tool_engine import ToolEngine
+from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@@ -187,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
- assistant_message = AssistantPromptMessage(content="", tool_calls=[])
+ assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@@ -199,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
- else:
- assistant_message.content = response
self._current_thoughts.append(assistant_message)
@@ -222,6 +221,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
final_answer += response + "\n"
+ # Check if max iteration is reached and model still wants to call tools
+ if iteration_step == max_iteration_steps and tool_calls:
+ raise AgentMaxIterationError(app_config.agent.max_iteration)
+
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py
index 93f2742599..13c51529cc 100644
--- a/api/core/app/app_config/entities.py
+++ b/api/core/app/app_config/entities.py
@@ -120,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
- json_schema: dict[str, Any] | None = Field(default=None)
+ json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -134,7 +134,7 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
- def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
+ def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py
index e4b308a6f6..c21c494efe 100644
--- a/api/core/app/apps/advanced_chat/app_config_manager.py
+++ b/api/core/app/apps/advanced_chat/app_config_manager.py
@@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
-
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index feb0d3358c..528c45f6c8 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
+
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
-from typing import Any, Literal, Union, overload
+from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -13,6 +15,9 @@ from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from constants import UUID_NIL
+
+if TYPE_CHECKING:
+ from controllers.console.app.workflow import LoopNodeRunPayload
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@@ -304,7 +309,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
- args: Mapping,
+ args: LoopNodeRunPayload,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
@@ -320,7 +325,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if not node_id:
raise ValueError("node_id is required")
- if args.get("inputs") is None:
+ if args.inputs is None:
raise ValueError("inputs is required")
# convert to app config
@@ -338,7 +343,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
- single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
+ single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index ee092e55c5..d702db0908 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -20,13 +20,15 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
+from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
+from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -37,9 +39,9 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from models import Workflow
-from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
+from services.conversation_variable_updater import ConversationVariableUpdater
logger = logging.getLogger(__name__)
@@ -103,6 +105,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
+ invoke_from = self.application_generate_entity.invoke_from
+ if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
+ invoke_from = InvokeFrom.DEBUGGER
+ user_from = self._resolve_user_from(invoke_from)
+
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@@ -142,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@@ -155,6 +162,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
+ user_from=user_from,
+ invoke_from=invoke_from,
)
db.session.close()
@@ -172,12 +181,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
- user_from=(
- UserFrom.ACCOUNT
- if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
- else UserFrom.END_USER
- ),
- invoke_from=self.application_generate_entity.invoke_from,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
@@ -200,6 +205,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
+ conversation_variable_layer = ConversationVariablePersistenceLayer(
+ ConversationVariableUpdater(session_factory.get_session_maker())
+ )
+ workflow_entry.graph_engine.layer(conversation_variable_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
@@ -309,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
- def _initialize_conversation_variables(self) -> list[VariableUnion]:
+ def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@@ -334,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
- return cast(list[VariableUnion], conversation_variables)
+ return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 2760466a3b..8b6b8f227b 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True,
+ message_id=message.id,
+ user_id=application_generate_entity.user_id,
+ tenant_id=app_config.tenant_id,
)
diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py
index 74c6d2eca6..d1e2f16b6f 100644
--- a/api/core/app/apps/base_app_generate_response_converter.py
+++ b/api/core/app/apps/base_app_generate_response_converter.py
@@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC):
"document_name": resource["document_name"],
"score": resource["score"],
"content": resource["content"],
+ "summary": resource.get("summary"),
}
)
metadata["retriever_resources"] = updated_resources
diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py
index 1b0474142e..07bae66867 100644
--- a/api/core/app/apps/base_app_generator.py
+++ b/api/core/app/apps/base_app_generator.py
@@ -75,12 +75,24 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
- if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
- raise ValueError("Invalid input type")
- if any(
- filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
- ):
- raise ValueError("Invalid input type")
+ invalid_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, dict)
+ and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
+ ]
+ if invalid_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_dict_keys}")
+
+ invalid_list_dict_keys = [
+ k
+ for k, v in user_inputs.items()
+ if isinstance(v, list)
+ and any(isinstance(item, dict) for item in v)
+ and entity_dictionary[k].type != VariableEntityType.FILE_LIST
+ ]
+ if invalid_list_dict_keys:
+ raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
return user_inputs
@@ -104,8 +116,9 @@ class BaseAppGenerator:
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
and not variable_entity.required
):
- # Treat empty string (frontend default) or empty list as unset
- if not value and isinstance(value, (str, list)):
+ # Treat empty string (frontend default) as unset
+ # For FILE_LIST, allow empty list [] to pass through
+ if isinstance(value, str) and not value:
return None
if variable_entity.type in {
@@ -175,6 +188,9 @@ class BaseAppGenerator:
value = True
elif value == 0:
value = False
+ case VariableEntityType.JSON_OBJECT:
+ if value and not isinstance(value, dict):
+ raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")
diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py
index 698eee9894..b41bedbea4 100644
--- a/api/core/app/apps/base_app_queue_manager.py
+++ b/api/core/app/apps/base_app_queue_manager.py
@@ -90,6 +90,7 @@ class AppQueueManager:
"""
self._clear_task_belong_cache()
self._q.put(None)
+ self._graph_runtime_state = None # Release reference to allow GC to reclaim memory
def _clear_task_belong_cache(self) -> None:
"""
diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py
index e2e6c11480..617515945b 100644
--- a/api/core/app/apps/base_app_runner.py
+++ b/api/core/app/apps/base_app_runner.py
@@ -1,6 +1,8 @@
+import base64
import logging
import time
from collections.abc import Generator, Mapping, Sequence
+from mimetypes import guess_extension
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
ModelConfigWithCredentialsEntity,
)
-from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
+from core.app.entities.queue_entities import (
+ QueueAgentMessageEvent,
+ QueueLLMChunkEvent,
+ QueueMessageEndEvent,
+ QueueMessageFileEvent,
+)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
+from core.file.enums import FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
+ TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
@@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
-from models.model import App, AppMode, Message, MessageAnnotation
+from core.tools.tool_file_manager import ToolFileManager
+from extensions.ext_database import db
+from models.enums import CreatorUserRole
+from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
from core.file.models import File
@@ -203,6 +215,9 @@ class AppRunner:
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
+ message_id: str | None = None,
+ user_id: str | None = None,
+ tenant_id: str | None = None,
):
"""
Handle invoke result
@@ -210,21 +225,41 @@ class AppRunner:
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
+ :param message_id: message id for multimodal output
+ :param user_id: user id for multimodal output
+ :param tenant_id: tenant id for multimodal output
:return:
"""
if not stream and isinstance(invoke_result, LLMResult):
- self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
+ self._handle_invoke_result_direct(
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ )
elif stream and isinstance(invoke_result, Generator):
- self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
+ self._handle_invoke_result_stream(
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ agent=agent,
+ message_id=message_id,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ )
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
- def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
+ def _handle_invoke_result_direct(
+ self,
+ invoke_result: LLMResult,
+ queue_manager: AppQueueManager,
+ ):
"""
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
+ :param message_id: message id for multimodal output
+ :param user_id: user id for multimodal output
+ :param tenant_id: tenant id for multimodal output
:return:
"""
queue_manager.publish(
@@ -235,13 +270,22 @@ class AppRunner:
)
def _handle_invoke_result_stream(
- self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
+ self,
+ invoke_result: Generator[LLMResultChunk, None, None],
+ queue_manager: AppQueueManager,
+ agent: bool,
+ message_id: str | None = None,
+ user_id: str | None = None,
+ tenant_id: str | None = None,
):
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
+ :param message_id: message id for multimodal output
+ :param user_id: user id for multimodal output
+ :param tenant_id: tenant id for multimodal output
:return:
"""
model: str = ""
@@ -259,12 +303,26 @@ class AppRunner:
text += message.content
elif isinstance(message.content, list):
for content in message.content:
- if not isinstance(content, str):
- # TODO(QuantumGhost): Add multimodal output support for easy ui.
- _logger.warning("received multimodal output, type=%s", type(content))
+ if isinstance(content, str):
+ text += content
+ elif isinstance(content, TextPromptMessageContent):
text += content.data
+ elif isinstance(content, ImagePromptMessageContent):
+ if message_id and user_id and tenant_id:
+ try:
+ self._handle_multimodal_image_content(
+ content=content,
+ message_id=message_id,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ queue_manager=queue_manager,
+ )
+ except Exception:
+ _logger.exception("Failed to handle multimodal image output")
+ else:
+ _logger.warning("Received multimodal output but missing required parameters")
else:
- text += content # failback to str
+ text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model
@@ -289,6 +347,101 @@ class AppRunner:
PublishFrom.APPLICATION_MANAGER,
)
+ def _handle_multimodal_image_content(
+ self,
+ content: ImagePromptMessageContent,
+ message_id: str,
+ user_id: str,
+ tenant_id: str,
+ queue_manager: AppQueueManager,
+ ):
+ """
+ Handle multimodal image content from LLM response.
+ Save the image and create a MessageFile record.
+
+ :param content: ImagePromptMessageContent instance
+ :param message_id: message id
+ :param user_id: user id
+ :param tenant_id: tenant id
+ :param queue_manager: queue manager
+ :return:
+ """
+ _logger.info("Handling multimodal image content for message %s", message_id)
+
+ image_url = content.url
+ base64_data = content.base64_data
+
+ _logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
+
+ if not image_url and not base64_data:
+ _logger.warning("Image content has neither URL nor base64 data")
+ return
+
+ tool_file_manager = ToolFileManager()
+
+ # Save the image file
+ try:
+ if image_url:
+ # Download image from URL
+ _logger.info("Downloading image from URL: %s", image_url)
+ tool_file = tool_file_manager.create_file_by_url(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ file_url=image_url,
+ conversation_id=None,
+ )
+ _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
+ elif base64_data:
+ if base64_data.startswith("data:"):
+ base64_data = base64_data.split(",", 1)[1]
+
+ image_binary = base64.b64decode(base64_data)
+ mimetype = content.mime_type or "image/png"
+ extension = guess_extension(mimetype) or ".png"
+
+ tool_file = tool_file_manager.create_file_by_raw(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ conversation_id=None,
+ file_binary=image_binary,
+ mimetype=mimetype,
+ filename=f"generated_image{extension}",
+ )
+ _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
+ else:
+ return
+ except Exception:
+ _logger.exception("Failed to save image file")
+ return
+
+ # Create MessageFile record
+ message_file = MessageFile(
+ message_id=message_id,
+ type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ belongs_to="assistant",
+ url=f"/files/tools/{tool_file.id}",
+ upload_file_id=tool_file.id,
+ created_by_role=(
+ CreatorUserRole.ACCOUNT
+ if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
+ else CreatorUserRole.END_USER
+ ),
+ created_by=user_id,
+ )
+
+ db.session.add(message_file)
+ db.session.commit()
+ db.session.refresh(message_file)
+
+ # Publish QueueMessageFileEvent
+ queue_manager.publish(
+ QueueMessageFileEvent(message_file_id=message_file.id),
+ PublishFrom.APPLICATION_MANAGER,
+ )
+
+ _logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
+
def moderation_for_inputs(
self,
*,
diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py
index f8338b226b..7d1a4c619f 100644
--- a/api/core/app/apps/chat/app_runner.py
+++ b/api/core/app/apps/chat/app_runner.py
@@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
- invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ stream=application_generate_entity.stream,
+ message_id=message.id,
+ user_id=application_generate_entity.user_id,
+ tenant_id=app_config.tenant_id,
)
diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py
index ddfb5725b4..a872c2e1f7 100644
--- a/api/core/app/apps/completion/app_runner.py
+++ b/api/core/app/apps/completion/app_runner.py
@@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
- invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
+ invoke_result=invoke_result,
+ queue_manager=queue_manager,
+ stream=application_generate_entity.stream,
+ message_id=message.id,
+ user_id=application_generate_entity.user_id,
+ tenant_id=app_config.tenant_id,
)
diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py
index 13eb40fd60..ea4441b5d8 100644
--- a/api/core/app/apps/pipeline/pipeline_generator.py
+++ b/api/core/app/apps/pipeline/pipeline_generator.py
@@ -130,7 +130,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
- if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
+ if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
@@ -156,7 +156,7 @@ class PipelineGenerator(BaseAppGenerator):
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
- if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
+ if invoke_from == InvokeFrom.PUBLISHED_PIPELINE and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py
index 4be9e01fbf..8ea34344b2 100644
--- a/api/core/app/apps/pipeline/pipeline_runner.py
+++ b/api/core/app/apps/pipeline/pipeline_runner.py
@@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
+from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
-from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -73,9 +73,15 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
+ invoke_from = self.application_generate_entity.invoke_from
+
+ if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
+ invoke_from = InvokeFrom.DEBUGGER
+
+ user_from = self._resolve_user_from(invoke_from)
user_id = None
- if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
+ if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@@ -117,7 +123,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
- invoke_from=self.application_generate_entity.invoke_from.value,
+ invoke_from=invoke_from.value,
)
rag_pipeline_variables = []
@@ -149,6 +155,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
+ user_from=user_from,
+ invoke_from=invoke_from,
)
# RUN WORKFLOW
@@ -159,12 +167,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
- user_from=(
- UserFrom.ACCOUNT
- if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
- else UserFrom.END_USER
- ),
- invoke_from=self.application_generate_entity.invoke_from,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
@@ -210,7 +214,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
return workflow
def _init_rag_pipeline_graph(
- self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
+ self,
+ workflow: Workflow,
+ graph_runtime_state: GraphRuntimeState,
+ start_node_id: str | None = None,
+ user_from: UserFrom = UserFrom.ACCOUNT,
+ invoke_from: InvokeFrom = InvokeFrom.SERVICE_API,
) -> Graph:
"""
Init pipeline graph
@@ -253,8 +262,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
- user_from=UserFrom.ACCOUNT,
- invoke_from=InvokeFrom.SERVICE_API,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=0,
)
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 0165c74295..ee205ed153 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -1,14 +1,16 @@
+from __future__ import annotations
+
import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
-from typing import Any, Literal, Union, overload
+from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
-from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -23,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
+from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@@ -39,6 +42,9 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
+if TYPE_CHECKING:
+ from controllers.console.app.workflow import LoopNodeRunPayload
+
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
@@ -380,7 +386,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
- args: Mapping[str, Any],
+ args: LoopNodeRunPayload,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
@@ -396,7 +402,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
if not node_id:
raise ValueError("node_id is required")
- if args.get("inputs") is None:
+ if args.inputs is None:
raise ValueError("inputs is required")
# convert to app config
@@ -412,7 +418,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
- single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
+ single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
@@ -476,7 +482,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
- with Session(db.engine, expire_on_commit=False) as session:
+ with session_factory.create_session() as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py
index 894e6f397a..0ee3c177f2 100644
--- a/api/core/app/apps/workflow/app_runner.py
+++ b/api/core/app/apps/workflow/app_runner.py
@@ -7,10 +7,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
+from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
-from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -20,7 +20,6 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from libs.datetime_utils import naive_utc_now
-from models.enums import UserFrom
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -74,7 +73,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
+ invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
+ if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
+ invoke_from = InvokeFrom.DEBUGGER
+ user_from = self._resolve_user_from(invoke_from)
+
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
@@ -102,6 +106,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
+ user_from=user_from,
+ invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
@@ -120,12 +126,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
- user_from=(
- UserFrom.ACCOUNT
- if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
- else UserFrom.END_USER
- ),
- invoke_from=self.application_generate_entity.invoke_from,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py
index 0e125b3538..13b7865f55 100644
--- a/api/core/app/apps/workflow_app_runner.py
+++ b/api/core/app/apps/workflow_app_runner.py
@@ -25,6 +25,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
+from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -53,7 +54,6 @@ from core.workflow.graph_events import (
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
-from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -77,10 +77,18 @@ class WorkflowBasedAppRunner:
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
+ @staticmethod
+ def _resolve_user_from(invoke_from: InvokeFrom) -> UserFrom:
+ if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
+ return UserFrom.ACCOUNT
+ return UserFrom.END_USER
+
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
+ user_from: UserFrom,
+ invoke_from: InvokeFrom,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
@@ -105,8 +113,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
- user_from=UserFrom.ACCOUNT,
- invoke_from=InvokeFrom.SERVICE_API,
+ user_from=user_from,
+ invoke_from=invoke_from,
call_depth=0,
)
@@ -149,7 +157,7 @@ class WorkflowBasedAppRunner:
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
@@ -158,18 +166,22 @@ class WorkflowBasedAppRunner:
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
- graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
+ graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
+ node_type_filter_key="iteration_id",
+ node_type_label="iteration",
)
elif single_loop_run:
- graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
+ graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
+ node_type_filter_key="loop_id",
+ node_type_label="loop",
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
@@ -250,7 +262,7 @@ class WorkflowBasedAppRunner:
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT,
- invoke_from=InvokeFrom.SERVICE_API,
+ invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
@@ -260,7 +272,9 @@ class WorkflowBasedAppRunner:
)
# init graph
- graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
+ graph = Graph.init(
+ graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True
+ )
if not graph:
raise ValueError("graph not found in workflow")
@@ -306,44 +320,6 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
- def _get_graph_and_variable_pool_of_single_iteration(
- self,
- workflow: Workflow,
- node_id: str,
- user_inputs: dict[str, Any],
- graph_runtime_state: GraphRuntimeState,
- ) -> tuple[Graph, VariablePool]:
- """
- Get variable pool of single iteration
- """
- return self._get_graph_and_variable_pool_for_single_node_run(
- workflow=workflow,
- node_id=node_id,
- user_inputs=user_inputs,
- graph_runtime_state=graph_runtime_state,
- node_type_filter_key="iteration_id",
- node_type_label="iteration",
- )
-
- def _get_graph_and_variable_pool_of_single_loop(
- self,
- workflow: Workflow,
- node_id: str,
- user_inputs: dict[str, Any],
- graph_runtime_state: GraphRuntimeState,
- ) -> tuple[Graph, VariablePool]:
- """
- Get variable pool of single loop
- """
- return self._get_graph_and_variable_pool_for_single_node_run(
- workflow=workflow,
- node_id=node_id,
- user_inputs=user_inputs,
- graph_runtime_state=graph_runtime_state,
- node_type_filter_key="loop_id",
- node_type_label="loop",
- )
-
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py
index 0cb573cb86..5bc453420d 100644
--- a/api/core/app/entities/app_invoke_entities.py
+++ b/api/core/app/entities/app_invoke_entities.py
@@ -42,7 +42,8 @@ class InvokeFrom(StrEnum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
- PUBLISHED = "published"
+ # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow.
+ PUBLISHED_PIPELINE = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"
diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py
index 79fbafe39e..3f9f3da9b2 100644
--- a/api/core/app/features/annotation_reply/annotation_reply.py
+++ b/api/core/app/features/annotation_reply/annotation_reply.py
@@ -75,7 +75,7 @@ class AnnotationReplyFeature:
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
- annotation.question,
+ annotation.question_text,
annotation.content,
query,
user_id,
diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py
new file mode 100644
index 0000000000..c070845b73
--- /dev/null
+++ b/api/core/app/layers/conversation_variable_persist_layer.py
@@ -0,0 +1,60 @@
+import logging
+
+from core.variables import VariableBase
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
+from core.workflow.conversation_variable_updater import ConversationVariableUpdater
+from core.workflow.enums import NodeType
+from core.workflow.graph_engine.layers.base import GraphEngineLayer
+from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
+from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
+
+logger = logging.getLogger(__name__)
+
+
+class ConversationVariablePersistenceLayer(GraphEngineLayer):
+ def __init__(self, conversation_variable_updater: ConversationVariableUpdater) -> None:
+ super().__init__()
+ self._conversation_variable_updater = conversation_variable_updater
+
+ def on_graph_start(self) -> None:
+ pass
+
+ def on_event(self, event: GraphEngineEvent) -> None:
+ if not isinstance(event, NodeRunSucceededEvent):
+ return
+ if event.node_type != NodeType.VARIABLE_ASSIGNER:
+ return
+ if self.graph_runtime_state is None:
+ return
+
+ updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or []
+ if not updated_variables:
+ return
+
+ conversation_id = self.graph_runtime_state.system_variable.conversation_id
+ if conversation_id is None:
+ return
+
+ updated_any = False
+ for item in updated_variables:
+ selector = item.selector
+ if len(selector) < 2:
+ logger.warning("Conversation variable selector invalid. selector=%s", selector)
+ continue
+ if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
+ continue
+ variable = self.graph_runtime_state.variable_pool.get(selector)
+ if not isinstance(variable, VariableBase):
+ logger.warning(
+ "Conversation variable not found in variable pool. selector=%s",
+ selector,
+ )
+ continue
+ self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable)
+ updated_any = True
+
+ if updated_any:
+ self._conversation_variable_updater.flush()
+
+ def on_graph_end(self, error: Exception | None) -> None:
+ pass
diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py
index 61a3e1baca..bf76ae8178 100644
--- a/api/core/app/layers/pause_state_persist_layer.py
+++ b/api/core/app/layers/pause_state_persist_layer.py
@@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
+ super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
@@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
if not isinstance(event, GraphRunPausedEvent):
return
- assert self.graph_runtime_state is not None
-
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py
index fe1a46a945..a7ea9ef446 100644
--- a/api/core/app/layers/trigger_post_layer.py
+++ b/api/core/app/layers/trigger_post_layer.py
@@ -3,8 +3,8 @@ from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
-from sqlalchemy.orm import Session, sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
@@ -31,12 +31,11 @@ class TriggerPostLayer(GraphEngineLayer):
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
- session_maker: sessionmaker[Session],
):
+ super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
- self.session_maker = session_maker
def on_graph_start(self):
pass
@@ -46,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer):
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
- with self.session_maker() as session:
+ with session_factory.create_session() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:
@@ -57,10 +56,6 @@ class TriggerPostLayer(GraphEngineLayer):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
- if not self.graph_runtime_state:
- logger.exception("Graph runtime state is not set")
- return
-
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index 5c169f4db1..6c997753fa 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
+ StreamEvent,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
+ _precomputed_event_type: StreamEvent | None = None
def __init__(
self,
@@ -342,9 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
+ # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
+ if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
+ self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
+ message_id=self._message_id
+ )
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
+ event_type=self._precomputed_event_type,
)
else:
yield self._agent_message_to_stream_response(
diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py
index 2e6f92efa5..2d4ee08daf 100644
--- a/api/core/app/task_pipeline/message_cycle_manager.py
+++ b/api/core/app/task_pipeline/message_cycle_manager.py
@@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
StreamEvent,
WorkflowTaskState,
)
+from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
@@ -54,6 +55,22 @@ class MessageCycleManager:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
+ self._message_has_file: set[str] = set()
+
+ def get_message_event_type(self, message_id: str) -> StreamEvent:
+ # Fast path: cached determination from prior QueueMessageFileEvent
+ if message_id in self._message_has_file:
+ return StreamEvent.MESSAGE_FILE
+
+ # Use SQLAlchemy 2.x style session.scalar(select(...))
+ with session_factory.create_session() as session:
+ message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
+
+ if message_file:
+ self._message_has_file.add(message_id)
+ return StreamEvent.MESSAGE_FILE
+
+ return StreamEvent.MESSAGE
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
@@ -185,6 +202,8 @@ class MessageCycleManager:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
+ self._message_has_file.add(message_file.message_id)
+
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
@@ -214,7 +233,11 @@ class MessageCycleManager:
return None
def message_to_stream_response(
- self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
+ self,
+ answer: str,
+ message_id: str,
+ from_variable_selector: list[str] | None = None,
+ event_type: StreamEvent | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
@@ -222,16 +245,12 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
- with Session(db.engine, expire_on_commit=False) as session:
- message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
- event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
-
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
- event=event_type,
+ event=event_type or StreamEvent.MESSAGE,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py
new file mode 100644
index 0000000000..172ee5d703
--- /dev/null
+++ b/api/core/app/workflow/__init__.py
@@ -0,0 +1,3 @@
+from .node_factory import DifyNodeFactory
+
+__all__ = ["DifyNodeFactory"]
diff --git a/api/core/app/workflow/layers/__init__.py b/api/core/app/workflow/layers/__init__.py
new file mode 100644
index 0000000000..945f75303c
--- /dev/null
+++ b/api/core/app/workflow/layers/__init__.py
@@ -0,0 +1,10 @@
+"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
+
+from .observability import ObservabilityLayer
+from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
+
+__all__ = [
+ "ObservabilityLayer",
+ "PersistenceWorkflowInfo",
+ "WorkflowPersistenceLayer",
+]
diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py
new file mode 100644
index 0000000000..94839c8ae3
--- /dev/null
+++ b/api/core/app/workflow/layers/observability.py
@@ -0,0 +1,176 @@
+"""
+Observability layer for GraphEngine.
+
+This layer creates OpenTelemetry spans for node execution, enabling distributed
+tracing of workflow execution. It establishes OTel context during node execution
+so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
+associates with the node span.
+"""
+
+import logging
+from dataclasses import dataclass
+from typing import cast, final
+
+from opentelemetry import context as context_api
+from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
+from typing_extensions import override
+
+from configs import dify_config
+from core.workflow.enums import NodeType
+from core.workflow.graph_engine.layers.base import GraphEngineLayer
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.parser import (
+ DefaultNodeOTelParser,
+ LLMNodeOTelParser,
+ NodeOTelParser,
+ RetrievalNodeOTelParser,
+ ToolNodeOTelParser,
+)
+from extensions.otel.runtime import is_instrument_flag_enabled
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(slots=True)
+class _NodeSpanContext:
+ span: "Span"
+ token: object
+
+
+@final
+class ObservabilityLayer(GraphEngineLayer):
+ """
+ Layer that creates OpenTelemetry spans for node execution.
+
+ This layer:
+ - Creates a span when a node starts execution
+ - Establishes OTel context so automatic instrumentation associates with the span
+ - Sets complete attributes and status when node execution ends
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._node_contexts: dict[str, _NodeSpanContext] = {}
+ self._parsers: dict[NodeType, NodeOTelParser] = {}
+ self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
+ self._is_disabled: bool = False
+ self._tracer: Tracer | None = None
+ self._build_parser_registry()
+ self._init_tracer()
+
+ def _init_tracer(self) -> None:
+ """Initialize OpenTelemetry tracer in constructor."""
+ if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
+ self._is_disabled = True
+ return
+
+ try:
+ self._tracer = get_tracer(__name__)
+ except Exception as e:
+ logger.warning("Failed to get OpenTelemetry tracer: %s", e)
+ self._is_disabled = True
+
+ def _build_parser_registry(self) -> None:
+ """Initialize parser registry for node types."""
+ self._parsers = {
+ NodeType.TOOL: ToolNodeOTelParser(),
+ NodeType.LLM: LLMNodeOTelParser(),
+ NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
+ }
+
+ def _get_parser(self, node: Node) -> NodeOTelParser:
+ node_type = getattr(node, "node_type", None)
+ if isinstance(node_type, NodeType):
+ return self._parsers.get(node_type, self._default_parser)
+ return self._default_parser
+
+ @override
+ def on_graph_start(self) -> None:
+ """Called when graph execution starts."""
+ self._node_contexts.clear()
+
+ @override
+ def on_node_run_start(self, node: Node) -> None:
+ """
+ Called when a node starts execution.
+
+ Creates a span and establishes OTel context for automatic instrumentation.
+ """
+ if self._is_disabled:
+ return
+
+ try:
+ if not self._tracer:
+ return
+
+ execution_id = node.execution_id
+ if not execution_id:
+ return
+
+ parent_context = context_api.get_current()
+ span = self._tracer.start_span(
+ f"{node.title}",
+ kind=SpanKind.INTERNAL,
+ context=parent_context,
+ )
+
+ new_context = set_span_in_context(span)
+ token = context_api.attach(new_context)
+
+ self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
+
+ except Exception as e:
+ logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
+
+ @override
+ def on_node_run_end(
+ self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ """
+ Called when a node finishes execution.
+
+ Sets complete attributes, records exceptions, and ends the span.
+ """
+ if self._is_disabled:
+ return
+
+ try:
+ execution_id = node.execution_id
+ if not execution_id:
+ return
+ node_context = self._node_contexts.get(execution_id)
+ if not node_context:
+ return
+
+ span = node_context.span
+ parser = self._get_parser(node)
+ try:
+ parser.parse(node=node, span=span, error=error, result_event=result_event)
+ span.end()
+ finally:
+ token = node_context.token
+ if token is not None:
+ try:
+ context_api.detach(token)
+ except Exception:
+ logger.warning("Failed to detach OpenTelemetry token: %s", token)
+ self._node_contexts.pop(execution_id, None)
+
+ except Exception as e:
+ logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
+
+ @override
+ def on_event(self, event) -> None:
+ """Not used in this layer."""
+ pass
+
+ @override
+ def on_graph_end(self, error: Exception | None) -> None:
+ """Called when graph execution ends."""
+ if self._node_contexts:
+ logger.warning(
+ "ObservabilityLayer: %d node spans were not properly ended",
+ len(self._node_contexts),
+ )
+ self._node_contexts.clear()
diff --git a/api/core/workflow/graph_engine/layers/persistence.py b/api/core/app/workflow/layers/persistence.py
similarity index 99%
rename from api/core/workflow/graph_engine/layers/persistence.py
rename to api/core/app/workflow/layers/persistence.py
index b70f36ec9e..41052b4f52 100644
--- a/api/core/workflow/graph_engine/layers/persistence.py
+++ b/api/core/app/workflow/layers/persistence.py
@@ -45,7 +45,6 @@ from core.workflow.graph_events import (
from core.workflow.node_events import NodeRunResult
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
-from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@@ -316,6 +315,9 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
# workflow inputs stay reusable without binding future runs to this conversation.
continue
inputs[f"sys.{field_name}"] = value
+ # Local import to avoid circular dependency during app bootstrapping.
+ from core.workflow.workflow_entry import WorkflowEntry
+
handled = WorkflowEntry.handle_special_values(inputs)
return handled or {}
@@ -337,8 +339,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
- if runtime_state is None:
- return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
@@ -404,6 +404,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
- if runtime_state is None:
- return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py
new file mode 100644
index 0000000000..a5773bbef8
--- /dev/null
+++ b/api/core/app/workflow/node_factory.py
@@ -0,0 +1,143 @@
+from collections.abc import Callable, Sequence
+from typing import TYPE_CHECKING, final
+
+from typing_extensions import override
+
+from configs import dify_config
+from core.file.file_manager import file_manager
+from core.helper.code_executor.code_executor import CodeExecutor
+from core.helper.code_executor.code_node_provider import CodeNodeProvider
+from core.helper.ssrf_proxy import ssrf_proxy
+from core.tools.tool_file_manager import ToolFileManager
+from core.workflow.entities.graph_config import NodeConfigDict
+from core.workflow.enums import NodeType
+from core.workflow.graph.graph import NodeFactory
+from core.workflow.nodes.base.node import Node
+from core.workflow.nodes.code.code_node import CodeNode
+from core.workflow.nodes.code.limits import CodeNodeLimits
+from core.workflow.nodes.http_request.node import HttpRequestNode
+from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
+from core.workflow.nodes.template_transform.template_renderer import (
+ CodeExecutorJinja2TemplateRenderer,
+ Jinja2TemplateRenderer,
+)
+from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
+
+@final
+class DifyNodeFactory(NodeFactory):
+ """
+ Default implementation of NodeFactory that uses the traditional node mapping.
+
+ This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
+ and instantiating the appropriate node class.
+ """
+
+ def __init__(
+ self,
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ code_executor: type[CodeExecutor] | None = None,
+ code_providers: Sequence[type[CodeNodeProvider]] | None = None,
+ code_limits: CodeNodeLimits | None = None,
+ template_renderer: Jinja2TemplateRenderer | None = None,
+ http_request_http_client: HttpClientProtocol | None = None,
+ http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ http_request_file_manager: FileManagerProtocol | None = None,
+ ) -> None:
+ self.graph_init_params = graph_init_params
+ self.graph_runtime_state = graph_runtime_state
+ self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
+ self._code_providers: tuple[type[CodeNodeProvider], ...] = (
+ tuple(code_providers) if code_providers else CodeNode.default_code_providers()
+ )
+ self._code_limits = code_limits or CodeNodeLimits(
+ max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
+ max_number=dify_config.CODE_MAX_NUMBER,
+ min_number=dify_config.CODE_MIN_NUMBER,
+ max_precision=dify_config.CODE_MAX_PRECISION,
+ max_depth=dify_config.CODE_MAX_DEPTH,
+ max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
+ max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
+ max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
+ )
+ self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
+ self._http_request_http_client = http_request_http_client or ssrf_proxy
+ self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
+ self._http_request_file_manager = http_request_file_manager or file_manager
+
+ @override
+ def create_node(self, node_config: NodeConfigDict) -> Node:
+ """
+ Create a Node instance from node configuration data using the traditional mapping.
+
+ :param node_config: node configuration dictionary containing type and other data
+ :return: initialized Node instance
+ :raises ValueError: if node type is unknown or configuration is invalid
+ """
+ # Get node_id from config
+ node_id = node_config["id"]
+
+ # Get node type from config
+ node_data = node_config["data"]
+ try:
+ node_type = NodeType(node_data["type"])
+ except ValueError:
+ raise ValueError(f"Unknown node type: {node_data['type']}")
+
+ # Get node class
+ node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
+ if not node_mapping:
+ raise ValueError(f"No class mapping found for node type: {node_type}")
+
+ latest_node_class = node_mapping.get(LATEST_VERSION)
+ node_version = str(node_data.get("version", "1"))
+ matched_node_class = node_mapping.get(node_version)
+ node_class = matched_node_class or latest_node_class
+ if not node_class:
+ raise ValueError(f"No latest version class found for node type: {node_type}")
+
+ # Create node instance
+ if node_type == NodeType.CODE:
+ return CodeNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ code_executor=self._code_executor,
+ code_providers=self._code_providers,
+ code_limits=self._code_limits,
+ )
+
+ if node_type == NodeType.TEMPLATE_TRANSFORM:
+ return TemplateTransformNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ template_renderer=self._template_renderer,
+ )
+
+ if node_type == NodeType.HTTP_REQUEST:
+ return HttpRequestNode(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ http_client=self._http_request_http_client,
+ tool_file_manager_factory=self._http_request_tool_file_manager_factory,
+ file_manager=self._http_request_file_manager,
+ )
+
+ return node_class(
+ id=node_id,
+ config=node_config,
+ graph_init_params=self.graph_init_params,
+ graph_runtime_state=self.graph_runtime_state,
+ )
diff --git a/api/core/datasource/__base/datasource_plugin.py b/api/core/datasource/__base/datasource_plugin.py
index 50c7249fe4..451e4fda0e 100644
--- a/api/core/datasource/__base/datasource_plugin.py
+++ b/api/core/datasource/__base/datasource_plugin.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from configs import dify_config
@@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
"""
return DatasourceProviderType.LOCAL_FILE
- def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
+ def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py
index 260dcf04f5..dde7d59726 100644
--- a/api/core/datasource/entities/datasource_entities.py
+++ b/api/core/datasource/entities/datasource_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import enum
from enum import StrEnum
from typing import Any
@@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
ONLINE_DRIVE = "online_drive"
@classmethod
- def value_of(cls, value: str) -> "DatasourceProviderType":
+ def value_of(cls, value: str) -> DatasourceProviderType:
"""
Get value of given mode.
@@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
typ: DatasourceParameterType,
required: bool,
options: list[str] | None = None,
- ) -> "DatasourceParameter":
+ ) -> DatasourceParameter:
"""
get a simple datasource parameter
@@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
- def empty(cls) -> "DatasourceInvokeMeta":
+ def empty(cls) -> DatasourceInvokeMeta:
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
- def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
+ def error_instance(cls, error: str) -> DatasourceInvokeMeta:
"""
Get an instance of DatasourceInvokeMeta with error
"""
diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py
index 98ea15e3fc..ce23da1e09 100644
--- a/api/core/datasource/online_document/online_document_plugin.py
+++ b/api/core/datasource/online_document/online_document_plugin.py
@@ -1,4 +1,4 @@
-from collections.abc import Generator, Mapping
+from collections.abc import Generator
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
@@ -34,7 +34,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
def get_online_document_pages(
self,
user_id: str,
- datasource_parameters: Mapping[str, Any],
+ datasource_parameters: dict[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager()
diff --git a/web/__mocks__/mime.js b/api/core/db/__init__.py
similarity index 100%
rename from web/__mocks__/mime.js
rename to api/core/db/__init__.py
diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py
new file mode 100644
index 0000000000..45d4bc4594
--- /dev/null
+++ b/api/core/db/session_factory.py
@@ -0,0 +1,38 @@
+from sqlalchemy import Engine
+from sqlalchemy.orm import Session, sessionmaker
+
+_session_maker: sessionmaker[Session] | None = None
+
+
+def configure_session_factory(engine: Engine, expire_on_commit: bool = False):
+ """Configure the global session factory"""
+ global _session_maker
+ _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit)
+
+
+def get_session_maker() -> sessionmaker[Session]:
+ if _session_maker is None:
+ raise RuntimeError("Session factory not configured. Call configure_session_factory() first.")
+ return _session_maker
+
+
+def create_session() -> Session:
+ return get_session_maker()()
+
+
+# Class wrapper for convenience
+class SessionFactory:
+ @staticmethod
+ def configure(engine: Engine, expire_on_commit: bool = False):
+ configure_session_factory(engine, expire_on_commit)
+
+ @staticmethod
+ def get_session_maker() -> sessionmaker[Session]:
+ return get_session_maker()
+
+ @staticmethod
+ def create_session() -> Session:
+ return create_session()
+
+
+session_factory = SessionFactory()
diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py
index bed3a35400..b1ba3c3e2a 100644
--- a/api/core/entities/knowledge_entities.py
+++ b/api/core/entities/knowledge_entities.py
@@ -1,8 +1,9 @@
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel):
content: str
+ summary: str | None = None
child_chunks: list[str] | None = None
@@ -20,9 +21,17 @@ class IndexingEstimate(BaseModel):
class PipelineDataset(BaseModel):
id: str
name: str
- description: str | None = Field(default="", description="knowledge dataset description")
+ description: str = Field(default="", description="knowledge dataset description")
chunk_structure: str
+ @field_validator("description", mode="before")
+ @classmethod
+ def normalize_description(cls, value: str | None) -> str:
+ """Coerce None to empty string so description is always a string."""
+ if value is None:
+ return ""
+ return value
+
class PipelineDocument(BaseModel):
id: str
diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py
index 7484cea04a..135d2a4945 100644
--- a/api/core/entities/mcp_provider.py
+++ b/api/core/entities/mcp_provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from datetime import datetime
from enum import StrEnum
@@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
updated_at: datetime
@classmethod
- def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
+ def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
return cls(
@@ -213,12 +215,23 @@ class MCPProviderEntity(BaseModel):
return None
def retrieve_tokens(self) -> OAuthTokens | None:
- """OAuth tokens if available"""
+ """Retrieve OAuth tokens if authentication is complete.
+
+ Returns:
+ OAuthTokens if the provider has been authenticated, None otherwise.
+ """
if not self.credentials:
return None
credentials = self.decrypt_credentials()
+ access_token = credentials.get("access_token", "")
+ # Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
+ # Note: We don't check for whitespace-only strings here because:
+ # 1. OAuth servers don't return whitespace-only access tokens in practice
+ # 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
+ if not access_token:
+ return None
return OAuthTokens(
- access_token=credentials.get("access_token", ""),
+ access_token=access_token,
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),
diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py
index 12431976f0..a123fb0321 100644
--- a/api/core/entities/model_entities.py
+++ b/api/core/entities/model_entities.py
@@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity):
@@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_small_dark=provider_entity.icon_small_dark,
- icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types,
)
@@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
- icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []
diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py
index 8a8067332d..0078ec7e4f 100644
--- a/api/core/entities/provider_entities.py
+++ b/api/core/entities/provider_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from enum import StrEnum, auto
from typing import Union
@@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
- def value_of(cls, value: str) -> "ProviderConfig.Type":
+ def value_of(cls, value: str) -> ProviderConfig.Type:
"""
Get value of given mode.
diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py
index 120fb73cdb..9945d7c1ab 100644
--- a/api/core/file/file_manager.py
+++ b/api/core/file/file_manager.py
@@ -104,6 +104,8 @@ def download(f: File, /):
):
return _download_file_content(f.storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
+ if f.remote_url is None:
+ raise ValueError("Missing file remote_url")
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
@@ -134,6 +136,8 @@ def _download_file_content(path: str, /):
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
+ if f.remote_url is None:
+ raise ValueError("Missing file remote_url")
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
data = response.content
@@ -164,3 +168,18 @@ def _to_url(f: File, /):
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
+
+
+class FileManager:
+ """
+ Adapter exposing file manager helpers behind FileManagerProtocol.
+
+ This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
+ where a protocol-typed file manager is expected.
+ """
+
+ def download(self, f: File, /) -> bytes:
+ return download(f)
+
+
+file_manager = FileManager()
diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py
index 6d553d7dc6..2ac483673a 100644
--- a/api/core/file/helpers.py
+++ b/api/core/file/helpers.py
@@ -8,8 +8,9 @@ import urllib.parse
from configs import dify_config
-def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
- url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
+def get_signed_file_url(upload_file_id: str, as_attachment=False, for_external: bool = True) -> str:
+ base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
+ url = f"{base_url}/files/{upload_file_id}/file-preview"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
diff --git a/api/core/file/models.py b/api/core/file/models.py
index d149205d77..6324523b22 100644
--- a/api/core/file/models.py
+++ b/api/core/file/models.py
@@ -112,17 +112,17 @@ class File(BaseModel):
return text
- def generate_url(self) -> str | None:
+ def generate_url(self, for_external: bool = True) -> str | None:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.remote_url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
if self.related_id is None:
raise ValueError("Missing file related_id")
- return helpers.get_signed_file_url(upload_file_id=self.related_id)
+ return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external)
elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]:
assert self.related_id is not None
assert self.extension is not None
- return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
+ return sign_tool_file(tool_file_id=self.related_id, extension=self.extension, for_external=for_external)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
@@ -133,7 +133,7 @@ class File(BaseModel):
"extension": self.extension,
"size": self.size,
"type": self.type,
- "url": self.generate_url(),
+ "url": self.generate_url(for_external=False),
}
@model_validator(mode="after")
diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py
index e93e1e4414..f4cce0b332 100644
--- a/api/core/helper/code_executor/code_node_provider.py
+++ b/api/core/helper/code_executor/code_node_provider.py
@@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC):
@classmethod
def get_default_config(cls) -> DefaultConfig:
- return {
- "type": "code",
- "config": {
- "variables": [
- {"variable": "arg1", "value_selector": []},
- {"variable": "arg2", "value_selector": []},
- ],
- "code_language": cls.get_language(),
- "code": cls.get_default_code(),
- "outputs": {"result": {"type": "string", "children": None}},
- },
+ variables: list[VariableConfig] = [
+ {"variable": "arg1", "value_selector": []},
+ {"variable": "arg2", "value_selector": []},
+ ]
+ outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}}
+
+ config: CodeConfig = {
+ "variables": variables,
+ "code_language": cls.get_language(),
+ "code": cls.get_default_code(),
+ "outputs": outputs,
}
+ return {"type": "code", "config": config}
diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
index 969125d2f7..5e4807401e 100644
--- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py
+++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
@@ -1,9 +1,14 @@
+from collections.abc import Mapping
from textwrap import dedent
+from typing import Any
from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer):
+ # Use separate placeholder for base64-encoded template to avoid confusion
+ _template_b64_placeholder: str = "{{template_b64}}"
+
@classmethod
def transform_response(cls, response: str):
"""
@@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
"""
return {"result": cls.extract_result_str_from_response(response)}
+ @classmethod
+ def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
+ """
+ Override base class to use base64 encoding for template code.
+ This prevents issues with special characters (quotes, newlines) in templates
+ breaking the generated Python script. Fixes #26818.
+ """
+ script = cls.get_runner_script()
+ # Encode template as base64 to safely embed any content including quotes
+ code_b64 = cls.serialize_code(code)
+ script = script.replace(cls._template_b64_placeholder, code_b64)
+ inputs_str = cls.serialize_inputs(inputs)
+ script = script.replace(cls._inputs_placeholder, inputs_str)
+ return script
+
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
- # declare main function
- def main(**inputs):
- import jinja2
- template = jinja2.Template('''{cls._code_placeholder}''')
- return template.render(**inputs)
-
+ import jinja2
import json
from base64 import b64decode
+ # declare main function
+ def main(**inputs):
+ # Decode base64-encoded template to handle special characters safely
+ template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
+ template = jinja2.Template(template_code)
+ return template.render(**inputs)
+
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index 3965f8cb31..5cdea19a8d 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
_inputs_placeholder: str = "{{inputs}}"
_result_tag: str = "<>"
+ @classmethod
+ def serialize_code(cls, code: str) -> str:
+ """
+ Serialize template code to base64 to safely embed in generated script.
+ This prevents issues with special characters like quotes breaking the script.
+ """
+ code_bytes = code.encode("utf-8")
+ return b64encode(code_bytes).decode("utf-8")
+
@classmethod
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
"""
@@ -67,7 +76,7 @@ class TemplateTransformer(ABC):
Post-process the result to convert scientific notation strings back to numbers
"""
- def convert_scientific_notation(value):
+ def convert_scientific_notation(value: Any) -> Any:
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
@@ -81,7 +90,7 @@ class TemplateTransformer(ABC):
return [convert_scientific_notation(v) for v in value]
return value
- return convert_scientific_notation(result) # type: ignore[no-any-return]
+ return convert_scientific_notation(result)
@classmethod
@abstractmethod
diff --git a/api/core/helper/csv_sanitizer.py b/api/core/helper/csv_sanitizer.py
new file mode 100644
index 0000000000..0023de5a35
--- /dev/null
+++ b/api/core/helper/csv_sanitizer.py
@@ -0,0 +1,89 @@
+"""CSV sanitization utilities to prevent formula injection attacks."""
+
+from typing import Any
+
+
+class CSVSanitizer:
+ """
+ Sanitizer for CSV export to prevent formula injection attacks.
+
+ This class provides methods to sanitize data before CSV export by escaping
+ characters that could be interpreted as formulas by spreadsheet applications
+ (Excel, LibreOffice, Google Sheets).
+
+ Formula injection occurs when user-controlled data starting with special
+ characters (=, +, -, @, tab, carriage return) is exported to CSV and opened
+ in a spreadsheet application, potentially executing malicious commands.
+ """
+
+ # Characters that can start a formula in Excel/LibreOffice/Google Sheets
+ FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"})
+
+ @classmethod
+ def sanitize_value(cls, value: Any) -> str:
+ """
+ Sanitize a value for safe CSV export.
+
+ Prefixes formula-initiating characters with a single quote to prevent
+ Excel/LibreOffice/Google Sheets from treating them as formulas.
+
+ Args:
+ value: The value to sanitize (will be converted to string)
+
+ Returns:
+ Sanitized string safe for CSV export
+
+ Examples:
+ >>> CSVSanitizer.sanitize_value("=1+1")
+ "'=1+1"
+ >>> CSVSanitizer.sanitize_value("Hello World")
+ "Hello World"
+ >>> CSVSanitizer.sanitize_value(None)
+ ""
+ """
+ if value is None:
+ return ""
+
+ # Convert to string
+ str_value = str(value)
+
+ # If empty, return as is
+ if not str_value:
+ return ""
+
+ # Check if first character is a formula initiator
+ if str_value[0] in cls.FORMULA_CHARS:
+ # Prefix with single quote to escape
+ return f"'{str_value}"
+
+ return str_value
+
+ @classmethod
+ def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]:
+ """
+ Sanitize specified fields in a dictionary.
+
+ Args:
+ data: Dictionary containing data to sanitize
+ fields_to_sanitize: List of field names to sanitize.
+ If None, sanitizes all string fields.
+
+ Returns:
+ Dictionary with sanitized values (creates a shallow copy)
+
+ Examples:
+ >>> data = {"question": "=1+1", "answer": "+calc", "id": "123"}
+ >>> CSVSanitizer.sanitize_dict(data, ["question", "answer"])
+ {"question": "'=1+1", "answer": "'+calc", "id": "123"}
+ """
+ sanitized = data.copy()
+
+ if fields_to_sanitize is None:
+ # Sanitize all string fields
+ fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)]
+
+ for field in fields_to_sanitize:
+ if field in sanitized:
+ sanitized[field] = cls.sanitize_value(sanitized[field])
+
+ return sanitized
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index 0de026f3c7..54068fc28d 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -4,11 +4,14 @@ Proxy requests to avoid SSRF
import logging
import time
+from typing import Any, TypeAlias
import httpx
+from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
+from core.tools.errors import ToolSSRFError
logger = logging.getLogger(__name__)
@@ -17,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
+Headers: TypeAlias = dict[str, str]
+_HEADERS_ADAPTER = TypeAdapter(Headers)
+
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
_SSRF_CLIENT_LIMITS = httpx.Limits(
@@ -32,6 +38,10 @@ class MaxRetriesExceededError(ValueError):
pass
+request_error = httpx.RequestError
+max_retries_exceeded_error = MaxRetriesExceededError
+
+
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(
@@ -71,7 +81,57 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
)
-def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+def _get_user_provided_host_header(headers: Headers | None) -> str | None:
+ """
+ Extract the user-provided Host header from the headers dict.
+
+ This is needed because when using a forward proxy, httpx may override the Host header.
+ We preserve the user's explicit Host header to support virtual hosting and other use cases.
+ """
+ if not headers:
+ return None
+ # Case-insensitive lookup for Host header
+ for key, value in headers.items():
+ if key.lower() == "host":
+ return value
+ return None
+
+
+def _inject_trace_headers(headers: Headers | None) -> Headers:
+ """
+ Inject W3C traceparent header for distributed tracing.
+
+ When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
+ When OTEL is disabled, we manually inject the traceparent header.
+ """
+ if headers is None:
+ headers = {}
+
+ # Skip if already present (case-insensitive check)
+ for key in headers:
+ if key.lower() == "traceparent":
+ return headers
+
+ # Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
+ if dify_config.ENABLE_OTEL:
+ return headers
+
+ # Generate and inject traceparent for non-OTEL scenarios
+ try:
+ from core.helper.trace_id_helper import generate_traceparent_header
+
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+ except Exception:
+ # Silently ignore errors to avoid breaking requests
+ logger.debug("Failed to generate traceparent header", exc_info=True)
+
+ return headers
+
+
+def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
+ # Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
@@ -87,13 +147,47 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
+ if not isinstance(verify_option, bool):
+ raise ValueError("ssl_verify must be a boolean")
client = _get_ssrf_client(verify_option)
+ # Inject traceparent header for distributed tracing (when OTEL is not enabled)
+ try:
+ headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {})
+ except ValidationError as e:
+ raise ValueError("headers must be a mapping of string keys to string values") from e
+ headers = _inject_trace_headers(headers)
+ kwargs["headers"] = headers
+
+ # Preserve user-provided Host header
+ # When using a forward proxy, httpx may override the Host header based on the URL.
+ # We extract and preserve any explicitly set Host header to support virtual hosting.
+ user_provided_host = _get_user_provided_host_header(headers)
+
retries = 0
while retries <= max_retries:
try:
+ # Preserve the user-provided Host header
+ # httpx may override the Host header when using a proxy
+ headers = {k: v for k, v in headers.items() if k.lower() != "host"}
+ if user_provided_host is not None:
+ headers["host"] = user_provided_host
+ kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs)
+ # Check for SSRF protection by Squid proxy
+ if response.status_code in (401, 403):
+ # Check if this is a Squid SSRF rejection
+ server_header = response.headers.get("server", "").lower()
+ via_header = response.headers.get("via", "").lower()
+
+ # Squid typically identifies itself in Server or Via headers
+ if "squid" in server_header or "squid" in via_header:
+ raise ToolSSRFError(
+ f"Access to '{url}' was blocked by SSRF protection. "
+ f"The URL may point to a private or local network address. "
+ )
+
if response.status_code not in STATUS_FORCELIST:
return response
else:
@@ -114,25 +208,63 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
-def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("GET", url, max_retries=max_retries, **kwargs)
-def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("POST", url, max_retries=max_retries, **kwargs)
-def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("PUT", url, max_retries=max_retries, **kwargs)
-def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
-def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
-def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
+def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
+
+
+class SSRFProxy:
+ """
+ Adapter exposing SSRF-protected HTTP helpers behind HttpClientProtocol.
+
+ This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
+ where a protocol-typed HTTP client is expected.
+ """
+
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]:
+ return max_retries_exceeded_error
+
+ @property
+ def request_error(self) -> type[Exception]:
+ return request_error
+
+ def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
+ return get(url=url, max_retries=max_retries, **kwargs)
+
+ def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
+ return head(url=url, max_retries=max_retries, **kwargs)
+
+ def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
+ return post(url=url, max_retries=max_retries, **kwargs)
+
+ def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
+ return put(url=url, max_retries=max_retries, **kwargs)
+
+ def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
+ return delete(url=url, max_retries=max_retries, **kwargs)
+
+ def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
+ return patch(url=url, max_retries=max_retries, **kwargs)
+
+
+ssrf_proxy = SSRFProxy()
diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py
deleted file mode 100644
index eef5937407..0000000000
--- a/api/core/helper/tool_provider_cache.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import json
-import logging
-from typing import Any
-
-from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
-from extensions.ext_redis import redis_client, redis_fallback
-
-logger = logging.getLogger(__name__)
-
-
-class ToolProviderListCache:
- """Cache for tool provider lists"""
-
- CACHE_TTL = 300 # 5 minutes
-
- @staticmethod
- def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
- """Generate cache key for tool providers list"""
- type_filter = typ or "all"
- return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
-
- @staticmethod
- @redis_fallback(default_return=None)
- def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
- """Get cached tool providers"""
- cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
- cached_data = redis_client.get(cache_key)
- if cached_data:
- try:
- return json.loads(cached_data.decode("utf-8"))
- except (json.JSONDecodeError, UnicodeDecodeError):
- logger.warning("Failed to decode cached tool providers data")
- return None
- return None
-
- @staticmethod
- @redis_fallback()
- def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
- """Cache tool providers"""
- cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
- redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
-
- @staticmethod
- @redis_fallback()
- def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
- """Invalidate cache for tool providers"""
- if typ:
- # Invalidate specific type cache
- cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
- redis_client.delete(cache_key)
- else:
- # Invalidate all caches for this tenant
- pattern = f"tool_providers:tenant_id:{tenant_id}:*"
- keys = list(redis_client.scan_iter(pattern))
- if keys:
- redis_client.delete(*keys)
diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py
index 820502e558..e827859109 100644
--- a/api/core/helper/trace_id_helper.py
+++ b/api/core/helper/trace_id_helper.py
@@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
+
+
+def get_span_id_from_otel_context() -> str | None:
+ """
+ Retrieve the current span ID from the active OpenTelemetry trace context.
+
+ Returns:
+ A 16-character hex string representing the span ID, or None if not available.
+ """
+ try:
+ from opentelemetry.trace import get_current_span
+ from opentelemetry.trace.span import INVALID_SPAN_ID
+
+ span = get_current_span()
+ if not span:
+ return None
+
+ span_context = span.get_span_context()
+ if not span_context or span_context.span_id == INVALID_SPAN_ID:
+ return None
+
+ return f"{span_context.span_id:016x}"
+ except Exception:
+ return None
+
+
+def generate_traceparent_header() -> str | None:
+ """
+ Generate a W3C traceparent header from the current context.
+
+ Uses OpenTelemetry context if available, otherwise uses the
+ ContextVar-based trace_id from the logging context.
+
+ Format: {version}-{trace_id}-{span_id}-{flags}
+ Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
+
+ Returns:
+ A valid traceparent header string, or None if generation fails.
+ """
+ import uuid
+
+ # Try OTEL context first
+ trace_id = get_trace_id_from_otel_context()
+ span_id = get_span_id_from_otel_context()
+
+ if trace_id and span_id:
+ return f"00-{trace_id}-{span_id}-01"
+
+ # Fallback: use ContextVar-based trace_id or generate new one
+ from core.logging.context import get_trace_id as get_logging_trace_id
+
+ trace_id = get_logging_trace_id() or uuid.uuid4().hex
+
+ # Generate a new span_id (16 hex chars)
+ span_id = uuid.uuid4().hex[:16]
+
+ return f"00-{trace_id}-{span_id}-01"
diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py
index af860a1070..370e64e385 100644
--- a/api/core/hosting_configuration.py
+++ b/api/core/hosting_configuration.py
@@ -56,6 +56,10 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
+ self.provider_map[f"{DEFAULT_PLUGIN_ID}/tongyi/tongyi"] = self.init_tongyi()
self.moderation_config = self.init_moderation_config()
@@ -128,7 +132,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
- hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
+ hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@@ -156,18 +160,49 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
- @staticmethod
- def init_anthropic() -> HostingProvider:
- quota_unit = QuotaUnit.TOKENS
+ def init_gemini(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_GEMINI_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
+ }
+
+ if dify_config.HOSTED_GEMINI_API_BASE:
+ credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
+ def init_anthropic(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
- hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
- trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
- paid_quota = PaidHostingQuota()
+ paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
@@ -185,6 +220,94 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
+ def init_tongyi(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_TONGYI_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_TONGYI_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "dashscope_api_key": dify_config.HOSTED_TONGYI_API_KEY,
+ "use_international_endpoint": dify_config.HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT,
+ }
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
+ def init_xai(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_XAI_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_XAI_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "api_key": dify_config.HOSTED_XAI_API_KEY,
+ }
+
+ if dify_config.HOSTED_XAI_API_BASE:
+ credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
+ def init_deepseek(self) -> HostingProvider:
+ quota_unit = QuotaUnit.CREDITS
+ quotas: list[HostingQuota] = []
+
+ if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
+ hosted_quota_limit = 0
+ trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
+ trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
+ quotas.append(trial_quota)
+
+ if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
+ paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
+ paid_quota = PaidHostingQuota(restrict_models=paid_models)
+ quotas.append(paid_quota)
+
+ if len(quotas) > 0:
+ credentials = {
+ "api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
+ }
+
+ if dify_config.HOSTED_DEEPSEEK_API_BASE:
+ credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
+
+ return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
+
+ return HostingProvider(
+ enabled=False,
+ quota_unit=quota_unit,
+ )
+
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index 59de4f403d..e172e88298 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -311,14 +311,18 @@ class IndexingRunner:
qa_preview_texts: list[QAPreviewDetail] = []
total_segments = 0
+ # doc_form represents the segmentation method (general, parent-child, QA)
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ # one extract_setting is one source document
for extract_setting in extract_settings:
# extract
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
+ # Extract document content
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
+ # Cleaning and segmentation
documents = index_processor.transform(
text_docs,
current_user=None,
@@ -361,6 +365,12 @@ class IndexingRunner:
if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
+
+ # Generate summary preview
+ summary_index_setting = tmp_processing_rule.get("summary_index_setting")
+ if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
+ preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
+
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
@@ -396,7 +406,7 @@ class IndexingRunner:
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
- "credential_id": data_source_info["credential_id"],
+ "credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
diff --git a/api/core/llm_generator/entities.py b/api/core/llm_generator/entities.py
new file mode 100644
index 0000000000..3bb8d2c899
--- /dev/null
+++ b/api/core/llm_generator/entities.py
@@ -0,0 +1,20 @@
+"""Shared payload models for LLM generator helpers and controllers."""
+
+from pydantic import BaseModel, Field
+
+from core.app.app_config.entities import ModelConfig
+
+
+class RuleGeneratePayload(BaseModel):
+ instruction: str = Field(..., description="Rule generation instruction")
+ model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
+ no_variable: bool = Field(default=False, description="Whether to exclude variables")
+
+
+class RuleCodeGeneratePayload(RuleGeneratePayload):
+ code_language: str = Field(default="javascript", description="Programming language for code generation")
+
+
+class RuleStructuredOutputPayload(BaseModel):
+ instruction: str = Field(..., description="Structured output generation instruction")
+ model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index 4a577e6c38..5b2c640265 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -6,6 +6,8 @@ from typing import Protocol, cast
import json_repair
+from core.app.app_config.entities import ModelConfig
+from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@@ -71,16 +73,23 @@ class LLMGenerator:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
- answer = cast(str, response.message.content)
- cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL)
- if cleaned_answer is None:
+ answer = response.message.get_text_content()
+ if answer == "":
return ""
try:
- result_dict = json.loads(cleaned_answer)
- answer = result_dict["Your Output"]
+ result_dict = json.loads(answer)
except json.JSONDecodeError:
- logger.exception("Failed to generate name after answer, use query instead")
+ result_dict = json_repair.loads(answer)
+
+ if not isinstance(result_dict, dict):
answer = query
+ else:
+ output = result_dict.get("Your Output")
+ if isinstance(output, str) and output.strip():
+ answer = output.strip()
+ else:
+ answer = query
+
name = answer.strip()
if len(name) > 75:
@@ -144,19 +153,19 @@ class LLMGenerator:
return questions
@classmethod
- def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool):
+ def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
output_parser = RuleConfigGeneratorOutputParser()
error = ""
error_step = ""
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
- model_parameters = model_config.get("completion_params", {})
- if no_variable:
+ model_parameters = args.model_config_data.completion_params
+ if args.no_variable:
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
prompt_generate = prompt_template.format(
inputs={
- "TASK_DESCRIPTION": instruction,
+ "TASK_DESCRIPTION": args.instruction,
},
remove_template_variables=False,
)
@@ -168,8 +177,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
- provider=model_config.get("provider", ""),
- model=model_config.get("name", ""),
+ provider=args.model_config_data.provider,
+ model=args.model_config_data.name,
)
try:
@@ -177,13 +186,13 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- rule_config["prompt"] = cast(str, response.message.content)
+ rule_config["prompt"] = response.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate rule config"
except Exception as e:
- logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
+ logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -202,7 +211,7 @@ class LLMGenerator:
# format the prompt_generate_prompt
prompt_generate_prompt = prompt_template.format(
inputs={
- "TASK_DESCRIPTION": instruction,
+ "TASK_DESCRIPTION": args.instruction,
},
remove_template_variables=False,
)
@@ -213,8 +222,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
- provider=model_config.get("provider", ""),
- model=model_config.get("name", ""),
+ provider=args.model_config_data.provider,
+ model=args.model_config_data.name,
)
try:
@@ -230,13 +239,11 @@ class LLMGenerator:
return rule_config
- rule_config["prompt"] = cast(str, prompt_content.message.content)
+ rule_config["prompt"] = prompt_content.message.get_text_content()
- if not isinstance(prompt_content.message.content, str):
- raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
- "INPUT_TEXT": prompt_content.message.content,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -245,8 +252,8 @@ class LLMGenerator:
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
inputs={
- "TASK_DESCRIPTION": instruction,
- "INPUT_TEXT": prompt_content.message.content,
+ "TASK_DESCRIPTION": args.instruction,
+ "INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -256,7 +263,7 @@ class LLMGenerator:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
- rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
+ rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
@@ -265,13 +272,13 @@ class LLMGenerator:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
- rule_config["opening_statement"] = cast(str, statement_content.message.content)
+ rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
except Exception as e:
- logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
+ logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -279,16 +286,20 @@ class LLMGenerator:
return rule_config
@classmethod
- def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
- if code_language == "python":
+ def generate_code(
+ cls,
+ tenant_id: str,
+ args: RuleCodeGeneratePayload,
+ ):
+ if args.code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format(
inputs={
- "INSTRUCTION": instruction,
- "CODE_LANGUAGE": code_language,
+ "INSTRUCTION": args.instruction,
+ "CODE_LANGUAGE": args.code_language,
},
remove_template_variables=False,
)
@@ -297,28 +308,28 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
- provider=model_config.get("provider", ""),
- model=model_config.get("name", ""),
+ provider=args.model_config_data.provider,
+ model=args.model_config_data.name,
)
prompt_messages = [UserPromptMessage(content=prompt)]
- model_parameters = model_config.get("completion_params", {})
+ model_parameters = args.model_config_data.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- generated_code = cast(str, response.message.content)
- return {"code": generated_code, "language": code_language, "error": ""}
+ generated_code = response.message.get_text_content()
+ return {"code": generated_code, "language": args.code_language, "error": ""}
except InvokeError as e:
error = str(e)
- return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
+ return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
- "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
+ "Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
)
- return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
+ return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
@@ -344,34 +355,31 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
- answer = cast(str, response.message.content)
+ answer = response.message.get_text_content()
return answer.strip()
@classmethod
- def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
+ def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
- provider=model_config.get("provider", ""),
- model=model_config.get("name", ""),
+ provider=args.model_config_data.provider,
+ model=args.model_config_data.name,
)
prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
- UserPromptMessage(content=instruction),
+ UserPromptMessage(content=args.instruction),
]
- model_parameters = model_config.get("model_parameters", {})
+ model_parameters = args.model_config_data.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
- raw_content = response.message.content
-
- if not isinstance(raw_content, str):
- raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
+ raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)
@@ -388,12 +396,17 @@ class LLMGenerator:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
- logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
+ logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def instruction_modify_legacy(
- tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
+ tenant_id: str,
+ flow_id: str,
+ current: str,
+ instruction: str,
+ model_config: ModelConfig,
+ ideal_output: str | None,
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
@@ -432,7 +445,7 @@ class LLMGenerator:
node_id: str,
current: str,
instruction: str,
- model_config: dict,
+ model_config: ModelConfig,
ideal_output: str | None,
workflow_service: WorkflowServiceInterface,
):
@@ -503,7 +516,7 @@ class LLMGenerator:
@staticmethod
def __instruction_modify_common(
tenant_id: str,
- model_config: dict,
+ model_config: ModelConfig,
last_run: dict | None,
current: str | None,
error_message: str | None,
@@ -524,8 +537,8 @@ class LLMGenerator:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
- provider=model_config.get("provider", ""),
- model=model_config.get("name", ""),
+ provider=model_config.provider,
+ model=model_config.name,
)
match node_type:
case "llm" | "agent":
@@ -568,7 +581,5 @@ class LLMGenerator:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
- logger.exception(
- "Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
- )
+ logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
return {"error": f"An unexpected error occurred: {str(e)}"}
diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py
index ec2b7f2d44..d46cf049dd 100644
--- a/api/core/llm_generator/prompts.py
+++ b/api/core/llm_generator/prompts.py
@@ -434,3 +434,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
+
+DEFAULT_GENERATOR_SUMMARY_PROMPT = (
+ """Summarize the following content. Extract only the key information and main points. """
+ """Remove redundant details.
+
+Requirements:
+1. Write a concise summary in plain text
+2. Use the same language as the input content
+3. Focus on important facts, concepts, and details
+4. If images are included, describe their key information
+5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
+6. Write directly without extra words
+
+Output only the summary text. Start summarizing now:
+
+"""
+)
diff --git a/api/core/logging/__init__.py b/api/core/logging/__init__.py
new file mode 100644
index 0000000000..db046cc9fa
--- /dev/null
+++ b/api/core/logging/__init__.py
@@ -0,0 +1,20 @@
+"""Structured logging components for Dify."""
+
+from core.logging.context import (
+ clear_request_context,
+ get_request_id,
+ get_trace_id,
+ init_request_context,
+)
+from core.logging.filters import IdentityContextFilter, TraceContextFilter
+from core.logging.structured_formatter import StructuredJSONFormatter
+
+__all__ = [
+ "IdentityContextFilter",
+ "StructuredJSONFormatter",
+ "TraceContextFilter",
+ "clear_request_context",
+ "get_request_id",
+ "get_trace_id",
+ "init_request_context",
+]
diff --git a/api/core/logging/context.py b/api/core/logging/context.py
new file mode 100644
index 0000000000..18633a0b05
--- /dev/null
+++ b/api/core/logging/context.py
@@ -0,0 +1,35 @@
+"""Request context for logging - framework agnostic.
+
+This module provides request-scoped context variables for logging,
+using Python's contextvars for thread-safe and async-safe storage.
+"""
+
+import uuid
+from contextvars import ContextVar
+
+_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
+_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
+
+
+def get_request_id() -> str:
+ """Get current request ID (10 hex chars)."""
+ return _request_id.get()
+
+
+def get_trace_id() -> str:
+ """Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
+ return _trace_id.get()
+
+
+def init_request_context() -> None:
+ """Initialize request context. Call at start of each request."""
+ req_id = uuid.uuid4().hex[:10]
+ trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
+ _request_id.set(req_id)
+ _trace_id.set(trace_id)
+
+
+def clear_request_context() -> None:
+ """Clear request context. Call at end of request (optional)."""
+ _request_id.set("")
+ _trace_id.set("")
diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py
new file mode 100644
index 0000000000..1e8aa8d566
--- /dev/null
+++ b/api/core/logging/filters.py
@@ -0,0 +1,94 @@
+"""Logging filters for structured logging."""
+
+import contextlib
+import logging
+
+import flask
+
+from core.logging.context import get_request_id, get_trace_id
+
+
+class TraceContextFilter(logging.Filter):
+ """
+ Filter that adds trace_id and span_id to log records.
+ Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
+ """
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ # Get trace context from OpenTelemetry
+ trace_id, span_id = self._get_otel_context()
+
+ # Set trace_id (fallback to ContextVar if no OTEL context)
+ if trace_id:
+ record.trace_id = trace_id
+ else:
+ record.trace_id = get_trace_id()
+
+ record.span_id = span_id or ""
+
+ # For backward compatibility, also set req_id
+ record.req_id = get_request_id()
+
+ return True
+
+ def _get_otel_context(self) -> tuple[str, str]:
+ """Extract trace_id and span_id from OpenTelemetry context."""
+ with contextlib.suppress(Exception):
+ from opentelemetry.trace import get_current_span
+ from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
+
+ span = get_current_span()
+ if span and span.get_span_context():
+ ctx = span.get_span_context()
+ if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
+ trace_id = f"{ctx.trace_id:032x}"
+ span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
+ return trace_id, span_id
+ return "", ""
+
+
+class IdentityContextFilter(logging.Filter):
+ """
+ Filter that adds user identity context to log records.
+ Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
+ """
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ identity = self._extract_identity()
+ record.tenant_id = identity.get("tenant_id", "")
+ record.user_id = identity.get("user_id", "")
+ record.user_type = identity.get("user_type", "")
+ return True
+
+ def _extract_identity(self) -> dict[str, str]:
+ """Extract identity from current_user if in request context."""
+ try:
+ if not flask.has_request_context():
+ return {}
+ from flask_login import current_user
+
+ # Check if user is authenticated using the proxy
+ if not current_user.is_authenticated:
+ return {}
+
+ # Access the underlying user object
+ user = current_user
+
+ from models import Account
+ from models.model import EndUser
+
+ identity: dict[str, str] = {}
+
+ if isinstance(user, Account):
+ if user.current_tenant_id:
+ identity["tenant_id"] = user.current_tenant_id
+ identity["user_id"] = user.id
+ identity["user_type"] = "account"
+ elif isinstance(user, EndUser):
+ identity["tenant_id"] = user.tenant_id
+ identity["user_id"] = user.id
+ identity["user_type"] = user.type or "end_user"
+
+ return identity
+ except Exception:
+ return {}
diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py
new file mode 100644
index 0000000000..4295d2dd34
--- /dev/null
+++ b/api/core/logging/structured_formatter.py
@@ -0,0 +1,107 @@
+"""Structured JSON log formatter for Dify."""
+
+import logging
+import traceback
+from datetime import UTC, datetime
+from typing import Any
+
+import orjson
+
+from configs import dify_config
+
+
+class StructuredJSONFormatter(logging.Formatter):
+ """
+ JSON log formatter following the specified schema:
+ {
+ "ts": "ISO 8601 UTC",
+ "severity": "INFO|ERROR|WARN|DEBUG",
+ "service": "service name",
+ "caller": "file:line",
+ "trace_id": "hex 32",
+ "span_id": "hex 16",
+ "identity": { "tenant_id", "user_id", "user_type" },
+ "message": "log message",
+ "attributes": { ... },
+ "stack_trace": "..."
+ }
+ """
+
+ SEVERITY_MAP: dict[int, str] = {
+ logging.DEBUG: "DEBUG",
+ logging.INFO: "INFO",
+ logging.WARNING: "WARN",
+ logging.ERROR: "ERROR",
+ logging.CRITICAL: "ERROR",
+ }
+
+ def __init__(self, service_name: str | None = None):
+ super().__init__()
+ self._service_name = service_name or dify_config.APPLICATION_NAME
+
+ def format(self, record: logging.LogRecord) -> str:
+ log_dict = self._build_log_dict(record)
+ try:
+ return orjson.dumps(log_dict).decode("utf-8")
+ except TypeError:
+ # Fallback: convert non-serializable objects to string
+ import json
+
+ return json.dumps(log_dict, default=str, ensure_ascii=False)
+
+ def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
+ # Core fields
+ log_dict: dict[str, Any] = {
+ "ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
+ "severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
+ "service": self._service_name,
+ "caller": f"{record.filename}:{record.lineno}",
+ "message": record.getMessage(),
+ }
+
+ # Trace context (from TraceContextFilter)
+ trace_id = getattr(record, "trace_id", "")
+ span_id = getattr(record, "span_id", "")
+
+ if trace_id:
+ log_dict["trace_id"] = trace_id
+ if span_id:
+ log_dict["span_id"] = span_id
+
+ # Identity context (from IdentityContextFilter)
+ identity = self._extract_identity(record)
+ if identity:
+ log_dict["identity"] = identity
+
+ # Dynamic attributes
+ attributes = getattr(record, "attributes", None)
+ if attributes:
+ log_dict["attributes"] = attributes
+
+ # Stack trace for errors with exceptions
+ if record.exc_info and record.levelno >= logging.ERROR:
+ log_dict["stack_trace"] = self._format_exception(record.exc_info)
+
+ return log_dict
+
+ def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
+ tenant_id = getattr(record, "tenant_id", None)
+ user_id = getattr(record, "user_id", None)
+ user_type = getattr(record, "user_type", None)
+
+ if not any([tenant_id, user_id, user_type]):
+ return None
+
+ identity: dict[str, str] = {}
+ if tenant_id:
+ identity["tenant_id"] = tenant_id
+ if user_id:
+ identity["user_id"] = user_id
+ if user_type:
+ identity["user_type"] = user_type
+ return identity
+
+ def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
+ if exc_info and exc_info[0] is not None:
+ return "".join(traceback.format_exception(*exc_info))
+ return ""
diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py
index 92787b39dd..aef1afb235 100644
--- a/api/core/mcp/auth/auth_flow.py
+++ b/api/core/mcp/auth/auth_flow.py
@@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
- Per SEP-985, supports fallback when discovery fails at one URL.
+ Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
+ Priority order:
+ 1. URL from WWW-Authenticate header (if provided)
+ 2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
+ 3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
"""
urls = []
@@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
- fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
- if fallback_url not in urls:
- urls.append(fallback_url)
+ path = parsed.path.rstrip("/")
+
+ # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
+ if path:
+ path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
+ if path_url not in urls:
+ urls.append(path_url)
+
+ # Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
+ root_url = f"{base_url}/.well-known/oauth-protected-resource"
+ if root_url not in urls:
+ urls.append(root_url)
return urls
@@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
- Per RFC 8414 section 3:
- - If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- - If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
-
- Example:
- - issuer: https://example.com/oauth
- - metadata: https://example.com/.well-known/oauth-authorization-server/oauth
+ Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
+ - OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
+ - OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
+ - OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
+ - OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
+ - OpenID Connect at root: https://example.com/.well-known/openid-configuration
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
- path = parsed.path.rstrip("/") # Remove trailing slash
+ path = parsed.path.rstrip("/")
+ # OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
+ urls.append(f"{base}/.well-known/oauth-authorization-server")
- # Try OpenID Connect discovery first (more common)
- urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
+ # OpenID Connect Discovery at root
+ urls.append(f"{base}/.well-known/openid-configuration")
- # OAuth 2.0 Authorization Server Metadata (RFC 8414)
- # Include the path component if present in the issuer URL
if path:
- urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
- else:
- urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
+ # OpenID Connect Discovery with path insertion
+ urls.append(f"{base}/.well-known/openid-configuration{path}")
+
+ # OpenID Connect Discovery path appending
+ urls.append(f"{base}{path}/.well-known/openid-configuration")
+
+ # OAuth 2.0 Authorization Server Metadata with path insertion
+ urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
return urls
diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py
index 24ca59ee45..1de1d5a073 100644
--- a/api/core/mcp/client/sse_client.py
+++ b/api/core/mcp/client/sse_client.py
@@ -61,6 +61,7 @@ class SSETransport:
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None
+ self.event_source: EventSource | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin.
@@ -237,6 +238,9 @@ class SSETransport:
write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue()
+ # Store event_source for graceful shutdown
+ self.event_source = event_source
+
# Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
@@ -296,6 +300,13 @@ def sse_client(
logger.exception("Error connecting to SSE endpoint")
raise
finally:
+ # Close the SSE connection to unblock the reader thread
+ if transport.event_source is not None:
+ try:
+ transport.event_source.response.close()
+ except RuntimeError:
+ pass
+
# Clean up queues
if read_queue:
read_queue.put(None)
diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py
index 805c16c838..5c3cd0d8f8 100644
--- a/api/core/mcp/client/streamable_client.py
+++ b/api/core/mcp/client/streamable_client.py
@@ -8,6 +8,7 @@ and session management.
import logging
import queue
+import threading
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@@ -103,6 +104,9 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON,
**self.headers,
}
+ self.stop_event = threading.Event()
+ self._active_responses: list[httpx.Response] = []
+ self._lock = threading.Lock()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
@@ -111,6 +115,30 @@ class StreamableHTTPTransport:
headers[MCP_SESSION_ID] = self.session_id
return headers
+ def _register_response(self, response: httpx.Response):
+ """Register a response for cleanup on shutdown."""
+ with self._lock:
+ self._active_responses.append(response)
+
+ def _unregister_response(self, response: httpx.Response):
+ """Unregister a response after it's closed."""
+ with self._lock:
+ try:
+ self._active_responses.remove(response)
+ except ValueError as e:
+ logger.debug("Ignoring error during response unregister: %s", e)
+
+ def close_active_responses(self):
+ """Close all active SSE connections to unblock threads."""
+ with self._lock:
+ responses_to_close = list(self._active_responses)
+ self._active_responses.clear()
+ for response in responses_to_close:
+ try:
+ response.close()
+ except RuntimeError as e:
+ logger.debug("Ignoring error during active response close: %s", e)
+
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
@@ -195,11 +223,21 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
- for sse in event_source.iter_sse():
- self._handle_sse_event(sse, server_to_client_queue)
+ # Register response for cleanup
+ self._register_response(event_source.response)
+
+ try:
+ for sse in event_source.iter_sse():
+ if self.stop_event.is_set():
+ logger.debug("GET stream received stop signal")
+ break
+ self._handle_sse_event(sse, server_to_client_queue)
+ finally:
+ self._unregister_response(event_source.response)
except Exception as exc:
- logger.debug("GET stream error (non-fatal): %s", exc)
+ if not self.stop_event.is_set():
+ logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE."""
@@ -224,15 +262,24 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
- for sse in event_source.iter_sse():
- is_complete = self._handle_sse_event(
- sse,
- ctx.server_to_client_queue,
- original_request_id,
- ctx.metadata.on_resumption_token_update if ctx.metadata else None,
- )
- if is_complete:
- break
+ # Register response for cleanup
+ self._register_response(event_source.response)
+
+ try:
+ for sse in event_source.iter_sse():
+ if self.stop_event.is_set():
+ logger.debug("Resumption stream received stop signal")
+ break
+ is_complete = self._handle_sse_event(
+ sse,
+ ctx.server_to_client_queue,
+ original_request_id,
+ ctx.metadata.on_resumption_token_update if ctx.metadata else None,
+ )
+ if is_complete:
+ break
+ finally:
+ self._unregister_response(event_source.response)
def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing."""
@@ -266,17 +313,20 @@ class StreamableHTTPTransport:
if is_initialization:
self._maybe_extract_session_id_from_response(response)
- content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
+ # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
+ # The server MUST NOT send a response to notifications.
+ if isinstance(message.root, JSONRPCRequest):
+ content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
- if content_type.startswith(JSON):
- self._handle_json_response(response, ctx.server_to_client_queue)
- elif content_type.startswith(SSE):
- self._handle_sse_response(response, ctx)
- else:
- self._handle_unexpected_content_type(
- content_type,
- ctx.server_to_client_queue,
- )
+ if content_type.startswith(JSON):
+ self._handle_json_response(response, ctx.server_to_client_queue)
+ elif content_type.startswith(SSE):
+ self._handle_sse_response(response, ctx)
+ else:
+ self._handle_unexpected_content_type(
+ content_type,
+ ctx.server_to_client_queue,
+ )
def _handle_json_response(
self,
@@ -295,17 +345,27 @@ class StreamableHTTPTransport:
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server."""
try:
+ # Register response for cleanup
+ self._register_response(response)
+
event_source = EventSource(response)
- for sse in event_source.iter_sse():
- is_complete = self._handle_sse_event(
- sse,
- ctx.server_to_client_queue,
- resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
- )
- if is_complete:
- break
+ try:
+ for sse in event_source.iter_sse():
+ if self.stop_event.is_set():
+ logger.debug("SSE response stream received stop signal")
+ break
+ is_complete = self._handle_sse_event(
+ sse,
+ ctx.server_to_client_queue,
+ resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
+ )
+ if is_complete:
+ break
+ finally:
+ self._unregister_response(response)
except Exception as e:
- ctx.server_to_client_queue.put(e)
+ if not self.stop_event.is_set():
+ ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
self,
@@ -345,6 +405,11 @@ class StreamableHTTPTransport:
"""
while True:
try:
+ # Check if we should stop
+ if self.stop_event.is_set():
+ logger.debug("Post writer received stop signal")
+ break
+
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None:
@@ -381,7 +446,8 @@ class StreamableHTTPTransport:
except queue.Empty:
continue
except Exception as exc:
- server_to_client_queue.put(exc)
+ if not self.stop_event.is_set():
+ server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request."""
@@ -465,6 +531,12 @@ def streamablehttp_client(
transport.get_session_id,
)
finally:
+ # Set stop event to signal all threads to stop
+ transport.stop_event.set()
+
+ # Close all active SSE connections to unblock threads
+ transport.close_active_responses()
+
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py
index 5929204ecb..2b30bfc21b 100644
--- a/api/core/mcp/mcp_client.py
+++ b/api/core/mcp/mcp_client.py
@@ -59,7 +59,7 @@ class MCPClient:
try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse")
- except MCPConnectionError:
+ except (MCPConnectionError, ValueError):
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py
index c97ae6eac7..e1a40593e7 100644
--- a/api/core/mcp/session/base_session.py
+++ b/api/core/mcp/session/base_session.py
@@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
- session: """BaseSession[
- SendRequestT,
- SendNotificationT,
- SendResultT,
- ReceiveRequestT,
- ReceiveNotificationT
- ]""",
+ session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id
@@ -353,7 +347,7 @@ class BaseSession(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
- responder = RequestResponder(
+ responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md
index a6caa7eb1e..b9d2c55210 100644
--- a/api/core/model_runtime/README.md
+++ b/api/core/model_runtime/README.md
@@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model
- Model provider display
- 
-
- Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
+ Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc.
- Selectable model list display
- 
-
After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
- In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
-
- 
-
- These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
+ In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models.
- Provider/model credential authentication
- 
-
- 
-
- The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
+ The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface.
## Structure
-
-
Model Runtime is divided into three layers:
- The outermost layer is the factory method
@@ -60,9 +46,6 @@ Model Runtime is divided into three layers:
It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
-## Next Steps
+## Documentation
-- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
-- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
-- View YAML configuration rules: [Link](./docs/en_US/schema.md)
-- Implement interface methods: [Link](./docs/en_US/interfaces.md)
+For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/).
diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md
index dfe614347a..0a8b56b3fe 100644
--- a/api/core/model_runtime/README_CN.md
+++ b/api/core/model_runtime/README_CN.md
@@ -18,34 +18,20 @@
- 模型供应商展示
- 
-
- 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
+ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。
- 可选择的模型列表展示
- 
+ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
- 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
-
- 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
-
- 
-
- 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
+ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。
- 供应商/模型凭据鉴权
- 
-
-
-
- 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
+ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。
## 结构
-
-
Model Runtime 分三层:
- 最外层为工厂方法
@@ -59,8 +45,7 @@ Model Runtime 分三层:
对于供应商/模型凭据,有两种情况
- 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
- 
+ - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
@@ -74,20 +59,6 @@ Model Runtime 分三层:
- 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
-## 下一步
+## 文档
-### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
-
-当添加后,这里将会出现一个新的供应商
-
-
-
-### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B)
-
-当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
-
-
-
-### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
-
-你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
+有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。
diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py
index 76969fea70..51c9c51257 100644
--- a/api/core/model_runtime/entities/defaults.py
+++ b/api/core/model_runtime/entities/defaults.py
@@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.MAX_TOKENS: {
"label": {
"en_US": "Max Tokens",
- "zh_Hans": "最大标记",
+ "zh_Hans": "最大 Token 数",
},
"type": "int",
"help": {
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 89dae2dbff..9e46d72893 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
@@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
TOOL = auto()
@classmethod
- def value_of(cls, value: str) -> "PromptMessageRole":
+ def value_of(cls, value: str) -> PromptMessageRole:
"""
Get value of given mode.
@@ -249,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
- if not super().is_empty() and not self.tool_calls:
- return False
-
- return True
+ return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):
diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py
index aee6ce1108..19194d162c 100644
--- a/api/core/model_runtime/entities/model_entities.py
+++ b/api/core/model_runtime/entities/model_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any
@@ -20,7 +22,7 @@ class ModelType(StrEnum):
TTS = auto()
@classmethod
- def value_of(cls, origin_model_type: str) -> "ModelType":
+ def value_of(cls, origin_model_type: str) -> ModelType:
"""
Get model type from origin model type.
@@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
JSON_SCHEMA = auto()
@classmethod
- def value_of(cls, value: Any) -> "DefaultParameterName":
+ def value_of(cls, value: Any) -> DefaultParameterName:
"""
Get parameter name from value.
diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py
index 648b209ef1..2d88751668 100644
--- a/api/core/model_runtime/entities/provider_entities.py
+++ b/api/core/model_runtime/entities/provider_entities.py
@@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
- icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
@@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
- icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models,
)
diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py
index 45f0335c2e..c3e50eaddd 100644
--- a/api/core/model_runtime/model_providers/__base/ai_model.py
+++ b/api/core/model_runtime/model_providers/__base/ai_model.py
@@ -1,10 +1,11 @@
import decimal
import hashlib
-from threading import Lock
+import logging
-from pydantic import BaseModel, ConfigDict, Field
+from pydantic import BaseModel, ConfigDict, Field, ValidationError
+from redis import RedisError
-import contexts
+from configs import dify_config
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError,
)
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
class AIModel(BaseModel):
@@ -144,34 +148,60 @@ class AIModel(BaseModel):
plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
- # sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
+ cached_schema_json = None
try:
- contexts.plugin_model_schemas.get()
- except LookupError:
- contexts.plugin_model_schemas.set({})
- contexts.plugin_model_schema_lock.set(Lock())
-
- with contexts.plugin_model_schema_lock.get():
- if cache_key in contexts.plugin_model_schemas.get():
- return contexts.plugin_model_schemas.get()[cache_key]
-
- schema = plugin_model_manager.get_model_schema(
- tenant_id=self.tenant_id,
- user_id="unknown",
- plugin_id=self.plugin_id,
- provider=self.provider_name,
- model_type=self.model_type.value,
- model=model,
- credentials=credentials or {},
+ cached_schema_json = redis_client.get(cache_key)
+ except (RedisError, RuntimeError) as exc:
+ logger.warning(
+ "Failed to read plugin model schema cache for model %s: %s",
+ model,
+ str(exc),
+ exc_info=True,
)
+ if cached_schema_json:
+ try:
+ return AIModelEntity.model_validate_json(cached_schema_json)
+ except ValidationError:
+ logger.warning(
+ "Failed to validate cached plugin model schema for model %s",
+ model,
+ exc_info=True,
+ )
+ try:
+ redis_client.delete(cache_key)
+ except (RedisError, RuntimeError) as exc:
+ logger.warning(
+ "Failed to delete invalid plugin model schema cache for model %s: %s",
+ model,
+ str(exc),
+ exc_info=True,
+ )
- if schema:
- contexts.plugin_model_schemas.get()[cache_key] = schema
+ schema = plugin_model_manager.get_model_schema(
+ tenant_id=self.tenant_id,
+ user_id="unknown",
+ plugin_id=self.plugin_id,
+ provider=self.provider_name,
+ model_type=self.model_type.value,
+ model=model,
+ credentials=credentials or {},
+ )
- return schema
+ if schema:
+ try:
+ redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
+ except (RedisError, RuntimeError) as exc:
+ logger.warning(
+ "Failed to write plugin model schema cache for model %s: %s",
+ model,
+ str(exc),
+ exc_info=True,
+ )
+
+ return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py
index c0f4c504d9..bbbdec61d1 100644
--- a/api/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/api/core/model_runtime/model_providers/__base/large_language_model.py
@@ -1,7 +1,7 @@
import logging
import time
import uuid
-from collections.abc import Generator, Sequence
+from collections.abc import Callable, Generator, Iterator, Sequence
from typing import Union
from pydantic import ConfigDict
@@ -30,6 +30,153 @@ def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
+def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None:
+ if not callbacks:
+ return
+
+ for callback in callbacks:
+ try:
+ invoke(callback)
+ except Exception as e:
+ if callback.raise_error:
+ raise
+ logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e)
+
+
+def _get_or_create_tool_call(
+ existing_tools_calls: list[AssistantPromptMessage.ToolCall],
+ tool_call_id: str,
+) -> AssistantPromptMessage.ToolCall:
+ """
+ Get or create a tool call by ID.
+
+ If `tool_call_id` is empty, returns the most recently created tool call.
+ """
+ if not tool_call_id:
+ if not existing_tools_calls:
+ raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta")
+ return existing_tools_calls[-1]
+
+ tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None)
+ if tool_call is None:
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=tool_call_id,
+ type="function",
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
+ )
+ existing_tools_calls.append(tool_call)
+
+ return tool_call
+
+
+def _merge_tool_call_delta(
+ tool_call: AssistantPromptMessage.ToolCall,
+ delta: AssistantPromptMessage.ToolCall,
+) -> None:
+ if delta.id:
+ tool_call.id = delta.id
+ if delta.type:
+ tool_call.type = delta.type
+ if delta.function.name:
+ tool_call.function.name = delta.function.name
+ if delta.function.arguments:
+ tool_call.function.arguments += delta.function.arguments
+
+
+def _build_llm_result_from_first_chunk(
+ model: str,
+ prompt_messages: Sequence[PromptMessage],
+ chunks: Iterator[LLMResultChunk],
+) -> LLMResult:
+ """
+ Build a single `LLMResult` from the first returned chunk.
+
+ This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
+
+ Note:
+ This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
+ streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
+ """
+ content = ""
+ content_list: list[PromptMessageContentUnionTypes] = []
+ usage = LLMUsage.empty_usage()
+ system_fingerprint: str | None = None
+ tools_calls: list[AssistantPromptMessage.ToolCall] = []
+
+ try:
+ first_chunk = next(chunks, None)
+ if first_chunk is not None:
+ if isinstance(first_chunk.delta.message.content, str):
+ content += first_chunk.delta.message.content
+ elif isinstance(first_chunk.delta.message.content, list):
+ content_list.extend(first_chunk.delta.message.content)
+
+ if first_chunk.delta.message.tool_calls:
+ _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
+
+ usage = first_chunk.delta.usage or LLMUsage.empty_usage()
+ system_fingerprint = first_chunk.system_fingerprint
+ finally:
+ try:
+ for _ in chunks:
+ pass
+ except Exception:
+ logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
+
+ return LLMResult(
+ model=model,
+ prompt_messages=prompt_messages,
+ message=AssistantPromptMessage(
+ content=content or content_list,
+ tool_calls=tools_calls,
+ ),
+ usage=usage,
+ system_fingerprint=system_fingerprint,
+ )
+
+
+def _invoke_llm_via_plugin(
+ *,
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ model: str,
+ credentials: dict,
+ model_parameters: dict,
+ prompt_messages: Sequence[PromptMessage],
+ tools: list[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: bool,
+) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
+ from core.plugin.impl.model import PluginModelClient
+
+ plugin_model_manager = PluginModelClient()
+ return plugin_model_manager.invoke_llm(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ plugin_id=plugin_id,
+ provider=provider,
+ model=model,
+ credentials=credentials,
+ model_parameters=model_parameters,
+ prompt_messages=list(prompt_messages),
+ tools=tools,
+ stop=list(stop) if stop else None,
+ stream=stream,
+ )
+
+
+def _normalize_non_stream_plugin_result(
+ model: str,
+ prompt_messages: Sequence[PromptMessage],
+ result: Union[LLMResult, Iterator[LLMResultChunk]],
+) -> LLMResult:
+ if isinstance(result, LLMResult):
+ return result
+ return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
+
+
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
@@ -40,42 +187,13 @@ def _increase_tool_call(
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
- def get_tool_call(tool_call_id: str):
- """
- Get or create a tool call by ID
-
- :param tool_call_id: tool call ID
- :return: existing or new tool call
- """
- if not tool_call_id:
- return existing_tools_calls[-1]
-
- _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
- if _tool_call is None:
- _tool_call = AssistantPromptMessage.ToolCall(
- id=tool_call_id,
- type="function",
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
- )
- existing_tools_calls.append(_tool_call)
-
- return _tool_call
-
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
- # get tool call
- tool_call = get_tool_call(new_tool_call.id)
- # update tool call
- if new_tool_call.id:
- tool_call.id = new_tool_call.id
- if new_tool_call.type:
- tool_call.type = new_tool_call.type
- if new_tool_call.function.name:
- tool_call.function.name = new_tool_call.function.name
- if new_tool_call.function.arguments:
- tool_call.function.arguments += new_tool_call.function.arguments
+
+ tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id)
+ _merge_tool_call_delta(tool_call, new_tool_call)
class LargeLanguageModel(AIModel):
@@ -141,10 +259,7 @@ class LargeLanguageModel(AIModel):
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
- from core.plugin.impl.model import PluginModelClient
-
- plugin_model_manager = PluginModelClient()
- result = plugin_model_manager.invoke_llm(
+ result = _invoke_llm_via_plugin(
tenant_id=self.tenant_id,
user_id=user or "unknown",
plugin_id=self.plugin_id,
@@ -154,38 +269,13 @@ class LargeLanguageModel(AIModel):
model_parameters=model_parameters,
prompt_messages=prompt_messages,
tools=tools,
- stop=list(stop) if stop else None,
+ stop=stop,
stream=stream,
)
if not stream:
- content = ""
- content_list = []
- usage = LLMUsage.empty_usage()
- system_fingerprint = None
- tools_calls: list[AssistantPromptMessage.ToolCall] = []
-
- for chunk in result:
- if isinstance(chunk.delta.message.content, str):
- content += chunk.delta.message.content
- elif isinstance(chunk.delta.message.content, list):
- content_list.extend(chunk.delta.message.content)
- if chunk.delta.message.tool_calls:
- _increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
-
- usage = chunk.delta.usage or LLMUsage.empty_usage()
- system_fingerprint = chunk.system_fingerprint
- break
-
- result = LLMResult(
- model=model,
- prompt_messages=prompt_messages,
- message=AssistantPromptMessage(
- content=content or content_list,
- tool_calls=tools_calls,
- ),
- usage=usage,
- system_fingerprint=system_fingerprint,
+ result = _normalize_non_stream_plugin_result(
+ model=model, prompt_messages=prompt_messages, result=result
)
except Exception as e:
self._trigger_invoke_error_callbacks(
@@ -204,7 +294,7 @@ class LargeLanguageModel(AIModel):
# TODO
raise self._transform_invoke_error(e)
- if stream and isinstance(result, Generator):
+ if stream and not isinstance(result, LLMResult):
return self._invoke_result_generator(
model=model,
result=result,
@@ -425,27 +515,21 @@ class LargeLanguageModel(AIModel):
:param user: unique user id
:param callbacks: callbacks
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_before_invoke(
- llm_instance=self,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning(
- "Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e
- )
+ _run_callbacks(
+ callbacks,
+ event="on_before_invoke",
+ invoke=lambda callback: callback.on_before_invoke(
+ llm_instance=self,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
def _trigger_new_chunk_callbacks(
self,
@@ -473,26 +557,22 @@ class LargeLanguageModel(AIModel):
:param stream: is stream response
:param user: unique user id
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_new_chunk(
- llm_instance=self,
- chunk=chunk,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e)
+ _run_callbacks(
+ callbacks,
+ event="on_new_chunk",
+ invoke=lambda callback: callback.on_new_chunk(
+ llm_instance=self,
+ chunk=chunk,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
def _trigger_after_invoke_callbacks(
self,
@@ -521,28 +601,22 @@ class LargeLanguageModel(AIModel):
:param user: unique user id
:param callbacks: callbacks
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_after_invoke(
- llm_instance=self,
- result=result,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning(
- "Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e
- )
+ _run_callbacks(
+ callbacks,
+ event="on_after_invoke",
+ invoke=lambda callback: callback.on_after_invoke(
+ llm_instance=self,
+ result=result,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
def _trigger_invoke_error_callbacks(
self,
@@ -571,25 +645,19 @@ class LargeLanguageModel(AIModel):
:param user: unique user id
:param callbacks: callbacks
"""
- if callbacks:
- for callback in callbacks:
- try:
- callback.on_invoke_error(
- llm_instance=self,
- ex=ex,
- model=model,
- credentials=credentials,
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=tools,
- stop=stop,
- stream=stream,
- user=user,
- )
- except Exception as e:
- if callback.raise_error:
- raise e
- else:
- logger.warning(
- "Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e
- )
+ _run_callbacks(
+ callbacks,
+ event="on_invoke_error",
+ invoke=lambda callback: callback.on_invoke_error(
+ llm_instance=self,
+ ex=ex,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ ),
+ )
diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py
index b8704ef4ed..9cfc6889ac 100644
--- a/api/core/model_runtime/model_providers/model_provider_factory.py
+++ b/api/core/model_runtime/model_providers/model_provider_factory.py
@@ -1,9 +1,15 @@
+from __future__ import annotations
+
import hashlib
import logging
from collections.abc import Sequence
from threading import Lock
+from pydantic import ValidationError
+from redis import RedisError
+
import contexts
+from configs import dify_config
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
@@ -16,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
+from extensions.ext_redis import redis_client
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
@@ -38,7 +45,7 @@ class ModelProviderFactory:
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
- def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
+ def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
@@ -76,7 +83,7 @@ class ModelProviderFactory:
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
- def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
+ def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
"""
Get plugin model provider
:param provider: provider name
@@ -173,34 +180,60 @@ class ModelProviderFactory:
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
- # sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
+ cached_schema_json = None
try:
- contexts.plugin_model_schemas.get()
- except LookupError:
- contexts.plugin_model_schemas.set({})
- contexts.plugin_model_schema_lock.set(Lock())
-
- with contexts.plugin_model_schema_lock.get():
- if cache_key in contexts.plugin_model_schemas.get():
- return contexts.plugin_model_schemas.get()[cache_key]
-
- schema = self.plugin_model_manager.get_model_schema(
- tenant_id=self.tenant_id,
- user_id="unknown",
- plugin_id=plugin_id,
- provider=provider_name,
- model_type=model_type.value,
- model=model,
- credentials=credentials or {},
+ cached_schema_json = redis_client.get(cache_key)
+ except (RedisError, RuntimeError) as exc:
+ logger.warning(
+ "Failed to read plugin model schema cache for model %s: %s",
+ model,
+ str(exc),
+ exc_info=True,
)
+ if cached_schema_json:
+ try:
+ return AIModelEntity.model_validate_json(cached_schema_json)
+ except ValidationError:
+ logger.warning(
+ "Failed to validate cached plugin model schema for model %s",
+ model,
+ exc_info=True,
+ )
+ try:
+ redis_client.delete(cache_key)
+ except (RedisError, RuntimeError) as exc:
+ logger.warning(
+ "Failed to delete invalid plugin model schema cache for model %s: %s",
+ model,
+ str(exc),
+ exc_info=True,
+ )
- if schema:
- contexts.plugin_model_schemas.get()[cache_key] = schema
+ schema = self.plugin_model_manager.get_model_schema(
+ tenant_id=self.tenant_id,
+ user_id="unknown",
+ plugin_id=plugin_id,
+ provider=provider_name,
+ model_type=model_type.value,
+ model=model,
+ credentials=credentials or {},
+ )
- return schema
+ if schema:
+ try:
+ redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
+ except (RedisError, RuntimeError) as exc:
+ logger.warning(
+ "Failed to write plugin model schema cache for model %s: %s",
+ model,
+ str(exc),
+ exc_info=True,
+ )
+
+ return schema
def get_models(
self,
@@ -281,11 +314,13 @@ class ModelProviderFactory:
elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params)
+ raise ValueError(f"Unsupported model type: {model_type}")
+
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""
Get provider icon
:param provider: provider name
- :param icon_type: icon type (icon_small or icon_large)
+ :param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
"""
@@ -309,13 +344,7 @@ class ModelProviderFactory:
else:
file_name = provider_schema.icon_small_dark.en_US
else:
- if not provider_schema.icon_large:
- raise ValueError(f"Provider {provider} does not have large icon.")
-
- if lang.lower() == "zh_hans":
- file_name = provider_schema.icon_large.zh_Hans
- else:
- file_name = provider_schema.icon_large.en_US
+ raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py
index d6bd4d2015..22ad756c91 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/core/ops/aliyun_trace/aliyun_trace.py
@@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
+from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
@@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
+ span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)
diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
index d3324f8f82..7624586367 100644
--- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py
+++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
- kind=trace_api.SpanKind.INTERNAL,
+ kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
index 20ff2d0875..9078031490 100644
--- a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
+++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
-from opentelemetry.trace import Status, StatusCode
+from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
+ span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
index 347992fa0d..a7b73e032e 100644
--- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
+++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
@@ -6,7 +6,13 @@ from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
-from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
+from openinference.semconv.trace import (
+ MessageAttributes,
+ OpenInferenceMimeTypeValues,
+ OpenInferenceSpanKindValues,
+ SpanAttributes,
+ ToolCallAttributes,
+)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
@@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
def datetime_to_nanos(dt: datetime | None) -> int:
- """Convert datetime to nanoseconds since epoch. If None, use current time."""
+ """Convert datetime to nanoseconds since epoch for Arize/Phoenix."""
if dt is None:
dt = datetime.now()
return int(dt.timestamp() * 1_000_000_000)
def error_to_string(error: Exception | str | None) -> str:
- """Convert an error to a string with traceback information."""
+ """Convert an error to a string with traceback information for Arize/Phoenix."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
@@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str:
def set_span_status(current_span: Span, error: Exception | str | None = None):
- """Set the status of the current span based on the presence of an error."""
+ """Set the status of the current span based on the presence of an error for Arize/Phoenix."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
@@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
def safe_json_dumps(obj: Any) -> str:
- """A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
+ """A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix."""
return json.dumps(obj, default=str, ensure_ascii=False)
+def wrap_span_metadata(metadata, **kwargs):
+ """Add common metatada to all trace entity types for Arize/Phoenix."""
+ metadata["created_from"] = "Dify"
+ metadata.update(kwargs)
+ return metadata
+
+
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
@@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
- workflow_metadata = {
- "workflow_run_id": trace_info.workflow_run_id or "",
- "message_id": trace_info.message_id or "",
- "workflow_app_log_id": trace_info.workflow_app_log_id or "",
- "status": trace_info.workflow_run_status or "",
- "status_message": trace_info.error or "",
- "level": "ERROR" if trace_info.error else "DEFAULT",
- "total_tokens": trace_info.total_tokens or 0,
- }
- workflow_metadata.update(trace_info.metadata)
+ file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
+
+ metadata = wrap_span_metadata(
+ trace_info.metadata,
+ trace_id=trace_info.trace_id or "",
+ message_id=trace_info.message_id or "",
+ status=trace_info.workflow_run_status or "",
+ status_message=trace_info.error or "",
+ level="ERROR" if trace_info.error else "DEFAULT",
+ trace_entity_type="workflow",
+ conversation_id=trace_info.conversation_id or "",
+ workflow_app_log_id=trace_info.workflow_app_log_id or "",
+ workflow_id=trace_info.workflow_id or "",
+ tenant_id=trace_info.tenant_id or "",
+ workflow_run_id=trace_info.workflow_run_id or "",
+ workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0,
+ workflow_run_version=trace_info.workflow_run_version or "",
+ total_tokens=trace_info.total_tokens or 0,
+ file_list=safe_json_dumps(file_list),
+ query=trace_info.query or "",
+ )
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
@@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
attributes={
- SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
- SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
- SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
+ SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
+ SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
@@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"app_id": app_id,
"app_name": node_execution.title,
"status": node_execution.status,
+ "status_message": node_execution.error or "",
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
}
)
@@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
- SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
@@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def message_trace(self, trace_info: MessageTraceInfo):
if trace_info.message_data is None:
+ logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.")
return
- file_list = cast(list[str], trace_info.file_list) or []
+ file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else []
message_file_data: MessageFile | None = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
- message_metadata = {
- "message_id": trace_info.message_id or "",
- "conversation_mode": str(trace_info.conversation_mode or ""),
- "user_id": trace_info.message_data.from_account_id or "",
- "file_list": json.dumps(file_list),
- "status": trace_info.message_data.status or "",
- "status_message": trace_info.error or "",
- "level": "ERROR" if trace_info.error else "DEFAULT",
- "total_tokens": trace_info.total_tokens or 0,
- "prompt_tokens": trace_info.message_tokens or 0,
- "completion_tokens": trace_info.answer_tokens or 0,
- "ls_provider": trace_info.message_data.model_provider or "",
- "ls_model_name": trace_info.message_data.model_id or "",
- }
- message_metadata.update(trace_info.metadata)
+ metadata = wrap_span_metadata(
+ trace_info.metadata,
+ trace_id=trace_info.trace_id or "",
+ message_id=trace_info.message_id or "",
+ status=trace_info.message_data.status or "",
+ status_message=trace_info.error or "",
+ level="ERROR" if trace_info.error else "DEFAULT",
+ trace_entity_type="message",
+ conversation_model=trace_info.conversation_model or "",
+ message_tokens=trace_info.message_tokens or 0,
+ answer_tokens=trace_info.answer_tokens or 0,
+ total_tokens=trace_info.total_tokens or 0,
+ conversation_mode=trace_info.conversation_mode or "",
+ gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0,
+ llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0,
+ is_streaming_request=trace_info.is_streaming_request or False,
+ user_id=trace_info.message_data.from_account_id or "",
+ file_list=safe_json_dumps(file_list),
+ model_provider=trace_info.message_data.model_provider or "",
+ model_id=trace_info.message_data.model_id or "",
+ )
# Add end user data if available
if trace_info.message_data.from_end_user_id:
@@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
- message_metadata["end_user_id"] = end_user_data.session_id
+ metadata["end_user_id"] = end_user_data.session_id
attributes = {
- SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
- SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
- SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
- SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
+ SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
+ SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
+ SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
+ SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
}
dify_trace_id = trace_info.trace_id or trace_info.message_id
@@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
try:
# Convert outputs to string based on type
+ outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value
if isinstance(trace_info.outputs, dict | list):
- outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
+ outputs_str = safe_json_dumps(trace_info.outputs)
+ outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
@@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
- SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: outputs_str,
- SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
- SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
+ SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
+ SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "",
}
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
@@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
+ logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.")
return
- metadata = {
- "message_id": trace_info.message_id,
- "tool_name": "moderation",
- "status": trace_info.message_data.status,
- "status_message": trace_info.message_data.error or "",
- "level": "ERROR" if trace_info.message_data.error else "DEFAULT",
- }
- metadata.update(trace_info.metadata)
+ metadata = wrap_span_metadata(
+ trace_info.metadata,
+ trace_id=trace_info.trace_id or "",
+ message_id=trace_info.message_id or "",
+ status=trace_info.message_data.status or "",
+ status_message=trace_info.message_data.error or "",
+ level="ERROR" if trace_info.message_data.error else "DEFAULT",
+ trace_entity_type="moderation",
+ model_provider=trace_info.message_data.model_provider or "",
+ model_id=trace_info.message_data.model_id or "",
+ )
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
attributes={
- SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
- SpanAttributes.OUTPUT_VALUE: json.dumps(
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
+ SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.OUTPUT_VALUE: safe_json_dumps(
{
- "action": trace_info.action,
"flagged": trace_info.flagged,
+ "action": trace_info.action,
"preset_response": trace_info.preset_response,
- "inputs": trace_info.inputs,
- },
- ensure_ascii=False,
+ "query": trace_info.query,
+ }
),
- SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
- SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
+ logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.")
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
- metadata = {
- "message_id": trace_info.message_id,
- "tool_name": "suggested_question",
- "status": trace_info.status,
- "status_message": trace_info.error or "",
- "level": "ERROR" if trace_info.error else "DEFAULT",
- "total_tokens": trace_info.total_tokens,
- "ls_provider": trace_info.model_provider or "",
- "ls_model_name": trace_info.model_id or "",
- }
- metadata.update(trace_info.metadata)
+ metadata = wrap_span_metadata(
+ trace_info.metadata,
+ trace_id=trace_info.trace_id or "",
+ message_id=trace_info.message_id or "",
+ status=trace_info.status or "",
+ status_message=trace_info.status_message or "",
+ level=trace_info.level or "",
+ trace_entity_type="suggested_question",
+ total_tokens=trace_info.total_tokens or 0,
+ from_account_id=trace_info.from_account_id or "",
+ agent_based=trace_info.agent_based or False,
+ from_source=trace_info.from_source or "",
+ model_provider=trace_info.model_provider or "",
+ model_id=trace_info.model_id or "",
+ workflow_run_id=trace_info.workflow_run_id or "",
+ )
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
attributes={
- SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
- SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
- SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
- SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
+ SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question),
+ SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
@@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
+ logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.")
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
- metadata = {
- "message_id": trace_info.message_id,
- "tool_name": "dataset_retrieval",
- "status": trace_info.message_data.status,
- "status_message": trace_info.message_data.error or "",
- "level": "ERROR" if trace_info.message_data.error else "DEFAULT",
- "ls_provider": trace_info.message_data.model_provider or "",
- "ls_model_name": trace_info.message_data.model_id or "",
- }
- metadata.update(trace_info.metadata)
+ metadata = wrap_span_metadata(
+ trace_info.metadata,
+ trace_id=trace_info.trace_id or "",
+ message_id=trace_info.message_id or "",
+ status=trace_info.message_data.status or "",
+ status_message=trace_info.error or "",
+ level="ERROR" if trace_info.error else "DEFAULT",
+ trace_entity_type="dataset_retrieval",
+ model_provider=trace_info.message_data.model_provider or "",
+ model_id=trace_info.message_data.model_id or "",
+ )
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
@@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
attributes={
- SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
- SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
- SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
- "start_time": start_time.isoformat() if start_time else "",
- "end_time": end_time.isoformat() if end_time else "",
+ SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}),
+ SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
)
try:
- if trace_info.message_data.error:
- set_span_status(span, trace_info.message_data.error)
+ if trace_info.error:
+ set_span_status(span, trace_info.error)
else:
set_span_status(span)
finally:
@@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
return
- metadata = {
- "message_id": trace_info.message_id,
- "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
- }
+ metadata = wrap_span_metadata(
+ trace_info.metadata,
+ trace_id=trace_info.trace_id or "",
+ message_id=trace_info.message_id or "",
+ status=trace_info.message_data.status or "",
+ status_message=trace_info.error or "",
+ level="ERROR" if trace_info.error else "DEFAULT",
+ trace_entity_type="tool",
+ tool_config=safe_json_dumps(trace_info.tool_config),
+ time_cost=trace_info.time_cost or 0,
+ file_url=trace_info.file_url or "",
+ )
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
- tool_params_str = (
- json.dumps(trace_info.tool_parameters, ensure_ascii=False)
- if isinstance(trace_info.tool_parameters, dict)
- else str(trace_info.tool_parameters)
- )
-
span = self.tracer.start_span(
name=trace_info.tool_name,
attributes={
- SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
- SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
- SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs),
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
+ SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.TOOL_NAME: trace_info.tool_name,
- SpanAttributes.TOOL_PARAMETERS: tool_params_str,
+ SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
if trace_info.message_data is None:
+ logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.")
return
- metadata = {
- "project_name": self.project,
- "message_id": trace_info.message_id,
- "status": trace_info.message_data.status,
- "status_message": trace_info.message_data.error or "",
- "level": "ERROR" if trace_info.message_data.error else "DEFAULT",
- }
- metadata.update(trace_info.metadata)
+ metadata = wrap_span_metadata(
+ trace_info.metadata,
+ trace_id=trace_info.trace_id or "",
+ message_id=trace_info.message_id or "",
+ status=trace_info.message_data.status or "",
+ status_message=trace_info.message_data.error or "",
+ level="ERROR" if trace_info.message_data.error else "DEFAULT",
+ trace_entity_type="generate_name",
+ model_provider=trace_info.message_data.model_provider or "",
+ model_id=trace_info.message_data.model_id or "",
+ conversation_id=trace_info.conversation_id or "",
+ tenant_id=trace_info.tenant_id,
+ )
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
self.ensure_root_span(dify_trace_id)
@@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
attributes={
- SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
- SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
- SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
- SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
- "start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
- "end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
+ SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs),
+ SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs),
+ SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
+ SpanAttributes.METADATA: safe_json_dumps(metadata),
+ SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
@@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def get_project_url(self):
+ """Build a redirect URL that forwards the user to the correct project for Arize/Phoenix."""
try:
- if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
- return "https://app.arize.com/"
- else:
- return f"{self.arize_phoenix_config.endpoint}/projects/"
+ project_name = self.arize_phoenix_config.project
+ endpoint = self.arize_phoenix_config.endpoint.rstrip("/")
+
+ # Arize
+ if isinstance(self.arize_phoenix_config, ArizeConfig):
+ return f"https://app.arize.com/?redirect_project_name={project_name}"
+
+ # Phoenix
+ return f"{endpoint}/projects/?redirect_project_name={project_name}"
+
except Exception as e:
- logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
- raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
+ logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
+ raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
- """Helper method to construct LLM attributes with passed prompts."""
- attributes = {}
+ """Construct LLM attributes with passed prompts for Arize/Phoenix."""
+ attributes: dict[str, str] = {}
+
+ def set_attribute(path: str, value: object) -> None:
+ """Store an attribute safely as a string."""
+ if value is None:
+ return
+ try:
+ if isinstance(value, (dict, list)):
+ value = safe_json_dumps(value)
+ attributes[path] = str(value)
+ except Exception:
+ attributes[path] = str(value)
+
+ def set_message_attribute(message_index: int, key: str, value: object) -> None:
+ path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
+ set_attribute(path, value)
+
+ def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
+ """Extract and assign tool call details safely."""
+ if not tool_call:
+ return
+
+ def safe_get(obj, key, default=None):
+ if isinstance(obj, dict):
+ return obj.get(key, default)
+ return getattr(obj, key, default)
+
+ function_obj = safe_get(tool_call, "function", {})
+ function_name = safe_get(function_obj, "name", "")
+ function_args = safe_get(function_obj, "arguments", {})
+ call_id = safe_get(tool_call, "id", "")
+
+ base_path = (
+ f"{SpanAttributes.LLM_INPUT_MESSAGES}."
+ f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}"
+ )
+
+ set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name)
+ set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args)
+ set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id)
+
+ # Handle list of messages
if isinstance(prompts, list):
- for i, msg in enumerate(prompts):
- if isinstance(msg, dict):
- attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
- attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
- # todo: handle assistant and tool role messages, as they don't always
- # have a text field, but may have a tool_calls field instead
- # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
- # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
- elif isinstance(prompts, dict):
- attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
- attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
- elif isinstance(prompts, str):
- attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
- attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
+ for message_index, message in enumerate(prompts):
+ if not isinstance(message, dict):
+ continue
+
+ role = message.get("role", "user")
+ content = message.get("text") or message.get("content") or ""
+
+ set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role)
+ set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content)
+
+ tool_calls = message.get("tool_calls") or []
+ if isinstance(tool_calls, list):
+ for tool_index, tool_call in enumerate(tool_calls):
+ set_tool_call_attributes(message_index, tool_index, tool_call)
+
+ # Handle single dict or plain string prompt
+ elif isinstance(prompts, (dict, str)):
+ set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts)
+ set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user")
return attributes
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index f45f15a6da..84f5bf5512 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -35,7 +35,6 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
-from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@@ -473,6 +472,9 @@ class TraceTask:
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
+ # Lazy import to avoid circular import during module initialization
+ from repositories.factory import DifyAPIRepositoryFactory
+
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py
index c00f785034..631e3b77b2 100644
--- a/api/core/ops/utils.py
+++ b/api/core/ops/utils.py
@@ -54,7 +54,7 @@ def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
- timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
+ timestamp = start_time.strftime("%Y%m%dT%H%M%S%f") + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:
diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py
index 88a3a7bd43..bfa662b9f6 100644
--- a/api/core/plugin/entities/parameters.py
+++ b/api/core/plugin/entities/parameters.py
@@ -76,7 +76,7 @@ class PluginParameter(BaseModel):
auto_generate: PluginParameterAutoGenerate | None = None
template: PluginParameterTemplate | None = None
required: bool = False
- default: Union[float, int, str, bool] | None = None
+ default: Union[float, int, str, bool, list, dict] | None = None
min: Union[float, int] | None = None
max: Union[float, int] | None = None
precision: int | None = None
diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py
index 3b83121357..6674228dc0 100644
--- a/api/core/plugin/entities/plugin_daemon.py
+++ b/api/core/plugin/entities/plugin_daemon.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
@@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
return [item.value for item in cls]
@classmethod
- def of(cls, credential_type: str) -> "CredentialType":
+ def of(cls, credential_type: str) -> CredentialType:
type_name = credential_type.lower()
if type_name in {"api-key", "api_key"}:
return cls.API_KEY
diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py
index a1c84bd5d9..7a6a598a2f 100644
--- a/api/core/plugin/impl/base.py
+++ b/api/core/plugin/impl/base.py
@@ -39,7 +39,7 @@ from core.trigger.errors import (
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
- getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
+ getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:
@@ -103,6 +103,9 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
+ # Inject traceparent header for distributed tracing
+ self._inject_trace_headers(prepared_headers)
+
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@@ -114,6 +117,31 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
+ def _inject_trace_headers(self, headers: dict[str, str]) -> None:
+ """
+ Inject W3C traceparent header for distributed tracing.
+
+ This ensures trace context is propagated to plugin daemon even if
+ HTTPXClientInstrumentor doesn't cover module-level httpx functions.
+ """
+ if not dify_config.ENABLE_OTEL:
+ return
+
+ import contextlib
+
+ # Skip if already present (case-insensitive check)
+ for key in headers:
+ if key.lower() == "traceparent":
+ return
+
+ # Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
+ with contextlib.suppress(Exception):
+ from core.helper.trace_id_helper import generate_traceparent_header
+
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+
def _stream_request(
self,
method: str,
@@ -292,18 +320,17 @@ class BasePluginClient:
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
- args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
- raise InvokeRateLimitError(description=args.get("description"))
+ raise InvokeRateLimitError(description=error_object.get("message"))
case InvokeAuthorizationError.__name__:
- raise InvokeAuthorizationError(description=args.get("description"))
+ raise InvokeAuthorizationError(description=error_object.get("message"))
case InvokeBadRequestError.__name__:
- raise InvokeBadRequestError(description=args.get("description"))
+ raise InvokeBadRequestError(description=error_object.get("message"))
case InvokeConnectionError.__name__:
- raise InvokeConnectionError(description=args.get("description"))
+ raise InvokeConnectionError(description=error_object.get("message"))
case InvokeServerUnavailableError.__name__:
- raise InvokeServerUnavailableError(description=args.get("description"))
+ raise InvokeServerUnavailableError(description=error_object.get("message"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
@@ -311,11 +338,11 @@ class BasePluginClient:
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
- raise TriggerPluginInvokeError(description=error_object.get("description"))
+ raise TriggerPluginInvokeError(description=error_object.get("message"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
- raise EventIgnoreError(description=error_object.get("description"))
+ raise EventIgnoreError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:
diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py
index 5b88742be5..2db5185a2c 100644
--- a/api/core/plugin/impl/endpoint.py
+++ b/api/core/plugin/impl/endpoint.py
@@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
+from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
+
+ This operation is idempotent: if the endpoint is already deleted (record not found),
+ it will return True instead of raising an error.
"""
- return self._request_with_plugin_daemon_response(
- "POST",
- f"plugin/{tenant_id}/endpoint/remove",
- bool,
- data={
- "endpoint_id": endpoint_id,
- },
- headers={
- "Content-Type": "application/json",
- },
- )
+ try:
+ return self._request_with_plugin_daemon_response(
+ "POST",
+ f"plugin/{tenant_id}/endpoint/remove",
+ bool,
+ data={
+ "endpoint_id": endpoint_id,
+ },
+ headers={
+ "Content-Type": "application/json",
+ },
+ )
+ except PluginDaemonInternalServerError as e:
+ # Make delete idempotent: if record is not found, consider it a success
+ if "record not found" in str(e.description).lower():
+ return True
+ raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 6c818bdc8b..fdbfca4330 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -331,7 +331,6 @@ class ProviderManager:
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
- icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types,
),
)
@@ -619,18 +618,18 @@ class ProviderManager:
)
for quota in configuration.quotas:
- if quota.quota_type == ProviderQuotaType.TRIAL:
+ if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
# Init trial provider records if not exists
- if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
+ if quota.quota_type not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
- provider_type=ProviderType.SYSTEM,
- quota_type=ProviderQuotaType.TRIAL,
- quota_limit=quota.quota_limit, # type: ignore
+ provider_type=ProviderType.SYSTEM.value,
+ quota_type=quota.quota_type,
+ quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
)
@@ -642,8 +641,8 @@ class ProviderManager:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
- Provider.provider_type == ProviderType.SYSTEM,
- Provider.quota_type == ProviderQuotaType.TRIAL,
+ Provider.provider_type == ProviderType.SYSTEM.value,
+ Provider.quota_type == quota.quota_type,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@@ -913,6 +912,22 @@ class ProviderManager:
provider_record
)
quota_configurations = []
+
+ if dify_config.EDITION == "CLOUD":
+ from services.credit_pool_service import CreditPoolService
+
+ trail_pool = CreditPoolService.get_pool(
+ tenant_id=tenant_id,
+ pool_type=ProviderQuotaType.TRIAL.value,
+ )
+ paid_pool = CreditPoolService.get_pool(
+ tenant_id=tenant_id,
+ pool_type=ProviderQuotaType.PAID.value,
+ )
+ else:
+ trail_pool = None
+ paid_pool = None
+
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@@ -933,16 +948,36 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
+ if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
+ quota_configuration = QuotaConfiguration(
+ quota_type=provider_quota.quota_type,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
+ quota_used=trail_pool.quota_used,
+ quota_limit=trail_pool.quota_limit,
+ is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
+ restrict_models=provider_quota.restrict_models,
+ )
- quota_configuration = QuotaConfiguration(
- quota_type=provider_quota.quota_type,
- quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
- quota_used=provider_record.quota_used,
- quota_limit=provider_record.quota_limit,
- is_valid=provider_record.quota_limit > provider_record.quota_used
- or provider_record.quota_limit == -1,
- restrict_models=provider_quota.restrict_models,
- )
+ elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
+ quota_configuration = QuotaConfiguration(
+ quota_type=provider_quota.quota_type,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
+ quota_used=paid_pool.quota_used,
+ quota_limit=paid_pool.quota_limit,
+ is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
+ restrict_models=provider_quota.restrict_models,
+ )
+
+ else:
+ quota_configuration = QuotaConfiguration(
+ quota_type=provider_quota.quota_type,
+ quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
+ quota_used=provider_record.quota_used,
+ quota_limit=provider_record.quota_limit,
+ is_valid=provider_record.quota_limit > provider_record.quota_used
+ or provider_record.quota_limit == -1,
+ restrict_models=provider_quota.restrict_models,
+ )
quota_configurations.append(quota_configuration)
diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py
index 9cb009035b..e182c35b99 100644
--- a/api/core/rag/cleaner/clean_processor.py
+++ b/api/core/rag/cleaner/clean_processor.py
@@ -27,26 +27,44 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
- # Remove URL but keep Markdown image URLs
- # First, temporarily replace Markdown image URLs with a placeholder
- markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
- placeholders: list[str] = []
+ # Remove URL but keep Markdown image URLs and link URLs
+ # Replace the ENTIRE markdown link/image with a single placeholder to protect
+ # the link text (which might also be a URL) from being removed
+ markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
+ markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
+ placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
- def replace_with_placeholder(match, placeholders=placeholders):
+ def replace_markdown_with_placeholder(match, placeholders=placeholders):
+ link_type = "link"
+ link_text = match.group(1)
+ url = match.group(2)
+ placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
+ placeholders.append((link_type, link_text, url))
+ return placeholder
+
+ def replace_image_with_placeholder(match, placeholders=placeholders):
+ link_type = "image"
url = match.group(1)
- placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
- placeholders.append(url)
- return f""
+ placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
+ placeholders.append((link_type, "image", url))
+ return placeholder
- text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
+ # Protect markdown links first
+ text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
+ # Then protect markdown images
+ text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
# Now remove all remaining URLs
- url_pattern = r"https?://[^\s)]+"
+ url_pattern = r"https?://\S+"
text = re.sub(url_pattern, "", text)
- # Finally, restore the Markdown image URLs
- for i, url in enumerate(placeholders):
- text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
+ # Restore the Markdown links and images
+ for i, (link_type, text_or_alt, url) in enumerate(placeholders):
+ placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
+ if link_type == "link":
+ text = text.replace(placeholder, f"[{text_or_alt}]({url})")
+ else: # image
+ text = text.replace(placeholder, f"")
return text
def filter_string(self, text):
diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py
index 97052717db..0f19ecadc8 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba.py
@@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
+
+ segment_query_stmt = db.session.query(DocumentSegment).where(
+ DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
+ )
+ if document_ids_filter:
+ segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
+
+ segments = db.session.execute(segment_query_stmt).scalars().all()
+ segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices:
- segment_query = db.session.query(DocumentSegment).where(
- DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
- )
- if document_ids_filter:
- segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
- segment = segment_query.first()
+ segment = segment_map.get(chunk_index)
if segment:
documents.append(
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index a139fba4d0..91c16ce079 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -1,4 +1,5 @@
import concurrent.futures
+import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@@ -7,12 +8,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
+from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
-from core.rag.embedding.retrieval import RetrievalSegments
+from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@@ -22,7 +24,13 @@ from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.signature import sign_upload_file
from extensions.ext_database import db
-from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
+from models.dataset import (
+ ChildChunk,
+ Dataset,
+ DocumentSegment,
+ DocumentSegmentSummary,
+ SegmentAttachmentBinding,
+)
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
@@ -35,6 +43,8 @@ default_retrieval_model = {
"score_threshold_enabled": False,
}
+logger = logging.getLogger(__name__)
+
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@@ -105,7 +115,12 @@ class RetrievalService:
)
)
- concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
+ if futures:
+ for future in concurrent.futures.as_completed(futures, timeout=3600):
+ if exceptions:
+ for f in futures:
+ f.cancel()
+ break
if exceptions:
raise ValueError(";\n".join(exceptions))
@@ -138,37 +153,47 @@ class RetrievalService:
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
- """Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
+ """Deduplicate documents in O(n) while preserving first-seen order.
+
+ Rules:
+ - For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
+ metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
+ - For non-dify documents (or dify without doc_id): deduplicate by content key
+ (provider, page_content), keeping the first occurrence.
+ """
if not documents:
return documents
- unique_documents = []
- seen_doc_ids = set()
+ # Map of dedup key -> chosen Document
+ chosen: dict[tuple, Document] = {}
+ # Preserve the order of first appearance of each dedup key
+ order: list[tuple] = []
- for document in documents:
- # For dify provider documents, use doc_id for deduplication
- if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
- doc_id = document.metadata["doc_id"]
- if doc_id not in seen_doc_ids:
- seen_doc_ids.add(doc_id)
- unique_documents.append(document)
- # If duplicate, keep the one with higher score
- elif "score" in document.metadata:
- # Find existing document with same doc_id and compare scores
- for i, existing_doc in enumerate(unique_documents):
- if (
- existing_doc.metadata
- and existing_doc.metadata.get("doc_id") == doc_id
- and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
- ):
- unique_documents[i] = document
- break
+ for doc in documents:
+ is_dify = doc.provider == "dify"
+ doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
+
+ if is_dify and doc_id:
+ key = ("dify", doc_id)
+ if key not in chosen:
+ chosen[key] = doc
+ order.append(key)
+ else:
+ # Only replace if the new one has a score and it's strictly higher
+ if "score" in doc.metadata:
+ new_score = float(doc.metadata.get("score", 0.0))
+ old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
+ if new_score > old_score:
+ chosen[key] = doc
else:
- # For non-dify documents, use content-based deduplication
- if document not in unique_documents:
- unique_documents.append(document)
+ # Content-based dedup for non-dify or dify without doc_id
+ content_key = (doc.provider or "dify", doc.page_content)
+ if content_key not in chosen:
+ chosen[content_key] = doc
+ order.append(content_key)
+ # If duplicate content appears, we keep the first occurrence (no score comparison)
- return unique_documents
+ return [chosen[k] for k in order]
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
@@ -199,6 +224,7 @@ class RetrievalService:
)
all_documents.extend(documents)
except Exception as e:
+ logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@@ -292,6 +318,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
+ logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@@ -340,6 +367,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
+ logger.error(e, exc_info=True)
exceptions.append(str(e))
@staticmethod
@@ -367,174 +395,262 @@ class RetrievalService:
.all()
}
- records = []
- include_segment_ids = set()
- segment_child_map = {}
- segment_file_map = {}
- with Session(bind=db.engine, expire_on_commit=False) as session:
- # Process documents
- for document in documents:
- segment_id = None
- attachment_info = None
- child_chunk = None
- document_id = document.metadata.get("document_id")
- if document_id not in dataset_documents:
- continue
+ valid_dataset_documents = {}
+ image_doc_ids: list[Any] = []
+ child_index_node_ids = []
+ index_node_ids = []
+ doc_to_document_map = {}
+ summary_segment_ids = set() # Track segments retrieved via summary
+ summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score
- dataset_document = dataset_documents[document_id]
- if not dataset_document:
- continue
+ # First pass: collect all document IDs and identify summary documents
+ for document in documents:
+ document_id = document.metadata.get("document_id")
+ if document_id not in dataset_documents:
+ continue
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- # Handle parent-child documents
- if document.metadata.get("doc_type") == DocType.IMAGE:
- attachment_info_dict = cls.get_segment_attachment_info(
- dataset_document.dataset_id,
- dataset_document.tenant_id,
- document.metadata.get("doc_id") or "",
- session,
- )
- if attachment_info_dict:
- attachment_info = attachment_info_dict["attachment_info"]
- segment_id = attachment_info_dict["segment_id"]
- else:
- child_index_node_id = document.metadata.get("doc_id")
- child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
- child_chunk = session.scalar(child_chunk_stmt)
+ dataset_document = dataset_documents[document_id]
+ if not dataset_document:
+ continue
+ valid_dataset_documents[document_id] = dataset_document
- if not child_chunk:
- continue
- segment_id = child_chunk.segment_id
+ doc_id = document.metadata.get("doc_id") or ""
+ doc_to_document_map[doc_id] = document
- if not segment_id:
- continue
-
- segment = (
- session.query(DocumentSegment)
- .where(
- DocumentSegment.dataset_id == dataset_document.dataset_id,
- DocumentSegment.enabled == True,
- DocumentSegment.status == "completed",
- DocumentSegment.id == segment_id,
- )
- .first()
- )
-
- if not segment:
- continue
-
- if segment.id not in include_segment_ids:
- include_segment_ids.add(segment.id)
- if child_chunk:
- child_chunk_detail = {
- "id": child_chunk.id,
- "content": child_chunk.content,
- "position": child_chunk.position,
- "score": document.metadata.get("score", 0.0),
- }
- map_detail = {
- "max_score": document.metadata.get("score", 0.0),
- "child_chunks": [child_chunk_detail],
- }
- segment_child_map[segment.id] = map_detail
- record = {
- "segment": segment,
- }
- if attachment_info:
- segment_file_map[segment.id] = [attachment_info]
- records.append(record)
- else:
- if child_chunk:
- child_chunk_detail = {
- "id": child_chunk.id,
- "content": child_chunk.content,
- "position": child_chunk.position,
- "score": document.metadata.get("score", 0.0),
- }
- if segment.id in segment_child_map:
- segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
- segment_child_map[segment.id]["max_score"] = max(
- segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
+ # Check if this is a summary document
+ is_summary = document.metadata.get("is_summary", False)
+ if is_summary:
+ # For summary documents, find the original chunk via original_chunk_id
+ original_chunk_id = document.metadata.get("original_chunk_id")
+ if original_chunk_id:
+ summary_segment_ids.add(original_chunk_id)
+ # Save summary's score for later use
+ summary_score = document.metadata.get("score")
+ if summary_score is not None:
+ try:
+ summary_score_float = float(summary_score)
+ # If the same segment has multiple summary hits, take the highest score
+ if original_chunk_id not in summary_score_map:
+ summary_score_map[original_chunk_id] = summary_score_float
+ else:
+ summary_score_map[original_chunk_id] = max(
+ summary_score_map[original_chunk_id], summary_score_float
)
- else:
- segment_child_map[segment.id] = {
- "max_score": document.metadata.get("score", 0.0),
- "child_chunks": [child_chunk_detail],
- }
- if attachment_info:
- if segment.id in segment_file_map:
- segment_file_map[segment.id].append(attachment_info)
- else:
- segment_file_map[segment.id] = [attachment_info]
- else:
- # Handle normal documents
- segment = None
- if document.metadata.get("doc_type") == DocType.IMAGE:
- attachment_info_dict = cls.get_segment_attachment_info(
- dataset_document.dataset_id,
- dataset_document.tenant_id,
- document.metadata.get("doc_id") or "",
- session,
- )
- if attachment_info_dict:
- attachment_info = attachment_info_dict["attachment_info"]
- segment_id = attachment_info_dict["segment_id"]
- document_segment_stmt = select(DocumentSegment).where(
- DocumentSegment.dataset_id == dataset_document.dataset_id,
- DocumentSegment.enabled == True,
- DocumentSegment.status == "completed",
- DocumentSegment.id == segment_id,
- )
- segment = session.scalar(document_segment_stmt)
- if segment:
- segment_file_map[segment.id] = [attachment_info]
- else:
- index_node_id = document.metadata.get("doc_id")
- if not index_node_id:
- continue
- document_segment_stmt = select(DocumentSegment).where(
- DocumentSegment.dataset_id == dataset_document.dataset_id,
- DocumentSegment.enabled == True,
- DocumentSegment.status == "completed",
- DocumentSegment.index_node_id == index_node_id,
- )
- segment = session.scalar(document_segment_stmt)
+ except (ValueError, TypeError):
+ # Skip invalid score values
+ pass
+ continue # Skip adding to other lists for summary documents
- if not segment:
- continue
- if segment.id not in include_segment_ids:
- include_segment_ids.add(segment.id)
- record = {
- "segment": segment,
- "score": document.metadata.get("score"), # type: ignore
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ if document.metadata.get("doc_type") == DocType.IMAGE:
+ image_doc_ids.append(doc_id)
+ else:
+ child_index_node_ids.append(doc_id)
+ else:
+ if document.metadata.get("doc_type") == DocType.IMAGE:
+ image_doc_ids.append(doc_id)
+ else:
+ index_node_ids.append(doc_id)
+
+ image_doc_ids = [i for i in image_doc_ids if i]
+ child_index_node_ids = [i for i in child_index_node_ids if i]
+ index_node_ids = [i for i in index_node_ids if i]
+
+ segment_ids: list[str] = []
+ index_node_segments: list[DocumentSegment] = []
+ segments: list[DocumentSegment] = []
+ attachment_map: dict[str, list[dict[str, Any]]] = {}
+ child_chunk_map: dict[str, list[ChildChunk]] = {}
+ doc_segment_map: dict[str, list[str]] = {}
+ segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
+
+ with session_factory.create_session() as session:
+ attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
+
+ for attachment in attachments:
+ segment_ids.append(attachment["segment_id"])
+ if attachment["segment_id"] in attachment_map:
+ attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
+ else:
+ attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
+ if attachment["segment_id"] in doc_segment_map:
+ doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
+ else:
+ doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
+
+ child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
+ child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
+
+ for i in child_index_nodes:
+ segment_ids.append(i.segment_id)
+ if i.segment_id in child_chunk_map:
+ child_chunk_map[i.segment_id].append(i)
+ else:
+ child_chunk_map[i.segment_id] = [i]
+ if i.segment_id in doc_segment_map:
+ doc_segment_map[i.segment_id].append(i.index_node_id)
+ else:
+ doc_segment_map[i.segment_id] = [i.index_node_id]
+
+ if index_node_ids:
+ document_segment_stmt = select(DocumentSegment).where(
+ DocumentSegment.enabled == True,
+ DocumentSegment.status == "completed",
+ DocumentSegment.index_node_id.in_(index_node_ids),
+ )
+ index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
+ for index_node_segment in index_node_segments:
+ doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
+
+ if segment_ids:
+ document_segment_stmt = select(DocumentSegment).where(
+ DocumentSegment.enabled == True,
+ DocumentSegment.status == "completed",
+ DocumentSegment.id.in_(segment_ids),
+ )
+ segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
+
+ if index_node_segments:
+ segments.extend(index_node_segments)
+
+ # Handle summary documents: query segments by original_chunk_id
+ if summary_segment_ids:
+ summary_segment_ids_list = list(summary_segment_ids)
+ summary_segment_stmt = select(DocumentSegment).where(
+ DocumentSegment.enabled == True,
+ DocumentSegment.status == "completed",
+ DocumentSegment.id.in_(summary_segment_ids_list),
+ )
+ summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
+ segments.extend(summary_segments)
+ # Add summary segment IDs to segment_ids for summary query
+ for seg in summary_segments:
+ if seg.id not in segment_ids:
+ segment_ids.append(seg.id)
+
+ # Batch query summaries for segments retrieved via summary (only enabled summaries)
+ if summary_segment_ids:
+ summaries = (
+ session.query(DocumentSegmentSummary)
+ .filter(
+ DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
+ DocumentSegmentSummary.status == "completed",
+ DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
+ )
+ .all()
+ )
+ for summary in summaries:
+ if summary.summary_content:
+ segment_summary_map[summary.chunk_id] = summary.summary_content
+
+ include_segment_ids = set()
+ segment_child_map: dict[str, dict[str, Any]] = {}
+ records: list[dict[str, Any]] = []
+
+ for segment in segments:
+ child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
+ attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
+ ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
+
+ if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ if segment.id not in include_segment_ids:
+ include_segment_ids.add(segment.id)
+ # Check if this segment was retrieved via summary
+ # Use summary score as base score if available, otherwise 0.0
+ max_score = summary_score_map.get(segment.id, 0.0)
+
+ if child_chunks or attachment_infos:
+ child_chunk_details = []
+ for child_chunk in child_chunks:
+ child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
+ if child_document:
+ child_score = child_document.metadata.get("score", 0.0)
+ else:
+ child_score = 0.0
+ child_chunk_detail = {
+ "id": child_chunk.id,
+ "content": child_chunk.content,
+ "position": child_chunk.position,
+ "score": child_score,
+ }
+ child_chunk_details.append(child_chunk_detail)
+ max_score = max(max_score, child_score)
+ for attachment_info in attachment_infos:
+ file_document = doc_to_document_map.get(attachment_info["id"])
+ if file_document:
+ max_score = max(max_score, file_document.metadata.get("score", 0.0))
+
+ map_detail = {
+ "max_score": max_score,
+ "child_chunks": child_chunk_details,
}
- if attachment_info:
- segment_file_map[segment.id] = [attachment_info]
- records.append(record)
+ segment_child_map[segment.id] = map_detail
else:
- if attachment_info:
- attachment_infos = segment_file_map.get(segment.id, [])
- if attachment_info not in attachment_infos:
- attachment_infos.append(attachment_info)
- segment_file_map[segment.id] = attachment_infos
+ # No child chunks or attachments, use summary score if available
+ summary_score = summary_score_map.get(segment.id)
+ if summary_score is not None:
+ segment_child_map[segment.id] = {
+ "max_score": summary_score,
+ "child_chunks": [],
+ }
+ record: dict[str, Any] = {
+ "segment": segment,
+ }
+ records.append(record)
+ else:
+ if segment.id not in include_segment_ids:
+ include_segment_ids.add(segment.id)
+
+ # Check if this segment was retrieved via summary
+ # Use summary score if available (summary retrieval takes priority)
+ max_score = summary_score_map.get(segment.id, 0.0)
+
+ # If not retrieved via summary, use original segment's score
+ if segment.id not in summary_score_map:
+ segment_document = doc_to_document_map.get(segment.index_node_id)
+ if segment_document:
+ max_score = max(max_score, segment_document.metadata.get("score", 0.0))
+
+ # Also consider attachment scores
+ for attachment_info in attachment_infos:
+ file_doc = doc_to_document_map.get(attachment_info["id"])
+ if file_doc:
+ max_score = max(max_score, file_doc.metadata.get("score", 0.0))
+
+ record = {
+ "segment": segment,
+ "score": max_score,
+ }
+ records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
- record["score"] = segment_child_map[record["segment"].id]["max_score"]
- if record["segment"].id in segment_file_map:
- record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
+ record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
+ if record["segment"].id in attachment_map:
+ record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
- result = []
+ result: list[RetrievalSegments] = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
- child_chunks = record.get("child_chunks")
- if not isinstance(child_chunks, list):
- child_chunks = None
+ raw_child_chunks = record.get("child_chunks")
+ child_chunks_list: list[RetrievalChildChunk] | None = None
+ if isinstance(raw_child_chunks, list):
+ # Sort by score descending
+ sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
+ child_chunks_list = [
+ RetrievalChildChunk(
+ id=chunk["id"],
+ content=chunk["content"],
+ score=chunk.get("score", 0.0),
+ position=chunk["position"],
+ )
+ for chunk in sorted_chunks
+ ]
# Extract files, ensuring it's a list or None
files = record.get("files")
@@ -549,13 +665,20 @@ class RetrievalService:
else None
)
+ # Extract summary if this segment was retrieved via summary
+ summary_content = segment_summary_map.get(segment.id)
+
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
- segment=segment, child_chunks=child_chunks, score=score, files=files
+ segment=segment,
+ child_chunks=child_chunks_list,
+ score=score,
+ files=files,
+ summary=summary_content,
)
result.append(retrieval_segment)
- return result
+ return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
except Exception as e:
db.session.rollback()
raise e
@@ -565,6 +688,8 @@ class RetrievalService:
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
+ all_documents: list[Document],
+ exceptions: list[str],
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
@@ -573,8 +698,6 @@ class RetrievalService:
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
- all_documents: list[Document] = [],
- exceptions: list[str] = [],
):
if not query and not attachment_id:
return
@@ -647,7 +770,14 @@ class RetrievalService:
document_ids_filter=document_ids_filter,
)
)
- concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
+ # Use as_completed for early error propagation - cancel remaining futures on first error
+ if futures:
+ for future in concurrent.futures.as_completed(futures, timeout=300):
+ if future.exception():
+ # Cancel remaining futures to avoid unnecessary waiting
+ for f in futures:
+ f.cancel()
+ break
if exceptions:
raise ValueError(";\n".join(exceptions))
@@ -696,3 +826,37 @@ class RetrievalService:
}
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None
+
+ @classmethod
+ def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
+ attachment_infos = []
+ upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
+ if upload_files:
+ upload_file_ids = [upload_file.id for upload_file in upload_files]
+ attachment_bindings = (
+ session.query(SegmentAttachmentBinding)
+ .where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
+ .all()
+ )
+ attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
+
+ if attachment_bindings:
+ for upload_file in upload_files:
+ attachment_binding = attachment_binding_map.get(upload_file.id)
+ attachment_info = {
+ "id": upload_file.id,
+ "name": upload_file.name,
+ "extension": "." + upload_file.extension,
+ "mime_type": upload_file.mime_type,
+ "source_url": sign_upload_file(upload_file.id, upload_file.extension),
+ "size": upload_file.size,
+ }
+ if attachment_binding:
+ attachment_infos.append(
+ {
+ "attachment_id": attachment_binding.attachment_id,
+ "attachment_info": attachment_info,
+ "segment_id": attachment_binding.segment_id,
+ }
+ )
+ return attachment_infos
diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
index a306f9ba0c..91bb71bfa6 100644
--- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
+++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import contextlib
import json
import logging
@@ -6,7 +8,7 @@ import re
import threading
import time
import uuid
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
@@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
Manages connection reuse across ClickzettaVector instances.
"""
- _instance: Optional["ClickzettaConnectionPool"] = None
+ _instance: ClickzettaConnectionPool | None = None
_lock = threading.Lock()
def __init__(self):
@@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
self._start_cleanup_thread()
@classmethod
- def get_instance(cls) -> "ClickzettaConnectionPool":
+ def get_instance(cls) -> ClickzettaConnectionPool:
"""Get singleton instance of connection pool."""
if cls._instance is None:
with cls._lock:
@@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
)
- def _create_connection(self, config: ClickzettaConfig) -> "Connection":
+ def _create_connection(self, config: ClickzettaConfig) -> Connection:
"""Create a new ClickZetta connection."""
max_retries = 3
retry_delay = 1.0
@@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
- def _configure_connection(self, connection: "Connection"):
+ def _configure_connection(self, connection: Connection):
"""Configure connection session settings."""
try:
with connection.cursor() as cursor:
@@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
except Exception:
logger.exception("Failed to configure connection, continuing with defaults")
- def _is_connection_valid(self, connection: "Connection") -> bool:
+ def _is_connection_valid(self, connection: Connection) -> bool:
"""Check if connection is still valid."""
try:
with connection.cursor() as cursor:
@@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
except Exception:
return False
- def get_connection(self, config: ClickzettaConfig) -> "Connection":
+ def get_connection(self, config: ClickzettaConfig) -> Connection:
"""Get a connection from the pool or create a new one."""
config_key = self._get_config_key(config)
@@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
# No valid connection found, create new one
return self._create_connection(config)
- def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
+ def return_connection(self, config: ClickzettaConfig, connection: Connection):
"""Return a connection to the pool."""
config_key = self._get_config_key(config)
@@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
self._connection_pool = ClickzettaConnectionPool.get_instance()
self._init_write_queue()
- def _get_connection(self) -> "Connection":
+ def _get_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._connection_pool.get_connection(self._config)
- def _return_connection(self, connection: "Connection"):
+ def _return_connection(self, connection: Connection):
"""Return a connection to the pool."""
self._connection_pool.return_connection(self._config, connection)
class ConnectionContext:
"""Context manager for borrowing and returning connections."""
- def __init__(self, vector_instance: "ClickzettaVector"):
+ def __init__(self, vector_instance: ClickzettaVector):
self.vector = vector_instance
self.connection: Connection | None = None
- def __enter__(self) -> "Connection":
+ def __enter__(self) -> Connection:
self.connection = self.vector._get_connection()
return self.connection
@@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
if self.connection:
self.vector._return_connection(self.connection)
- def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
+ def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
"""Get a connection context manager."""
return self.ConnectionContext(self)
@@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
"""Return the vector database type."""
return "clickzetta"
- def _ensure_connection(self) -> "Connection":
+ def _ensure_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._get_connection()
@@ -984,9 +986,11 @@ class ClickzettaVector(BaseVector):
# No need for dataset_id filter since each dataset has its own table
- # Use simple quote escaping for LIKE clause
- escaped_query = query.replace("'", "''")
- filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
+ # Escape special characters for LIKE clause to prevent SQL injection
+ from libs.helper import escape_like_pattern
+
+ escaped_query = escape_like_pattern(query).replace("'", "''")
+ filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""
diff --git a/api/core/rag/datasource/vdb/iris/__init__.py b/api/core/rag/datasource/vdb/iris/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py
new file mode 100644
index 0000000000..50bb2429ec
--- /dev/null
+++ b/api/core/rag/datasource/vdb/iris/iris_vector.py
@@ -0,0 +1,508 @@
+"""InterSystems IRIS vector database implementation for Dify.
+
+This module provides vector storage and retrieval using IRIS native VECTOR type
+with HNSW indexing for efficient similarity search.
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+import threading
+import uuid
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, Any
+
+from configs import dify_config
+from configs.middleware.vdb.iris_config import IrisVectorConfig
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+if TYPE_CHECKING:
+ import iris
+else:
+ try:
+ import iris
+ except ImportError:
+ iris = None # type: ignore[assignment]
+
+logger = logging.getLogger(__name__)
+
+# Singleton connection pool to minimize IRIS license usage
+_pool_lock = threading.Lock()
+_pool_instance: IrisConnectionPool | None = None
+
+
+def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool:
+ """Get or create the global IRIS connection pool (singleton pattern)."""
+ global _pool_instance # pylint: disable=global-statement
+ with _pool_lock:
+ if _pool_instance is None:
+ logger.info("Initializing IRIS connection pool")
+ _pool_instance = IrisConnectionPool(config)
+ return _pool_instance
+
+
+class IrisConnectionPool:
+ """Thread-safe connection pool for IRIS database."""
+
+ def __init__(self, config: IrisVectorConfig) -> None:
+ self.config = config
+ self._pool: list[Any] = []
+ self._lock = threading.Lock()
+ self._min_size = config.IRIS_MIN_CONNECTION
+ self._max_size = config.IRIS_MAX_CONNECTION
+ self._in_use = 0
+ self._schemas_initialized: set[str] = set() # Cache for initialized schemas
+ self._initialize_pool()
+
+ def _initialize_pool(self) -> None:
+ for _ in range(self._min_size):
+ self._pool.append(self._create_connection())
+
+ def _create_connection(self) -> Any:
+ return iris.connect(
+ hostname=self.config.IRIS_HOST,
+ port=self.config.IRIS_SUPER_SERVER_PORT,
+ namespace=self.config.IRIS_DATABASE,
+ username=self.config.IRIS_USER,
+ password=self.config.IRIS_PASSWORD,
+ )
+
+ def get_connection(self) -> Any:
+ """Get a connection from pool or create new if available."""
+ with self._lock:
+ if self._pool:
+ conn = self._pool.pop()
+ self._in_use += 1
+ return conn
+ if self._in_use < self._max_size:
+ conn = self._create_connection()
+ self._in_use += 1
+ return conn
+ raise RuntimeError("Connection pool exhausted")
+
+ def return_connection(self, conn: Any) -> None:
+ """Return connection to pool after validating it."""
+ if not conn:
+ return
+
+ # Validate connection health
+ is_valid = False
+ try:
+ cursor = conn.cursor()
+ cursor.execute("SELECT 1")
+ cursor.close()
+ is_valid = True
+ except (OSError, RuntimeError) as e:
+ logger.debug("Connection validation failed: %s", e)
+ try:
+ conn.close()
+ except (OSError, RuntimeError):
+ pass
+
+ with self._lock:
+ self._pool.append(conn if is_valid else self._create_connection())
+ self._in_use -= 1
+
+ def ensure_schema_exists(self, schema: str) -> None:
+ """Ensure schema exists in IRIS database.
+
+ This method is idempotent and thread-safe. It uses a memory cache to avoid
+ redundant database queries for already-verified schemas.
+
+ Args:
+ schema: Schema name to ensure exists
+
+ Raises:
+ Exception: If schema creation fails
+ """
+ # Fast path: check cache first (no lock needed for read-only set lookup)
+ if schema in self._schemas_initialized:
+ return
+
+ # Slow path: acquire lock and check again (double-checked locking)
+ with self._lock:
+ if schema in self._schemas_initialized:
+ return
+
+ # Get a connection to check/create schema
+ conn = self._pool[0] if self._pool else self._create_connection()
+ cursor = conn.cursor()
+ try:
+ # Check if schema exists using INFORMATION_SCHEMA
+ check_sql = """
+ SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA
+ WHERE SCHEMA_NAME = ?
+ """
+ cursor.execute(check_sql, (schema,)) # Must be tuple or list
+ exists = cursor.fetchone()[0] > 0
+
+ if not exists:
+ # Schema doesn't exist, create it
+ cursor.execute(f"CREATE SCHEMA {schema}")
+ conn.commit()
+ logger.info("Created schema: %s", schema)
+ else:
+ logger.debug("Schema already exists: %s", schema)
+
+ # Add to cache to skip future checks
+ self._schemas_initialized.add(schema)
+
+ except Exception:
+ conn.rollback()
+ logger.exception("Failed to ensure schema %s exists", schema)
+ raise
+ finally:
+ cursor.close()
+
+ def close_all(self) -> None:
+ """Close all connections (application shutdown only)."""
+ with self._lock:
+ for conn in self._pool:
+ try:
+ conn.close()
+ except (OSError, RuntimeError):
+ pass
+ self._pool.clear()
+ self._in_use = 0
+ self._schemas_initialized.clear()
+
+
+class IrisVector(BaseVector):
+ """IRIS vector database implementation using native VECTOR type and HNSW indexing."""
+
+ # Fallback score for full-text search when Rank function unavailable or TEXT_INDEX disabled
+ _FULL_TEXT_FALLBACK_SCORE = 0.5
+
+ def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
+ super().__init__(collection_name)
+ self.config = config
+ self.table_name = f"embedding_{collection_name}".upper()
+ self.schema = config.IRIS_SCHEMA or "dify"
+ self.pool = get_iris_pool(config)
+
+ def get_type(self) -> str:
+ return VectorType.IRIS
+
+ @contextmanager
+ def _get_cursor(self):
+ """Context manager for database cursor with connection pooling."""
+ conn = self.pool.get_connection()
+ cursor = conn.cursor()
+ try:
+ yield cursor
+ conn.commit()
+ except Exception:
+ conn.rollback()
+ raise
+ finally:
+ cursor.close()
+ self.pool.return_connection(conn)
+
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
+ dimension = len(embeddings[0])
+ self._create_collection(dimension)
+ return self.add_texts(texts, embeddings)
+
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]:
+ """Add documents with embeddings to the collection."""
+ added_ids = []
+ with self._get_cursor() as cursor:
+ for i, doc in enumerate(documents):
+ doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4())
+ metadata = json.dumps(doc.metadata) if doc.metadata else "{}"
+ embedding_str = json.dumps(embeddings[i])
+
+ sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)"
+ cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str))
+ added_ids.append(doc_id)
+
+ return added_ids
+
+ def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin
+ try:
+ with self._get_cursor() as cursor:
+ sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?"
+ cursor.execute(sql, (id,))
+ return cursor.fetchone() is not None
+ except (OSError, RuntimeError, ValueError):
+ return False
+
+ def delete_by_ids(self, ids: list[str]) -> None:
+ if not ids:
+ return
+
+ with self._get_cursor() as cursor:
+ placeholders = ",".join(["?" for _ in ids])
+ sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})"
+ cursor.execute(sql, ids)
+
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
+ """Delete documents by metadata field (JSON LIKE pattern matching)."""
+ with self._get_cursor() as cursor:
+ pattern = f'%"{key}": "{value}"%'
+ sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?"
+ cursor.execute(sql, (pattern,))
+
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+ """Search similar documents using VECTOR_COSINE with HNSW index."""
+ top_k = kwargs.get("top_k", 4)
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ embedding_str = json.dumps(query_vector)
+
+ with self._get_cursor() as cursor:
+ sql = f"""
+ SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score
+ FROM {self.schema}.{self.table_name}
+ ORDER BY score DESC
+ """
+ cursor.execute(sql, (embedding_str,))
+
+ docs = []
+ for row in cursor.fetchall():
+ if len(row) >= 4:
+ text, meta_str, score = row[1], row[2], float(row[3])
+ if score >= score_threshold:
+ metadata = json.loads(meta_str) if meta_str else {}
+ metadata["score"] = score
+ docs.append(Document(page_content=text, metadata=metadata))
+ return docs
+
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+ """Search documents by full-text using iFind index with BM25 relevance scoring.
+
+ When IRIS_TEXT_INDEX is enabled, this method uses the auto-generated Rank
+ function from %iFind.Index.Basic to calculate BM25 relevance scores. The Rank
+ function is automatically created with naming: {schema}.{table_name}_{index}Rank
+
+ Args:
+ query: Search query string
+ **kwargs: Optional parameters including top_k, document_ids_filter
+
+ Returns:
+ List of Document objects with relevance scores in metadata["score"]
+ """
+ top_k = kwargs.get("top_k", 5)
+ document_ids_filter = kwargs.get("document_ids_filter")
+
+ with self._get_cursor() as cursor:
+ if self.config.IRIS_TEXT_INDEX:
+ # Use iFind full-text search with auto-generated Rank function
+ text_index_name = f"idx_{self.table_name}_text"
+ # IRIS removes underscores from function names
+ table_no_underscore = self.table_name.replace("_", "")
+ index_no_underscore = text_index_name.replace("_", "")
+ rank_function = f"{self.schema}.{table_no_underscore}_{index_no_underscore}Rank"
+
+ # Build WHERE clause with document ID filter if provided
+ where_clause = f"WHERE %ID %FIND search_index({text_index_name}, ?)"
+ # First param for Rank function, second for FIND
+ params = [query, query]
+
+ if document_ids_filter:
+ # Add document ID filter
+ placeholders = ",".join("?" * len(document_ids_filter))
+ where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
+ params.extend(document_ids_filter)
+
+ sql = f"""
+ SELECT TOP {top_k}
+ id,
+ text,
+ meta,
+ {rank_function}(%ID, ?) AS score
+ FROM {self.schema}.{self.table_name}
+ {where_clause}
+ ORDER BY score DESC
+ """
+
+ logger.debug(
+ "iFind search: query='%s', index='%s', rank='%s'",
+ query,
+ text_index_name,
+ rank_function,
+ )
+
+ try:
+ cursor.execute(sql, params)
+ except Exception: # pylint: disable=broad-exception-caught
+ # Fallback to query without Rank function if it fails
+ logger.warning(
+ "Rank function '%s' failed, using fixed score",
+ rank_function,
+ exc_info=True,
+ )
+ sql_fallback = f"""
+ SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
+ FROM {self.schema}.{self.table_name}
+ {where_clause}
+ """
+ # Skip first param (for Rank function)
+ cursor.execute(sql_fallback, params[1:])
+ else:
+ # Fallback to LIKE search (IRIS_TEXT_INDEX disabled)
+ from libs.helper import ( # pylint: disable=import-outside-toplevel
+ escape_like_pattern,
+ )
+
+ escaped_query = escape_like_pattern(query)
+ query_pattern = f"%{escaped_query}%"
+
+ # Build WHERE clause with document ID filter if provided
+ where_clause = "WHERE text LIKE ? ESCAPE '\\\\'"
+ params = [query_pattern]
+
+ if document_ids_filter:
+ placeholders = ",".join("?" * len(document_ids_filter))
+ where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
+ params.extend(document_ids_filter)
+
+ sql = f"""
+ SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
+ FROM {self.schema}.{self.table_name}
+ {where_clause}
+ ORDER BY LENGTH(text) ASC
+ """
+
+ logger.debug(
+ "LIKE fallback (TEXT_INDEX disabled): query='%s'",
+ query_pattern,
+ )
+ cursor.execute(sql, params)
+
+ docs = []
+ for row in cursor.fetchall():
+ # Expecting 4 columns: id, text, meta, score
+ if len(row) >= 4:
+ text_content = row[1]
+ meta_str = row[2]
+ score_value = row[3]
+
+ metadata = json.loads(meta_str) if meta_str else {}
+ # Add score to metadata for hybrid search compatibility
+ score = float(score_value) if score_value is not None else 0.0
+ metadata["score"] = score
+
+ docs.append(Document(page_content=text_content, metadata=metadata))
+
+ logger.info(
+ "Full-text search completed: query='%s', results=%d/%d",
+ query,
+ len(docs),
+ top_k,
+ )
+
+ if not docs:
+ logger.warning("Full-text search for '%s' returned no results", query)
+
+ return docs
+
+ def delete(self) -> None:
+ """Delete the entire collection (drop table - permanent)."""
+ with self._get_cursor() as cursor:
+ sql = f"DROP TABLE {self.schema}.{self.table_name}"
+ cursor.execute(sql)
+
+ def _create_collection(self, dimension: int) -> None:
+ """Create table with VECTOR column and HNSW index.
+
+ Uses Redis lock to prevent concurrent creation attempts across multiple
+ API server instances (api, worker, worker_beat).
+ """
+ cache_key = f"vector_indexing_{self._collection_name}"
+ lock_name = f"{cache_key}_lock"
+
+ with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager
+ if redis_client.get(cache_key):
+ return
+
+ # Ensure schema exists (idempotent, cached after first call)
+ self.pool.ensure_schema_exists(self.schema)
+
+ with self._get_cursor() as cursor:
+ # Create table with VECTOR column
+ sql = f"""
+ CREATE TABLE {self.schema}.{self.table_name} (
+ id VARCHAR(255) PRIMARY KEY,
+ text CLOB,
+ meta CLOB,
+ embedding VECTOR(DOUBLE, {dimension})
+ )
+ """
+ logger.info("Creating table: %s.%s", self.schema, self.table_name)
+ cursor.execute(sql)
+
+ # Create HNSW index for vector similarity search
+ index_name = f"idx_{self.table_name}_embedding"
+ sql_index = (
+ f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} "
+ "(embedding) AS HNSW(Distance='Cosine')"
+ )
+ logger.info("Creating HNSW index: %s", index_name)
+ cursor.execute(sql_index)
+ logger.info("HNSW index created successfully: %s", index_name)
+
+ # Create full-text search index if enabled
+ logger.info(
+ "IRIS_TEXT_INDEX config value: %s (type: %s)",
+ self.config.IRIS_TEXT_INDEX,
+ type(self.config.IRIS_TEXT_INDEX),
+ )
+ if self.config.IRIS_TEXT_INDEX:
+ text_index_name = f"idx_{self.table_name}_text"
+ language = self.config.IRIS_TEXT_INDEX_LANGUAGE
+ # Fixed: Removed extra parentheses and corrected syntax
+ sql_text_index = f"""
+ CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text)
+ AS %iFind.Index.Basic
+ (LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
+ """
+ logger.info(
+ "Creating text index: %s with language: %s",
+ text_index_name,
+ language,
+ )
+ logger.info("SQL for text index: %s", sql_text_index)
+ cursor.execute(sql_text_index)
+ logger.info("Text index created successfully: %s", text_index_name)
+ else:
+ logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled")
+
+ redis_client.set(cache_key, 1, ex=3600)
+
+
+class IrisVectorFactory(AbstractVectorFactory):
+ """Factory for creating IrisVector instances."""
+
+ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector:
+ if dataset.index_struct_dict:
+ class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+ collection_name = class_prefix
+ else:
+ dataset_id = dataset.id
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+ index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name)
+ dataset.index_struct = json.dumps(index_struct_dict)
+
+ return IrisVector(
+ collection_name=collection_name,
+ config=IrisVectorConfig(
+ IRIS_HOST=dify_config.IRIS_HOST,
+ IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT,
+ IRIS_USER=dify_config.IRIS_USER,
+ IRIS_PASSWORD=dify_config.IRIS_PASSWORD,
+ IRIS_DATABASE=dify_config.IRIS_DATABASE,
+ IRIS_SCHEMA=dify_config.IRIS_SCHEMA,
+ IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL,
+ IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION,
+ IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION,
+ IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX,
+ IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE,
+ ),
+ )
diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py
index d82ab89a34..cb05c22b55 100644
--- a/api/core/rag/datasource/vdb/oracle/oraclevector.py
+++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py
@@ -289,7 +289,8 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
- if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
+ # `nr`: Person, `ns`: Location, `nt`: Organization
+ if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
current_entity += word
else:
if current_entity:
diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py
index 445a0a7f8b..0615b8312c 100644
--- a/api/core/rag/datasource/vdb/pgvector/pgvector.py
+++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py
@@ -255,7 +255,10 @@ class PGVector(BaseVector):
return
with self._get_cursor() as cur:
- cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
+ cur.execute("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
+ if not cur.fetchone():
+ cur.execute("CREATE EXTENSION vector")
+
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
index 86b6ace3f6..d080e8da58 100644
--- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
+++ b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
@@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
- # Vastbase 支持的向量维度取值范围为 [1,16000]
+ # Vastbase supports vector dimensions in the range [1, 16,000]
if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)
diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
index f8c62b908a..4a4a458f2e 100644
--- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
@@ -391,46 +391,78 @@ class QdrantVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- """Return docs most similar by bm25.
+ """Return docs most similar by full-text search.
+
+ Searches each keyword separately and merges results to ensure documents
+ matching ANY keyword are returned (OR logic). Results are capped at top_k.
+
+ Args:
+ query: Search query text. Multi-word queries are split into keywords,
+ with each keyword searched separately. Limited to 10 keywords.
+ **kwargs: Additional search parameters (top_k, document_ids_filter)
+
Returns:
- List of documents most similar to the query text and distance for each.
+ List of up to top_k unique documents matching any query keyword.
"""
from qdrant_client.http import models
- scroll_filter = models.Filter(
- must=[
- models.FieldCondition(
- key="group_id",
- match=models.MatchValue(value=self._group_id),
- ),
- models.FieldCondition(
- key="page_content",
- match=models.MatchText(text=query),
- ),
- ]
- )
+ # Build base must conditions (AND logic) for metadata filters
+ base_must_conditions: list = [
+ models.FieldCondition(
+ key="group_id",
+ match=models.MatchValue(value=self._group_id),
+ ),
+ ]
+
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
- if scroll_filter.must:
- scroll_filter.must.append(
- models.FieldCondition(
- key="metadata.document_id",
- match=models.MatchAny(any=document_ids_filter),
- )
+ base_must_conditions.append(
+ models.FieldCondition(
+ key="metadata.document_id",
+ match=models.MatchAny(any=document_ids_filter),
)
- response = self._client.scroll(
- collection_name=self._collection_name,
- scroll_filter=scroll_filter,
- limit=kwargs.get("top_k", 2),
- with_payload=True,
- with_vectors=True,
- )
- results = response[0]
- documents = []
- for result in results:
- if result:
- document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
- documents.append(document)
+ )
+
+ # Split query into keywords, deduplicate and limit to prevent DoS
+ keywords = list(dict.fromkeys(kw.strip() for kw in query.strip().split() if kw.strip()))[:10]
+
+ if not keywords:
+ return []
+
+ top_k = kwargs.get("top_k", 2)
+ seen_ids: set[str | int] = set()
+ documents: list[Document] = []
+
+ # Search each keyword separately and merge results.
+ # This ensures each keyword gets its own search, preventing one keyword's
+ # results from completely overshadowing another's due to scroll ordering.
+ for keyword in keywords:
+ scroll_filter = models.Filter(
+ must=[
+ *base_must_conditions,
+ models.FieldCondition(
+ key="page_content",
+ match=models.MatchText(text=keyword),
+ ),
+ ]
+ )
+
+ response = self._client.scroll(
+ collection_name=self._collection_name,
+ scroll_filter=scroll_filter,
+ limit=top_k,
+ with_payload=True,
+ with_vectors=True,
+ )
+ results = response[0]
+
+ for result in results:
+ if result and result.id not in seen_ids:
+ seen_ids.add(result.id)
+ document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
+ documents.append(document)
+ if len(documents) >= top_k:
+ return documents
return documents
diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py
index 3a47241293..b9772b3c08 100644
--- a/api/core/rag/datasource/vdb/vector_factory.py
+++ b/api/core/rag/datasource/vdb/vector_factory.py
@@ -163,7 +163,7 @@ class Vector:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
- case VectorType.OCEANBASE:
+ case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
@@ -187,6 +187,10 @@ class Vector:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
+ case VectorType.IRIS:
+ from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
+
+ return IrisVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py
index bc7d93a2e0..bd99a31446 100644
--- a/api/core/rag/datasource/vdb/vector_type.py
+++ b/api/core/rag/datasource/vdb/vector_type.py
@@ -27,8 +27,10 @@ class VectorType(StrEnum):
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
+ SEEKDB = "seekdb"
OPENGAUSS = "opengauss"
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"
+ IRIS = "iris"
diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
index 84d1e26b34..b48dd93f04 100644
--- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
+++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
@@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
in a Weaviate collection.
"""
+ _DOCUMENT_ID_PROPERTY = "document_id"
+
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
@@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
return []
col = self._client.collections.use(self._collection_name)
- props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
+ props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
- ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
- where = ors[0]
- for f in ors[1:]:
- where = where | f
+ where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
- ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
- where = ors[0]
- for f in ors[1:]:
- where = where | f
+ where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py
index 1fe74d3042..69adac522d 100644
--- a/api/core/rag/docstore/dataset_docstore.py
+++ b/api/core/rag/docstore/dataset_docstore.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from collections.abc import Sequence
from typing import Any
@@ -22,7 +24,7 @@ class DatasetDocumentStore:
self._document_id = document_id
@classmethod
- def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
+ def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
return cls(**config_dict)
def to_dict(self) -> dict[str, Any]:
diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py
index b54a37b49e..f6834ab87b 100644
--- a/api/core/rag/embedding/retrieval.py
+++ b/api/core/rag/embedding/retrieval.py
@@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
+ summary: str | None = None # Summary content if retrieved via summary index
diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py
index 9f66cd9a03..aec5c353f8 100644
--- a/api/core/rag/entities/citation_metadata.py
+++ b/api/core/rag/entities/citation_metadata.py
@@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel):
doc_metadata: dict[str, Any] | None = None
title: str | None = None
files: list[dict[str, Any]] | None = None
+ summary: str | None = None
diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py
index c3bfbce98f..0c42034073 100644
--- a/api/core/rag/extractor/entity/extract_setting.py
+++ b/api/core/rag/extractor/entity/extract_setting.py
@@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
"""
credential_id: str | None = None
- notion_workspace_id: str
+ notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: str
document: Document | None = None
diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py
index ea9c6bd73a..875bfd1439 100644
--- a/api/core/rag/extractor/excel_extractor.py
+++ b/api/core/rag/extractor/excel_extractor.py
@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
import os
-from typing import cast
+from typing import TypedDict
import pandas as pd
from openpyxl import load_workbook
@@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
+class Candidate(TypedDict):
+ idx: int
+ count: int
+ map: dict[int, str]
+
+
class ExcelExtractor(BaseExtractor):
"""Load Excel files.
@@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor):
file_extension = os.path.splitext(self._file_path)[-1].lower()
if file_extension == ".xlsx":
- wb = load_workbook(self._file_path, data_only=True)
- for sheet_name in wb.sheetnames:
- sheet = wb[sheet_name]
- data = sheet.values
- cols = next(data, None)
- if cols is None:
- continue
- df = pd.DataFrame(data, columns=cols)
-
- df.dropna(how="all", inplace=True)
-
- for index, row in df.iterrows():
- page_content = []
- for col_index, (k, v) in enumerate(row.items()):
- if pd.notna(v):
- cell = sheet.cell(
- row=cast(int, index) + 2, column=col_index + 1
- ) # +2 to account for header and 1-based index
- if cell.hyperlink:
- value = f"[{v}]({cell.hyperlink.target})"
- page_content.append(f'"{k}":"{value}"')
- else:
- page_content.append(f'"{k}":"{v}"')
- documents.append(
- Document(page_content=";".join(page_content), metadata={"source": self._file_path})
- )
+ wb = load_workbook(self._file_path, read_only=True, data_only=True)
+ try:
+ for sheet_name in wb.sheetnames:
+ sheet = wb[sheet_name]
+ header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
+ if not column_map:
+ continue
+ start_row = header_row_idx + 1
+ for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
+ if all(cell.value is None for cell in row):
+ continue
+ page_content = []
+ for col_idx, cell in enumerate(row):
+ value = cell.value
+ if col_idx in column_map:
+ col_name = column_map[col_idx]
+ if hasattr(cell, "hyperlink") and cell.hyperlink:
+ target = getattr(cell.hyperlink, "target", None)
+ if target:
+ value = f"[{value}]({target})"
+ if value is None:
+ value = ""
+ elif not isinstance(value, str):
+ value = str(value)
+ value = value.strip().replace('"', '\\"')
+ page_content.append(f'"{col_name}":"{value}"')
+ if page_content:
+ documents.append(
+ Document(page_content=";".join(page_content), metadata={"source": self._file_path})
+ )
+ finally:
+ wb.close()
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
@@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor):
df = excel_file.parse(sheet_name=excel_sheet_name)
df.dropna(how="all", inplace=True)
- for _, row in df.iterrows():
+ for _, series_row in df.iterrows():
page_content = []
- for k, v in row.items():
+ for k, v in series_row.items():
if pd.notna(v):
page_content.append(f'"{k}":"{v}"')
documents.append(
@@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor):
raise ValueError(f"Unsupported file extension: {file_extension}")
return documents
+
+ def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
+ """
+ Scan first N rows to find the most likely header row.
+ Returns:
+ header_row_idx: 1-based index of the header row
+ column_map: Dict mapping 0-based column index to column name
+ max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
+ """
+ # Store potential candidates: (row_index, non_empty_count, column_map)
+ candidates: list[Candidate] = []
+
+ # Limit scan to avoid performance issues on huge files
+ # We iterate manually to control the read scope
+ for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
+ # Filter out empty cells and build a temp map for this row
+ # col_idx is 0-based
+ row_map = {}
+ for col_idx, cell_value in enumerate(row):
+ if cell_value is not None and str(cell_value).strip():
+ row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
+
+ if not row_map:
+ continue
+
+ non_empty_count = len(row_map)
+
+ # Header selection heuristic (implemented):
+ # - Prefer the first row with at least 2 non-empty columns.
+ # - Fallback: choose the row with the most non-empty columns
+ # (tie-breaker: smaller row index).
+ candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
+
+ if not candidates:
+ return 0, {}, 0
+
+ # Choose the best candidate header row.
+
+ best_candidate: Candidate | None = None
+
+ # Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
+
+ for cand in candidates:
+ if cand["count"] >= 2:
+ best_candidate = cand
+ break
+
+ # Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
+ if not best_candidate:
+ # Sort by count desc, then index asc
+ candidates.sort(key=lambda x: (-x["count"], x["idx"]))
+ best_candidate = candidates[0]
+
+ # Determine max_col_idx (1-based for openpyxl)
+ # It is the index of the last valid column in our map + 1
+ max_col_idx = max(best_candidate["map"].keys()) + 1
+
+ return best_candidate["idx"], best_candidate["map"], max_col_idx
diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py
index 0f62f9c4b6..6d28ce25bc 100644
--- a/api/core/rag/extractor/extract_processor.py
+++ b/api/core/rag/extractor/extract_processor.py
@@ -112,7 +112,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
- extractor = PdfExtractor(file_path)
+ extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@@ -148,7 +148,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
- extractor = PdfExtractor(file_path)
+ extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
@@ -166,7 +166,7 @@ class ExtractProcessor:
elif extract_setting.datasource_type == DatasourceType.NOTION:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
- notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
+ notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "",
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document,
diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py
index 789ac8557d..5d6223db06 100644
--- a/api/core/rag/extractor/firecrawl/firecrawl_app.py
+++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py
@@ -25,7 +25,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
- response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
+ response = self._post_request(self._build_url("v2/scrape"), json_data, headers)
if response.status_code == 200:
response_data = response.json()
data = response_data["data"]
@@ -42,7 +42,7 @@ class FirecrawlApp:
json_data = {"url": url}
if params:
json_data.update(params)
- response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
+ response = self._post_request(self._build_url("v2/crawl"), json_data, headers)
if response.status_code == 200:
# There's also another two fields in the response: "success" (bool) and "url" (str)
job_id = response.json().get("id")
@@ -58,7 +58,7 @@ class FirecrawlApp:
if params:
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
json_data.update(params)
- response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
+ response = self._post_request(self._build_url("v2/map"), json_data, headers)
if response.status_code == 200:
return cast(dict[str, Any], response.json())
elif response.status_code in {402, 409, 500, 429, 408}:
@@ -69,7 +69,7 @@ class FirecrawlApp:
def check_crawl_status(self, job_id) -> dict[str, Any]:
headers = self._prepare_headers()
- response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
+ response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
@@ -120,6 +120,10 @@ class FirecrawlApp:
def _prepare_headers(self) -> dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
+ def _build_url(self, path: str) -> str:
+ # ensure exactly one slash between base and path, regardless of user-provided base_url
+ return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
+
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries):
response = httpx.post(url, headers=headers, json=data)
@@ -139,7 +143,11 @@ class FirecrawlApp:
return response
def _handle_error(self, response, action):
- error_message = response.json().get("error", "Unknown error occurred")
+ try:
+ payload = response.json()
+ error_message = payload.get("error") or payload.get("message") or response.text or "Unknown error occurred"
+ except json.JSONDecodeError:
+ error_message = response.text or "Unknown error occurred"
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
@@ -160,7 +168,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
- response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
+ response = self._post_request(self._build_url("v2/search"), json_data, headers)
if response.status_code == 200:
response_data = response.json()
if not response_data.get("success"):
diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py
index 5166c0c768..5b466b281c 100644
--- a/api/core/rag/extractor/helpers.py
+++ b/api/core/rag/extractor/helpers.py
@@ -45,6 +45,6 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
except concurrent.futures.TimeoutError:
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
- if all(encoding["encoding"] is None for encoding in encodings):
+ if all(encoding.encoding is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")
- return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
+ return [enc for enc in encodings if enc.encoding is not None]
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index e87ab38349..372af8fd94 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
- self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
- if not self._notion_access_token:
+ try:
+ self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
+ except Exception as e:
+ logger.warning(
+ (
+ "Failed to get Notion access token from datasource credentials: %s, "
+ "falling back to environment variable NOTION_INTEGRATION_TOKEN"
+ ),
+ e,
+ )
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
- )
+ ) from e
self._notion_access_token = integration_token
diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py
index 80530d99a6..6aabcac704 100644
--- a/api/core/rag/extractor/pdf_extractor.py
+++ b/api/core/rag/extractor/pdf_extractor.py
@@ -1,25 +1,57 @@
"""Abstract interface for document loader implementations."""
import contextlib
+import io
+import logging
+import uuid
from collections.abc import Iterator
+import pypdfium2
+import pypdfium2.raw as pdfium_c
+
+from configs import dify_config
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
+from extensions.ext_database import db
from extensions.ext_storage import storage
+from libs.datetime_utils import naive_utc_now
+from models.enums import CreatorUserRole
+from models.model import UploadFile
+
+logger = logging.getLogger(__name__)
class PdfExtractor(BaseExtractor):
- """Load pdf files.
-
+ """
+ PdfExtractor is used to extract text and images from PDF files.
Args:
- file_path: Path to the file to load.
+ file_path: Path to the PDF file.
+ tenant_id: Workspace ID.
+ user_id: ID of the user performing the extraction.
+ file_cache_key: Optional cache key for the extracted text.
"""
- def __init__(self, file_path: str, file_cache_key: str | None = None):
- """Initialize with file path."""
+ # Magic bytes for image format detection: (magic_bytes, extension, mime_type)
+ IMAGE_FORMATS = [
+ (b"\xff\xd8\xff", "jpg", "image/jpeg"),
+ (b"\x89PNG\r\n\x1a\n", "png", "image/png"),
+ (b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
+ (b"GIF8", "gif", "image/gif"),
+ (b"BM", "bmp", "image/bmp"),
+ (b"II*\x00", "tiff", "image/tiff"),
+ (b"MM\x00*", "tiff", "image/tiff"),
+ (b"II+\x00", "tiff", "image/tiff"),
+ (b"MM\x00+", "tiff", "image/tiff"),
+ ]
+ MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
+
+ def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
+ """Initialize PdfExtractor."""
self._file_path = file_path
+ self._tenant_id = tenant_id
+ self._user_id = user_id
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
@@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
- import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
@@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
+
+ image_content = self._extract_images(page)
+ if image_content:
+ content += "\n" + image_content
+
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
+
+ def _extract_images(self, page) -> str:
+ """
+ Extract images from a PDF page, save them to storage and database,
+ and return markdown image links.
+
+ Args:
+ page: pypdfium2 page object.
+
+ Returns:
+ Markdown string containing links to the extracted images.
+ """
+ image_content = []
+ upload_files = []
+ base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
+
+ try:
+ image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
+ for obj in image_objects:
+ try:
+ # Extract image bytes
+ img_byte_arr = io.BytesIO()
+ # Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
+ # Fallback to png for other formats
+ obj.extract(img_byte_arr, fb_format="png")
+ img_bytes = img_byte_arr.getvalue()
+
+ if not img_bytes:
+ continue
+
+ header = img_bytes[: self.MAX_MAGIC_LEN]
+ image_ext = None
+ mime_type = None
+ for magic, ext, mime in self.IMAGE_FORMATS:
+ if header.startswith(magic):
+ image_ext = ext
+ mime_type = mime
+ break
+
+ if not image_ext or not mime_type:
+ continue
+
+ file_uuid = str(uuid.uuid4())
+ file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
+
+ storage.save(file_key, img_bytes)
+
+ # save file to db
+ upload_file = UploadFile(
+ tenant_id=self._tenant_id,
+ storage_type=dify_config.STORAGE_TYPE,
+ key=file_key,
+ name=file_key,
+ size=len(img_bytes),
+ extension=image_ext,
+ mime_type=mime_type,
+ created_by=self._user_id,
+ created_by_role=CreatorUserRole.ACCOUNT,
+ created_at=naive_utc_now(),
+ used=True,
+ used_by=self._user_id,
+ used_at=naive_utc_now(),
+ )
+ upload_files.append(upload_file)
+ image_content.append(f"")
+ except Exception as e:
+ logger.warning("Failed to extract image from PDF: %s", e)
+ continue
+ except Exception as e:
+ logger.warning("Failed to get objects from PDF page: %s", e)
+ if upload_files:
+ db.session.add_all(upload_files)
+ db.session.commit()
+ return "\n".join(image_content)
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index c7a5568866..1ddbfc5864 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -1,4 +1,7 @@
-"""Abstract interface for document loader implementations."""
+"""Word (.docx) document extractor used for RAG ingestion.
+
+Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
+"""
import logging
import mimetypes
@@ -7,10 +10,10 @@ import re
import tempfile
import uuid
from urllib.parse import urlparse
-from xml.etree import ElementTree
-import httpx
from docx import Document as DocxDocument
+from docx.oxml.ns import qn
+from docx.text.run import Run
from configs import dify_config
from core.helper import ssrf_proxy
@@ -43,7 +46,7 @@ class WordExtractor(BaseExtractor):
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
- response = httpx.get(self.file_path, timeout=None)
+ response = ssrf_proxy.get(self.file_path)
if response.status_code != 200:
response.close()
@@ -54,6 +57,7 @@ class WordExtractor(BaseExtractor):
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
try:
self.temp_file.write(response.content)
+ self.temp_file.flush()
finally:
response.close()
self.file_path = self.temp_file.name
@@ -83,23 +87,46 @@ class WordExtractor(BaseExtractor):
def _extract_images_from_docx(self, doc):
image_count = 0
image_map = {}
+ base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
- for rel in doc.part.rels.values():
+ for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
image_count += 1
if rel.is_external:
url = rel.target_ref
- response = ssrf_proxy.get(url)
+ if not self._is_valid_url(url):
+ continue
+ try:
+ response = ssrf_proxy.get(url)
+ except Exception as e:
+ logger.warning("Failed to download image from URL: %s: %s", url, str(e))
+ continue
if response.status_code == 200:
- image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
+ image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
if image_ext is None:
continue
file_uuid = str(uuid.uuid4())
- file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
+ file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, response.content)
- else:
- continue
+ # save file to db
+ upload_file = UploadFile(
+ tenant_id=self.tenant_id,
+ storage_type=dify_config.STORAGE_TYPE,
+ key=file_key,
+ name=file_key,
+ size=0,
+ extension=str(image_ext),
+ mime_type=mime_type or "",
+ created_by=self.user_id,
+ created_by_role=CreatorUserRole.ACCOUNT,
+ created_at=naive_utc_now(),
+ used=True,
+ used_by=self.user_id,
+ used_at=naive_utc_now(),
+ )
+ db.session.add(upload_file)
+ image_map[r_id] = f""
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
@@ -110,27 +137,25 @@ class WordExtractor(BaseExtractor):
mime_type, _ = mimetypes.guess_type(file_key)
storage.save(file_key, rel.target_part.blob)
- # save file to db
- upload_file = UploadFile(
- tenant_id=self.tenant_id,
- storage_type=dify_config.STORAGE_TYPE,
- key=file_key,
- name=file_key,
- size=0,
- extension=str(image_ext),
- mime_type=mime_type or "",
- created_by=self.user_id,
- created_by_role=CreatorUserRole.ACCOUNT,
- created_at=naive_utc_now(),
- used=True,
- used_by=self.user_id,
- used_at=naive_utc_now(),
- )
-
- db.session.add(upload_file)
- db.session.commit()
- image_map[rel.target_part] = f""
-
+ # save file to db
+ upload_file = UploadFile(
+ tenant_id=self.tenant_id,
+ storage_type=dify_config.STORAGE_TYPE,
+ key=file_key,
+ name=file_key,
+ size=0,
+ extension=str(image_ext),
+ mime_type=mime_type or "",
+ created_by=self.user_id,
+ created_by_role=CreatorUserRole.ACCOUNT,
+ created_at=naive_utc_now(),
+ used=True,
+ used_by=self.user_id,
+ used_at=naive_utc_now(),
+ )
+ db.session.add(upload_file)
+ image_map[rel.target_part] = f""
+ db.session.commit()
return image_map
def _table_to_markdown(self, table, image_map):
@@ -186,11 +211,17 @@ class WordExtractor(BaseExtractor):
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if not image_id:
continue
- image_part = paragraph.part.rels[image_id].target_part
-
- if image_part in image_map:
- image_link = image_map[image_part]
- paragraph_content.append(image_link)
+ rel = paragraph.part.rels.get(image_id)
+ if rel is None:
+ continue
+ # For external images, use image_id as key; for internal, use target_part
+ if rel.is_external:
+ if image_id in image_map:
+ paragraph_content.append(image_map[image_id])
+ else:
+ image_part = rel.target_part
+ if image_part in image_map:
+ paragraph_content.append(image_map[image_part])
else:
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()
@@ -202,32 +233,20 @@ class WordExtractor(BaseExtractor):
image_map = self._extract_images_from_docx(doc)
- hyperlinks_url = None
- url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
- for para in doc.paragraphs:
- for run in para.runs:
- if run.text and hyperlinks_url:
- result = f" [{run.text}]({hyperlinks_url}) "
- run.text = result
- hyperlinks_url = None
- if "HYPERLINK" in run.element.xml:
- try:
- xml = ElementTree.XML(run.element.xml)
- x_child = [c for c in xml.iter() if c is not None]
- for x in x_child:
- if x is None:
- continue
- if x.tag.endswith("instrText"):
- if x.text is None:
- continue
- for i in url_pattern.findall(x.text):
- hyperlinks_url = str(i)
- except Exception:
- logger.exception("Failed to parse HYPERLINK xml")
-
def parse_paragraph(paragraph):
- paragraph_content = []
- for run in paragraph.runs:
+ def append_image_link(image_id, has_drawing, target_buffer):
+ """Helper to append image link from image_map based on relationship type."""
+ rel = doc.part.rels[image_id]
+ if rel.is_external:
+ if image_id in image_map and not has_drawing:
+ target_buffer.append(image_map[image_id])
+ else:
+ image_part = rel.target_part
+ if image_part in image_map and not has_drawing:
+ target_buffer.append(image_map[image_part])
+
+ def process_run(run, target_buffer):
+ # Helper to extract text and embedded images from a run element and append them to target_buffer
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
drawing_elements = run.element.findall(
@@ -243,10 +262,18 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
if embed_id:
- image_part = doc.part.related_parts.get(embed_id)
- if image_part in image_map:
- has_drawing = True
- paragraph_content.append(image_map[image_part])
+ rel = doc.part.rels.get(embed_id)
+ if rel is not None and rel.is_external:
+ # External image: use embed_id as key
+ if embed_id in image_map:
+ has_drawing = True
+ target_buffer.append(image_map[embed_id])
+ else:
+ # Internal image: use target_part as key
+ image_part = doc.part.related_parts.get(embed_id)
+ if image_part in image_map:
+ has_drawing = True
+ target_buffer.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
@@ -261,9 +288,7 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
- image_part = doc.part.rels[image_id].target_part
- if image_part in image_map and not has_drawing:
- paragraph_content.append(image_map[image_part])
+ append_image_link(image_id, has_drawing, target_buffer)
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
@@ -271,11 +296,93 @@ class WordExtractor(BaseExtractor):
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
- image_part = doc.part.rels[image_id].target_part
- if image_part in image_map and not has_drawing:
- paragraph_content.append(image_map[image_part])
+ append_image_link(image_id, has_drawing, target_buffer)
if run.text.strip():
- paragraph_content.append(run.text.strip())
+ target_buffer.append(run.text.strip())
+
+ def process_hyperlink(hyperlink_elem, target_buffer):
+ # Helper to extract text from a hyperlink element and append it to target_buffer
+ r_id = hyperlink_elem.get(qn("r:id"))
+
+ # Extract text from runs inside the hyperlink
+ link_text_parts = []
+ for run_elem in hyperlink_elem.findall(qn("w:r")):
+ run = Run(run_elem, paragraph)
+ # Hyperlink text may be split across multiple runs (e.g., with different formatting),
+ # so collect all run texts first
+ if run.text:
+ link_text_parts.append(run.text)
+
+ link_text = "".join(link_text_parts).strip()
+
+ # Resolve URL
+ if r_id:
+ try:
+ rel = doc.part.rels.get(r_id)
+ if rel and rel.is_external:
+ link_text = f"[{link_text or rel.target_ref}]({rel.target_ref})"
+ except Exception:
+ logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
+
+ if link_text:
+ target_buffer.append(link_text)
+
+ paragraph_content = []
+ # State for legacy HYPERLINK fields
+ hyperlink_field_url = None
+ hyperlink_field_text_parts: list = []
+ is_collecting_field_text = False
+ # Iterate through paragraph elements in document order
+ for child in paragraph._element:
+ tag = child.tag
+ if tag == qn("w:r"):
+ # Regular run
+ run = Run(child, paragraph)
+
+ # Check for fldChar (begin/end/separate) and instrText for legacy hyperlinks
+ fld_chars = child.findall(qn("w:fldChar"))
+ instr_texts = child.findall(qn("w:instrText"))
+
+ # Handle Fields
+ if fld_chars or instr_texts:
+ # Process instrText to find HYPERLINK "url"
+ for instr in instr_texts:
+ if instr.text and "HYPERLINK" in instr.text:
+ # Quick regex to extract URL
+ match = re.search(r'HYPERLINK\s+"([^"]+)"', instr.text, re.IGNORECASE)
+ if match:
+ hyperlink_field_url = match.group(1)
+
+ # Process fldChar
+ for fld_char in fld_chars:
+ fld_char_type = fld_char.get(qn("w:fldCharType"))
+ if fld_char_type == "begin":
+ # Start of a field: reset legacy link state
+ hyperlink_field_url = None
+ hyperlink_field_text_parts = []
+ is_collecting_field_text = False
+ elif fld_char_type == "separate":
+ # Separator: if we found a URL, start collecting visible text
+ if hyperlink_field_url:
+ is_collecting_field_text = True
+ elif fld_char_type == "end":
+ # End of field
+ if is_collecting_field_text and hyperlink_field_url:
+ # Create markdown link and append to main content
+ display_text = "".join(hyperlink_field_text_parts).strip()
+ if display_text:
+ link_md = f"[{display_text}]({hyperlink_field_url})"
+ paragraph_content.append(link_md)
+ # Reset state
+ hyperlink_field_url = None
+ hyperlink_field_text_parts = []
+ is_collecting_field_text = False
+
+ # Decide where to append content
+ target_buffer = hyperlink_field_text_parts if is_collecting_field_text else paragraph_content
+ process_run(run, target_buffer)
+ elif tag == qn("w:hyperlink"):
+ process_hyperlink(child, paragraph_content)
return "".join(paragraph_content) if paragraph_content else ""
paragraphs = doc.paragraphs.copy()
diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py
index 9ad69e7fe3..7c270a32d0 100644
--- a/api/core/rag/index_processor/constant/built_in_field.py
+++ b/api/core/rag/index_processor/constant/built_in_field.py
@@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum):
notion_import = "notion"
local_file = "file_upload"
online_document = "online_document"
+ online_drive = "online_drive"
diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py
index 8a28eb477a..151a3de7d9 100644
--- a/api/core/rag/index_processor/index_processor_base.py
+++ b/api/core/rag/index_processor/index_processor_base.py
@@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
+from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
@@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC):
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
+ @abstractmethod
+ def generate_summary_preview(
+ self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
+ ) -> list[PreviewDetail]:
+ """
+ For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
+ The summary can be stored in a new attribute, e.g., summary.
+ This method should be implemented by subclasses.
+ """
+ raise NotImplementedError
+
@abstractmethod
def load(
self,
@@ -231,7 +243,7 @@ class BaseIndexProcessor(ABC):
if not filename:
parsed_url = urlparse(image_url)
- # unquote 处理 URL 中的中文
+ # Decode percent-encoded characters in the URL path.
path = unquote(parsed_url.path)
filename = os.path.basename(path)
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index cf68cff7dc..ab91e29145 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -1,9 +1,27 @@
"""Paragraph index processor."""
+import logging
+import re
import uuid
from collections.abc import Mapping
-from typing import Any
+from typing import Any, cast
+logger = logging.getLogger(__name__)
+
+from core.entities.knowledge_entities import PreviewDetail
+from core.file import File, FileTransferMethod, FileType, file_manager
+from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
+from core.model_manager import ModelInstance
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+from core.model_runtime.entities.message_entities import (
+ ImagePromptMessageContent,
+ PromptMessage,
+ PromptMessageContentUnionTypes,
+ TextPromptMessageContent,
+ UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import ModelFeature, ModelType
+from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
@@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
+from core.workflow.nodes.llm import llm_utils
+from extensions.ext_database import db
+from factories.file_factory import build_from_mapping
from libs import helper
+from models import UploadFile
from models.account import Account
-from models.dataset import Dataset, DatasetProcessRule
+from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
+from services.summary_index_service import SummaryIndexService
class ParagraphIndexProcessor(BaseIndexProcessor):
@@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
+ # Note: Summary indexes are now disabled (not deleted) when segments are disabled.
+ # This method is called for actual deletion scenarios (e.g., when segment is deleted).
+ # For disable operations, disable_summaries_for_segments is called directly in the task.
+ # Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
+ delete_summaries = kwargs.get("delete_summaries", False)
+ if delete_summaries:
+ if node_ids:
+ # Find segments by index_node_id
+ segments = (
+ db.session.query(DocumentSegment)
+ .filter(
+ DocumentSegment.dataset_id == dataset.id,
+ DocumentSegment.index_node_id.in_(node_ids),
+ )
+ .all()
+ )
+ segment_ids = [segment.id for segment in segments]
+ if segment_ids:
+ SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
+ else:
+ # Delete all summaries for the dataset
+ SummaryIndexService.delete_summaries_for_segments(dataset, None)
+
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:
@@ -227,3 +273,322 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
}
else:
raise ValueError("Chunks is not a list")
+
+ def generate_summary_preview(
+ self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
+ ) -> list[PreviewDetail]:
+ """
+ For each segment, concurrently call generate_summary to generate a summary
+ and write it to the summary attribute of PreviewDetail.
+ In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
+ """
+ import concurrent.futures
+
+ from flask import current_app
+
+ # Capture Flask app context for worker threads
+ flask_app = None
+ try:
+ flask_app = current_app._get_current_object() # type: ignore
+ except RuntimeError:
+ logger.warning("No Flask application context available, summary generation may fail")
+
+ def process(preview: PreviewDetail) -> None:
+ """Generate summary for a single preview item."""
+ if flask_app:
+ # Ensure Flask app context in worker thread
+ with flask_app.app_context():
+ summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
+ preview.summary = summary
+ else:
+ # Fallback: try without app context (may fail)
+ summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
+ preview.summary = summary
+
+ # Generate summaries concurrently using ThreadPoolExecutor
+ # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
+ timeout_seconds = min(300, 60 * len(preview_texts))
+ errors: list[Exception] = []
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
+ futures = [executor.submit(process, preview) for preview in preview_texts]
+ # Wait for all tasks to complete with timeout
+ done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
+
+ # Cancel tasks that didn't complete in time
+ if not_done:
+ timeout_error_msg = (
+ f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
+ )
+ logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
+ # In preview mode, timeout is also an error
+ errors.append(TimeoutError(timeout_error_msg))
+ for future in not_done:
+ future.cancel()
+ # Wait a bit for cancellation to take effect
+ concurrent.futures.wait(not_done, timeout=5)
+
+ # Collect exceptions from completed futures
+ for future in done:
+ try:
+ future.result() # This will raise any exception that occurred
+ except Exception as e:
+ logger.exception("Error in summary generation future")
+ errors.append(e)
+
+ # In preview mode (indexing-estimate), if there are any errors, fail the request
+ if errors:
+ error_messages = [str(e) for e in errors]
+ error_summary = (
+ f"Failed to generate summaries for {len(errors)} chunk(s). "
+ f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
+ )
+ if len(errors) > 3:
+ error_summary += f" (and {len(errors) - 3} more)"
+ logger.error("Summary generation failed in preview mode: %s", error_summary)
+ raise ValueError(error_summary)
+
+ return preview_texts
+
+ @staticmethod
+ def generate_summary(
+ tenant_id: str,
+ text: str,
+ summary_index_setting: dict | None = None,
+ segment_id: str | None = None,
+ ) -> tuple[str, LLMUsage]:
+ """
+ Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
+ and supports vision models by including images from the segment attachments or text content.
+
+ Args:
+ tenant_id: Tenant ID
+ text: Text content to summarize
+ summary_index_setting: Summary index configuration
+ segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
+
+ Returns:
+ Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
+ """
+ if not summary_index_setting or not summary_index_setting.get("enable"):
+ raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
+
+ model_name = summary_index_setting.get("model_name")
+ model_provider_name = summary_index_setting.get("model_provider_name")
+ summary_prompt = summary_index_setting.get("summary_prompt")
+
+ if not model_name or not model_provider_name:
+ raise ValueError("model_name and model_provider_name are required in summary_index_setting")
+
+ # Import default summary prompt
+ if not summary_prompt:
+ summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
+
+ provider_manager = ProviderManager()
+ provider_model_bundle = provider_manager.get_provider_model_bundle(
+ tenant_id, model_provider_name, ModelType.LLM
+ )
+ model_instance = ModelInstance(provider_model_bundle, model_name)
+
+ # Get model schema to check if vision is supported
+ model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
+ supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
+
+ # Extract images if model supports vision
+ image_files = []
+ if supports_vision:
+ # First, try to get images from SegmentAttachmentBinding (preferred method)
+ if segment_id:
+ image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
+
+ # If no images from attachments, fall back to extracting from text
+ if not image_files:
+ image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
+
+ # Build prompt messages
+ prompt_messages = []
+
+ if image_files:
+ # If we have images, create a UserPromptMessage with both text and images
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
+
+ # Add images first
+ for file in image_files:
+ try:
+ file_content = file_manager.to_prompt_message_content(
+ file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
+ )
+ prompt_message_contents.append(file_content)
+ except Exception as e:
+ logger.warning("Failed to convert image file to prompt message content: %s", str(e))
+ continue
+
+ # Add text content
+ if prompt_message_contents: # Only add text if we successfully added images
+ prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+ else:
+ # If image conversion failed, fall back to text-only
+ prompt = f"{summary_prompt}\n{text}"
+ prompt_messages.append(UserPromptMessage(content=prompt))
+ else:
+ # No images, use simple text prompt
+ prompt = f"{summary_prompt}\n{text}"
+ prompt_messages.append(UserPromptMessage(content=prompt))
+
+ result = model_instance.invoke_llm(
+ prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False
+ )
+
+ # Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator
+ if not isinstance(result, LLMResult):
+ raise ValueError("Expected LLMResult when stream=False")
+
+ summary_content = getattr(result.message, "content", "")
+ usage = result.usage
+
+ # Deduct quota for summary generation (same as workflow nodes)
+ try:
+ llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
+ except Exception as e:
+ # Log but don't fail summary generation if quota deduction fails
+ logger.warning("Failed to deduct quota for summary generation: %s", str(e))
+
+ return summary_content, usage
+
+ @staticmethod
+ def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
+ """
+ Extract images from markdown text and convert them to File objects.
+
+ Args:
+ tenant_id: Tenant ID
+ text: Text content that may contain markdown image links
+
+ Returns:
+ List of File objects representing images found in the text
+ """
+ # Extract markdown images using regex pattern
+ pattern = r"!\[.*?\]\((.*?)\)"
+ images = re.findall(pattern, text)
+
+ if not images:
+ return []
+
+ upload_file_id_list = []
+
+ for image in images:
+ # For data before v0.10.0
+ pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
+ match = re.search(pattern, image)
+ if match:
+ upload_file_id = match.group(1)
+ upload_file_id_list.append(upload_file_id)
+ continue
+
+ # For data after v0.10.0
+ pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
+ match = re.search(pattern, image)
+ if match:
+ upload_file_id = match.group(1)
+ upload_file_id_list.append(upload_file_id)
+ continue
+
+ # For tools directory - direct file formats (e.g., .png, .jpg, etc.)
+ pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
+ match = re.search(pattern, image)
+ if match:
+ # Tool files are handled differently, skip for now
+ continue
+
+ if not upload_file_id_list:
+ return []
+
+ # Get unique IDs for database query
+ unique_upload_file_ids = list(set(upload_file_id_list))
+ upload_files = (
+ db.session.query(UploadFile)
+ .where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
+ .all()
+ )
+
+ # Create File objects from UploadFile records
+ file_objects = []
+ for upload_file in upload_files:
+ # Only process image files
+ if not upload_file.mime_type or "image" not in upload_file.mime_type:
+ continue
+
+ mapping = {
+ "upload_file_id": upload_file.id,
+ "transfer_method": FileTransferMethod.LOCAL_FILE.value,
+ "type": FileType.IMAGE.value,
+ }
+
+ try:
+ file_obj = build_from_mapping(
+ mapping=mapping,
+ tenant_id=tenant_id,
+ )
+ file_objects.append(file_obj)
+ except Exception as e:
+ logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
+ continue
+
+ return file_objects
+
+ @staticmethod
+ def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
+ """
+ Extract images from SegmentAttachmentBinding table (preferred method).
+ This matches how DatasetRetrieval gets segment attachments.
+
+ Args:
+ tenant_id: Tenant ID
+ segment_id: Segment ID to fetch attachments for
+
+ Returns:
+ List of File objects representing images found in segment attachments
+ """
+ from sqlalchemy import select
+
+ # Query attachments from SegmentAttachmentBinding table
+ attachments_with_bindings = db.session.execute(
+ select(SegmentAttachmentBinding, UploadFile)
+ .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
+ .where(
+ SegmentAttachmentBinding.segment_id == segment_id,
+ SegmentAttachmentBinding.tenant_id == tenant_id,
+ )
+ ).all()
+
+ if not attachments_with_bindings:
+ return []
+
+ file_objects = []
+ for _, upload_file in attachments_with_bindings:
+ # Only process image files
+ if not upload_file.mime_type or "image" not in upload_file.mime_type:
+ continue
+
+ try:
+ # Create File object directly (similar to DatasetRetrieval)
+ file_obj = File(
+ id=upload_file.id,
+ filename=upload_file.name,
+ extension="." + upload_file.extension,
+ mime_type=upload_file.mime_type,
+ tenant_id=tenant_id,
+ type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.LOCAL_FILE,
+ remote_url=upload_file.source_url,
+ related_id=upload_file.id,
+ size=upload_file.size,
+ storage_key=upload_file.key,
+ )
+ file_objects.append(file_obj)
+ except Exception as e:
+ logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
+ continue
+
+ return file_objects
diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py
index 0366f3259f..961df2e50c 100644
--- a/api/core/rag/index_processor/processor/parent_child_index_processor.py
+++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py
@@ -1,11 +1,14 @@
"""Paragraph index processor."""
import json
+import logging
import uuid
from collections.abc import Mapping
from typing import Any
from configs import dify_config
+from core.db.session_factory import session_factory
+from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
+from services.summary_index_service import SummaryIndexService
+
+logger = logging.getLogger(__name__)
class ParentChildIndexProcessor(BaseIndexProcessor):
@@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
+ # Note: Summary indexes are now disabled (not deleted) when segments are disabled.
+ # This method is called for actual deletion scenarios (e.g., when segment is deleted).
+ # For disable operations, disable_summaries_for_segments is called directly in the task.
+ # Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
+ delete_summaries = kwargs.get("delete_summaries", False)
+ if delete_summaries:
+ if node_ids:
+ # Find segments by index_node_id
+ with session_factory.create_session() as session:
+ segments = (
+ session.query(DocumentSegment)
+ .filter(
+ DocumentSegment.dataset_id == dataset.id,
+ DocumentSegment.index_node_id.in_(node_ids),
+ )
+ .all()
+ )
+ segment_ids = [segment.id for segment in segments]
+ if segment_ids:
+ SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
+ else:
+ # Delete all summaries for the dataset
+ SummaryIndexService.delete_summaries_for_segments(dataset, None)
+
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
@@ -326,3 +356,91 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),
}
+
+ def generate_summary_preview(
+ self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
+ ) -> list[PreviewDetail]:
+ """
+ For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
+ and write it to the summary attribute of PreviewDetail.
+ In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
+
+ Note: For parent-child structure, we only generate summaries for parent chunks.
+ """
+ import concurrent.futures
+
+ from flask import current_app
+
+ # Capture Flask app context for worker threads
+ flask_app = None
+ try:
+ flask_app = current_app._get_current_object() # type: ignore
+ except RuntimeError:
+ logger.warning("No Flask application context available, summary generation may fail")
+
+ def process(preview: PreviewDetail) -> None:
+ """Generate summary for a single preview item (parent chunk)."""
+ from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
+
+ if flask_app:
+ # Ensure Flask app context in worker thread
+ with flask_app.app_context():
+ summary, _ = ParagraphIndexProcessor.generate_summary(
+ tenant_id=tenant_id,
+ text=preview.content,
+ summary_index_setting=summary_index_setting,
+ )
+ preview.summary = summary
+ else:
+ # Fallback: try without app context (may fail)
+ summary, _ = ParagraphIndexProcessor.generate_summary(
+ tenant_id=tenant_id,
+ text=preview.content,
+ summary_index_setting=summary_index_setting,
+ )
+ preview.summary = summary
+
+ # Generate summaries concurrently using ThreadPoolExecutor
+ # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
+ timeout_seconds = min(300, 60 * len(preview_texts))
+ errors: list[Exception] = []
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
+ futures = [executor.submit(process, preview) for preview in preview_texts]
+ # Wait for all tasks to complete with timeout
+ done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
+
+ # Cancel tasks that didn't complete in time
+ if not_done:
+ timeout_error_msg = (
+ f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
+ )
+ logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
+ # In preview mode, timeout is also an error
+ errors.append(TimeoutError(timeout_error_msg))
+ for future in not_done:
+ future.cancel()
+ # Wait a bit for cancellation to take effect
+ concurrent.futures.wait(not_done, timeout=5)
+
+ # Collect exceptions from completed futures
+ for future in done:
+ try:
+ future.result() # This will raise any exception that occurred
+ except Exception as e:
+ logger.exception("Error in summary generation future")
+ errors.append(e)
+
+ # In preview mode (indexing-estimate), if there are any errors, fail the request
+ if errors:
+ error_messages = [str(e) for e in errors]
+ error_summary = (
+ f"Failed to generate summaries for {len(errors)} chunk(s). "
+ f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
+ )
+ if len(errors) > 3:
+ error_summary += f" (and {len(errors) - 3} more)"
+ logger.error("Summary generation failed in preview mode: %s", error_summary)
+ raise ValueError(error_summary)
+
+ return preview_texts
diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py
index 1183d5fbd7..272d2ed351 100644
--- a/api/core/rag/index_processor/processor/qa_index_processor.py
+++ b/api/core/rag/index_processor/processor/qa_index_processor.py
@@ -11,6 +11,8 @@ import pandas as pd
from flask import Flask, current_app
from werkzeug.datastructures import FileStorage
+from core.db.session_factory import session_factory
+from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
-from models.dataset import Dataset
+from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
+from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor):
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
+ # Note: Summary indexes are now disabled (not deleted) when segments are disabled.
+ # This method is called for actual deletion scenarios (e.g., when segment is deleted).
+ # For disable operations, disable_summaries_for_segments is called directly in the task.
+ # Note: qa_model doesn't generate summaries, but we clean them for completeness
+ # Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
+ delete_summaries = kwargs.get("delete_summaries", False)
+ if delete_summaries:
+ if node_ids:
+ # Find segments by index_node_id
+ with session_factory.create_session() as session:
+ segments = (
+ session.query(DocumentSegment)
+ .filter(
+ DocumentSegment.dataset_id == dataset.id,
+ DocumentSegment.index_node_id.in_(node_ids),
+ )
+ .all()
+ )
+ segment_ids = [segment.id for segment in segments]
+ if segment_ids:
+ SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
+ else:
+ # Delete all summaries for the dataset
+ SummaryIndexService.delete_summaries_for_segments(dataset, None)
+
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
@@ -212,6 +240,17 @@ class QAIndexProcessor(BaseIndexProcessor):
"total_segments": len(qa_chunks.qa_chunks),
}
+ def generate_summary_preview(
+ self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
+ ) -> list[PreviewDetail]:
+ """
+ QA model doesn't generate summaries, so this method returns preview_texts unchanged.
+
+ Note: QA model uses question-answer pairs, which don't require summary generation.
+ """
+ # QA model doesn't generate summaries, return as-is
+ return preview_texts
+
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
diff --git a/api/core/rag/pipeline/queue.py b/api/core/rag/pipeline/queue.py
index 7472598a7f..bf8db95b4e 100644
--- a/api/core/rag/pipeline/queue.py
+++ b/api/core/rag/pipeline/queue.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from collections.abc import Sequence
from typing import Any
@@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
return self.model_dump_json()
@classmethod
- def deserialize(cls, serialized_data: str) -> "TaskWrapper":
+ def deserialize(cls, serialized_data: str) -> TaskWrapper:
return cls.model_validate_json(serialized_data)
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 635eab73f0..541c241ae5 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
-from sqlalchemy import and_, or_, select
+from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@@ -151,20 +151,14 @@ class DatasetRetrieval:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
- for dataset_id in dataset_ids:
- # get dataset from dataset id
- dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
- dataset = db.session.scalar(dataset_stmt)
- # pass if dataset is not available
- if not dataset:
+ dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
+ datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
+ for dataset in datasets:
+ if dataset.available_document_count == 0 and dataset.provider != "external":
continue
-
- # pass if dataset is not available
- if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
- continue
-
available_datasets.append(dataset)
+
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
@@ -242,20 +236,24 @@ class DatasetRetrieval:
if records:
for record in records:
segment = record.segment
+ # Build content: if summary exists, add it before the segment content
if segment.answer:
- document_context_list.append(
- DocumentContext(
- content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
- score=record.score,
- )
- )
+ segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
else:
- document_context_list.append(
- DocumentContext(
- content=segment.get_sign_content(),
- score=record.score,
- )
+ segment_content = segment.get_sign_content()
+
+ # If summary exists, prepend it to the content
+ if record.summary:
+ final_content = f"{record.summary}\n{segment_content}"
+ else:
+ final_content = segment_content
+
+ document_context_list.append(
+ DocumentContext(
+ content=final_content,
+ score=record.score,
)
+ )
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
@@ -282,26 +280,35 @@ class DatasetRetrieval:
)
context_files.append(attachment_info)
if show_retrieve_source:
+ dataset_ids = [record.segment.dataset_id for record in records]
+ document_ids = [record.segment.document_id for record in records]
+ dataset_document_stmt = select(DatasetDocument).where(
+ DatasetDocument.id.in_(document_ids),
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
+ dataset_stmt = select(Dataset).where(
+ Dataset.id.in_(dataset_ids),
+ )
+ datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
+ dataset_map = {i.id: i for i in datasets}
+ document_map = {i.id: i for i in documents}
for record in records:
segment = record.segment
- dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
- dataset_document_stmt = select(DatasetDocument).where(
- DatasetDocument.id == segment.document_id,
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- document = db.session.scalar(dataset_document_stmt)
- if dataset and document:
+ dataset_item = dataset_map.get(segment.dataset_id)
+ document_item = document_map.get(segment.document_id)
+ if dataset_item and document_item:
source = RetrievalSourceMetadata(
- dataset_id=dataset.id,
- dataset_name=dataset.name,
- document_id=document.id,
- document_name=document.name,
- data_source_type=document.data_source_type,
+ dataset_id=dataset_item.id,
+ dataset_name=dataset_item.name,
+ document_id=document_item.id,
+ document_name=document_item.name,
+ data_source_type=document_item.data_source_type,
segment_id=segment.id,
retriever_from=invoke_from.to_source(),
score=record.score or 0.0,
- doc_metadata=document.doc_metadata,
+ doc_metadata=document_item.doc_metadata,
)
if invoke_from.to_source() == "dev":
@@ -313,6 +320,9 @@ class DatasetRetrieval:
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
+ # Add summary if this segment was retrieved via summary
+ if hasattr(record, "summary") and record.summary:
+ source.summary = record.summary
retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
@@ -512,7 +522,11 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
+ dataset_count = len(available_datasets)
with measure_time() as timer:
+ cancel_event = threading.Event()
+ thread_exceptions: list[Exception] = []
+
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
@@ -531,6 +545,9 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
+ "dataset_count": dataset_count,
+ "cancel_event": cancel_event,
+ "thread_exceptions": thread_exceptions,
},
)
all_threads.append(query_thread)
@@ -554,12 +571,26 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
+ "dataset_count": dataset_count,
+ "cancel_event": cancel_event,
+ "thread_exceptions": thread_exceptions,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
- for thread in all_threads:
- thread.join()
+
+ # Poll threads with short timeout to detect errors quickly (fail-fast)
+ while any(t.is_alive() for t in all_threads):
+ for thread in all_threads:
+ thread.join(timeout=0.1)
+ if thread_exceptions:
+ cancel_event.set()
+ break
+ if thread_exceptions:
+ break
+
+ if thread_exceptions:
+ raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
@@ -1033,7 +1064,7 @@ class DatasetRetrieval:
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
- self._process_metadata_filter_func(
+ self.process_metadata_filter_func(
sequence,
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
@@ -1069,7 +1100,7 @@ class DatasetRetrieval:
value=expected_value,
)
)
- filters = self._process_metadata_filter_func(
+ filters = self.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@@ -1165,26 +1196,33 @@ class DatasetRetrieval:
return None
return automatic_metadata_filters
- def _process_metadata_filter_func(
- self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
+ @classmethod
+ def process_metadata_filter_func(
+ cls, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return filters
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
+ from libs.helper import escape_like_pattern
+
match condition:
case "contains":
- filters.append(json_field.like(f"%{value}%"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))
case "not contains":
- filters.append(json_field.notlike(f"%{value}%"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))
case "start with":
- filters.append(json_field.like(f"{value}%"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.like(f"{escaped_value}%", escape="\\"))
case "end with":
- filters.append(json_field.like(f"%{value}"))
+ escaped_value = escape_like_pattern(str(value))
+ filters.append(json_field.like(f"%{escaped_value}", escape="\\"))
case "is" | "=":
if isinstance(value, str):
@@ -1215,6 +1253,20 @@ class DatasetRetrieval:
case "≥" | ">=":
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
+ case "in" | "not in":
+ if isinstance(value, str):
+ value_list = [v.strip() for v in value.split(",") if v.strip()]
+ elif isinstance(value, (list, tuple)):
+ value_list = [str(v) for v in value if v is not None]
+ else:
+ value_list = [str(value)] if value is not None else []
+
+ if not value_list:
+ # `field in []` is False, `field not in []` is True
+ filters.append(literal(condition == "not in"))
+ else:
+ op = json_field.in_ if condition == "in" else json_field.notin_
+ filters.append(op(value_list))
case _:
pass
@@ -1386,69 +1438,89 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
+ dataset_count: int,
+ cancel_event: threading.Event | None = None,
+ thread_exceptions: list[Exception] | None = None,
):
- with flask_app.app_context():
- threads = []
- all_documents_item: list[Document] = []
- index_type = None
- for dataset in available_datasets:
- index_type = dataset.indexing_technique
- document_ids_filter = None
- if dataset.provider != "external":
- if metadata_condition and not metadata_filter_document_ids:
- continue
- if metadata_filter_document_ids:
- document_ids = metadata_filter_document_ids.get(dataset.id, [])
- if document_ids:
- document_ids_filter = document_ids
- else:
+ try:
+ with flask_app.app_context():
+ threads = []
+ all_documents_item: list[Document] = []
+ index_type = None
+ for dataset in available_datasets:
+ # Check for cancellation signal
+ if cancel_event and cancel_event.is_set():
+ break
+ index_type = dataset.indexing_technique
+ document_ids_filter = None
+ if dataset.provider != "external":
+ if metadata_condition and not metadata_filter_document_ids:
continue
- retrieval_thread = threading.Thread(
- target=self._retriever,
- kwargs={
- "flask_app": flask_app,
- "dataset_id": dataset.id,
- "query": query,
- "top_k": top_k,
- "all_documents": all_documents_item,
- "document_ids_filter": document_ids_filter,
- "metadata_condition": metadata_condition,
- "attachment_ids": [attachment_id] if attachment_id else None,
- },
- )
- threads.append(retrieval_thread)
- retrieval_thread.start()
- for thread in threads:
- thread.join()
+ if metadata_filter_document_ids:
+ document_ids = metadata_filter_document_ids.get(dataset.id, [])
+ if document_ids:
+ document_ids_filter = document_ids
+ else:
+ continue
+ retrieval_thread = threading.Thread(
+ target=self._retriever,
+ kwargs={
+ "flask_app": flask_app,
+ "dataset_id": dataset.id,
+ "query": query,
+ "top_k": top_k,
+ "all_documents": all_documents_item,
+ "document_ids_filter": document_ids_filter,
+ "metadata_condition": metadata_condition,
+ "attachment_ids": [attachment_id] if attachment_id else None,
+ },
+ )
+ threads.append(retrieval_thread)
+ retrieval_thread.start()
- if reranking_enable:
- # do rerank for searched documents
- data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
- if query:
- all_documents_item = data_post_processor.invoke(
- query=query,
- documents=all_documents_item,
- score_threshold=score_threshold,
- top_n=top_k,
- query_type=QueryType.TEXT_QUERY,
- )
- if attachment_id:
- all_documents_item = data_post_processor.invoke(
- documents=all_documents_item,
- score_threshold=score_threshold,
- top_n=top_k,
- query_type=QueryType.IMAGE_QUERY,
- query=attachment_id,
- )
- else:
- if index_type == IndexTechniqueType.ECONOMY:
- if not query:
- all_documents_item = []
- else:
- all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
- elif index_type == IndexTechniqueType.HIGH_QUALITY:
- all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
+ # Poll threads with short timeout to respond quickly to cancellation
+ while any(t.is_alive() for t in threads):
+ for thread in threads:
+ thread.join(timeout=0.1)
+ if cancel_event and cancel_event.is_set():
+ break
+ if cancel_event and cancel_event.is_set():
+ break
+
+ # Skip second reranking when there is only one dataset
+ if reranking_enable and dataset_count > 1:
+ # do rerank for searched documents
+ data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
+ if query:
+ all_documents_item = data_post_processor.invoke(
+ query=query,
+ documents=all_documents_item,
+ score_threshold=score_threshold,
+ top_n=top_k,
+ query_type=QueryType.TEXT_QUERY,
+ )
+ if attachment_id:
+ all_documents_item = data_post_processor.invoke(
+ documents=all_documents_item,
+ score_threshold=score_threshold,
+ top_n=top_k,
+ query_type=QueryType.IMAGE_QUERY,
+ query=attachment_id,
+ )
else:
- all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
- if all_documents_item:
- all_documents.extend(all_documents_item)
+ if index_type == IndexTechniqueType.ECONOMY:
+ if not query:
+ all_documents_item = []
+ else:
+ all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
+ elif index_type == IndexTechniqueType.HIGH_QUALITY:
+ all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
+ else:
+ all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
+ if all_documents_item:
+ all_documents.extend(all_documents_item)
+ except Exception as e:
+ if cancel_event:
+ cancel_event.set()
+ if thread_exceptions is not None:
+ thread_exceptions.append(e)
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 801d2a2a52..b65cb14d8e 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import codecs
import re
from typing import Any
@@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
- self._fixed_separator = fixed_separator
+ self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
def split_text(self, text: str) -> list[str]:
@@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = re.split(r" +", text)
else:
splits = text.split(separator)
- splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
+ if self._keep_separator:
+ splits = [s + separator for s in splits[:-1]] + splits[-1:]
else:
splits = list(text)
if separator == "\n":
@@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
- _separator = separator if self._keep_separator else ""
+ _separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
if separator != "":
for s, s_len in zip(splits, s_lens):
diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py
index 51bfae1cd3..b87fba4eaa 100644
--- a/api/core/schemas/registry.py
+++ b/api/core/schemas/registry.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
+
import json
import logging
import threading
from collections.abc import Mapping, MutableMapping
from pathlib import Path
-from typing import Any, ClassVar, Optional
+from typing import Any, ClassVar
class SchemaRegistry:
@@ -11,7 +13,7 @@ class SchemaRegistry:
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
- _default_instance: ClassVar[Optional["SchemaRegistry"]] = None
+ _default_instance: ClassVar[SchemaRegistry | None] = None
_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self, base_dir: str):
@@ -20,7 +22,7 @@ class SchemaRegistry:
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
@classmethod
- def default_registry(cls) -> "SchemaRegistry":
+ def default_registry(cls) -> SchemaRegistry:
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
if cls._default_instance is None:
with cls._lock:
@@ -33,6 +35,7 @@ class SchemaRegistry:
registry.load_all_versions()
cls._default_instance = registry
+ return cls._default_instance
return cls._default_instance
diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py
index 6e4e616465..b7a45414ef 100644
--- a/api/core/tools/__base/tool.py
+++ b/api/core/tools/__base/tool.py
@@ -1,4 +1,5 @@
-import inspect
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
@@ -25,7 +26,7 @@ class Tool(ABC):
self.entity = entity
self.runtime = runtime
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
"""
fork a new tool with metadata
:return: the new tool
@@ -179,7 +180,7 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
- def create_file_message(self, file: "File") -> ToolInvokeMessage:
+ def create_file_message(self, file: File) -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),
diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py
index 84efefba07..51b0407886 100644
--- a/api/core/tools/builtin_tool/tool.py
+++ b/api/core/tools/builtin_tool/tool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
@@ -24,7 +26,7 @@ class BuiltinTool(Tool):
super().__init__(**kwargs)
self.provider = provider
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
"""
fork a new tool with metadata
:return: the new tool
diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py
index 0cc992155a..e2f6c00555 100644
--- a/api/core/tools/custom_tool/provider.py
+++ b/api/core/tools/custom_tool/provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from pydantic import Field
from sqlalchemy import select
@@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = []
@classmethod
- def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
+ def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
credentials_schema = [
ProviderConfig(
name="auth_type",
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index 353f3a646a..96268d029e 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import contextlib
from collections.abc import Mapping
@@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
MCP = auto()
@classmethod
- def value_of(cls, value: str) -> "ToolProviderType":
+ def value_of(cls, value: str) -> ToolProviderType:
"""
Get value of given mode.
@@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
OPENAI_ACTIONS = auto()
@classmethod
- def value_of(cls, value: str) -> "ApiProviderSchemaType":
+ def value_of(cls, value: str) -> ApiProviderSchemaType:
"""
Get value of given mode.
@@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
API_KEY_QUERY = auto()
@classmethod
- def value_of(cls, value: str) -> "ApiProviderAuthType":
+ def value_of(cls, value: str) -> ApiProviderAuthType:
"""
Get value of given mode.
@@ -128,7 +130,7 @@ class ToolInvokeMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
- json_object: dict
+ json_object: dict | list
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
@@ -142,7 +144,14 @@ class ToolInvokeMessage(BaseModel):
end: bool = Field(..., description="Whether the chunk is the last chunk")
class FileMessage(BaseModel):
- pass
+ file_marker: str = Field(default="file_marker")
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_file_message(cls, values):
+ if isinstance(values, dict) and "file_marker" not in values:
+ raise ValueError("Invalid FileMessage: missing file_marker")
+ return values
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
@@ -153,11 +162,11 @@ class ToolInvokeMessage(BaseModel):
@classmethod
def transform_variable_value(cls, values):
"""
- Only basic types and lists are allowed.
+ Only basic types, lists, and None are allowed.
"""
value = values.get("variable_value")
- if not isinstance(value, dict | list | str | int | float | bool):
- raise ValueError("Only basic types and lists are allowed.")
+ if value is not None and not isinstance(value, dict | list | str | int | float | bool):
+ raise ValueError("Only basic types, lists, and None are allowed.")
# if stream is true, the value must be a string
if values.get("stream"):
@@ -232,10 +241,22 @@ class ToolInvokeMessage(BaseModel):
@field_validator("message", mode="before")
@classmethod
- def decode_blob_message(cls, v):
+ def decode_blob_message(cls, v, info: ValidationInfo):
+ # 处理 blob 解码
if isinstance(v, dict) and "blob" in v:
with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"])
+
+ # Force correct message type based on type field
+ # Only wrap dict types to avoid wrapping already parsed Pydantic model objects
+ if info.data and isinstance(info.data, dict) and isinstance(v, dict):
+ msg_type = info.data.get("type")
+ if msg_type == cls.MessageType.JSON:
+ if "json_object" not in v:
+ v = {"json_object": v}
+ elif msg_type == cls.MessageType.FILE:
+ v = {"file_marker": "file_marker"}
+
return v
@field_serializer("message")
@@ -307,7 +328,7 @@ class ToolParameter(PluginParameter):
typ: ToolParameterType,
required: bool,
options: list[str] | None = None,
- ) -> "ToolParameter":
+ ) -> ToolParameter:
"""
get a simple tool parameter
@@ -429,14 +450,14 @@ class ToolInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
- def empty(cls) -> "ToolInvokeMeta":
+ def empty(cls) -> ToolInvokeMeta:
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
- def error_instance(cls, error: str) -> "ToolInvokeMeta":
+ def error_instance(cls, error: str) -> ToolInvokeMeta:
"""
Get an instance of ToolInvokeMeta with error
"""
diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py
index b0c2232857..e4afe24426 100644
--- a/api/core/tools/errors.py
+++ b/api/core/tools/errors.py
@@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
pass
+class ToolSSRFError(ValueError):
+ pass
+
+
class ToolCredentialPolicyViolationError(ValueError):
pass
diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py
index 60cf2ab925..9f42704eda 100644
--- a/api/core/tools/mcp_tool/tool.py
+++ b/api/core/tools/mcp_tool/tool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import json
import logging
@@ -6,7 +8,15 @@ from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
-from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
+from core.mcp.types import (
+ AudioContent,
+ BlobResourceContents,
+ CallToolResult,
+ EmbeddedResource,
+ ImageContent,
+ TextContent,
+ TextResourceContents,
+)
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
@@ -54,10 +64,19 @@ class MCPTool(Tool):
for content in result.content:
if isinstance(content, TextContent):
yield from self._process_text_content(content)
- elif isinstance(content, ImageContent):
- yield self._process_image_content(content)
- elif isinstance(content, AudioContent):
- yield self._process_audio_content(content)
+ elif isinstance(content, ImageContent | AudioContent):
+ yield self.create_blob_message(
+ blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
+ )
+ elif isinstance(content, EmbeddedResource):
+ resource = content.resource
+ if isinstance(resource, TextResourceContents):
+ yield self.create_text_message(resource.text)
+ elif isinstance(resource, BlobResourceContents):
+ mime_type = resource.mimeType or "application/octet-stream"
+ yield self.create_blob_message(blob=base64.b64decode(resource.blob), meta={"mime_type": mime_type})
+ else:
+ raise ToolInvokeError(f"Unsupported embedded resource type: {type(resource)}")
else:
logger.warning("Unsupported content type=%s", type(content))
@@ -102,15 +121,7 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
- def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage:
- """Process image content and return a blob message."""
- return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
-
- def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
- """Process audio content and return a blob message."""
- return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
-
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,
runtime=runtime,
diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py
index 828dc3b810..d3a2ad488c 100644
--- a/api/core/tools/plugin_tool/tool.py
+++ b/api/core/tools/plugin_tool/tool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from collections.abc import Generator
from typing import Any
@@ -46,7 +48,7 @@ class PluginTool(Tool):
message_id=message_id,
)
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
return PluginTool(
entity=self.entity,
runtime=runtime,
diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py
index fef3157f27..22e099deba 100644
--- a/api/core/tools/signature.py
+++ b/api/core/tools/signature.py
@@ -7,12 +7,12 @@ import time
from configs import dify_config
-def sign_tool_file(tool_file_id: str, extension: str) -> str:
+def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""
sign file to get a temporary url for plugin access
"""
- # Use internal URL for plugin/tool file access in Docker environments
- base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
+ # Use internal URL for plugin/tool file access in Docker environments, unless for_external is True
+ base_url = dify_config.FILES_URL if for_external else (dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL)
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index a6efd91eee..29428116cb 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -1,5 +1,6 @@
import contextlib
import json
+import logging
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
@@ -36,6 +37,8 @@ from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import Message, MessageFile
+logger = logging.getLogger(__name__)
+
class ToolEngine:
"""
@@ -124,25 +127,31 @@ class ToolEngine:
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
+ logger.error(e, exc_info=True)
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
+ logger.error(e, exc_info=True)
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
+ logger.error(e, exc_info=True)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index f8213d9fd7..d561d39923 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -189,16 +189,13 @@ class ToolManager:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
if not provider_controller.need_credentials:
- return cast(
- BuiltinTool,
- builtin_tool.fork_tool_runtime(
- runtime=ToolRuntime(
- tenant_id=tenant_id,
- credentials={},
- invoke_from=invoke_from,
- tool_invoke_from=tool_invoke_from,
- )
- ),
+ return builtin_tool.fork_tool_runtime(
+ runtime=ToolRuntime(
+ tenant_id=tenant_id,
+ credentials={},
+ invoke_from=invoke_from,
+ tool_invoke_from=tool_invoke_from,
+ )
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
@@ -300,18 +297,15 @@ class ToolManager:
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
- return cast(
- BuiltinTool,
- builtin_tool.fork_tool_runtime(
- runtime=ToolRuntime(
- tenant_id=tenant_id,
- credentials=dict(decrypted_credentials),
- credential_type=CredentialType.of(builtin_provider.credential_type),
- runtime_parameters={},
- invoke_from=invoke_from,
- tool_invoke_from=tool_invoke_from,
- )
- ),
+ return builtin_tool.fork_tool_runtime(
+ runtime=ToolRuntime(
+ tenant_id=tenant_id,
+ credentials=dict(decrypted_credentials),
+ credential_type=CredentialType.of(builtin_provider.credential_type),
+ runtime_parameters={},
+ invoke_from=invoke_from,
+ tool_invoke_from=tool_invoke_from,
+ )
)
elif provider_type == ToolProviderType.API:
diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
index f96510fb45..057ec41f65 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
@@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if records:
for record in records:
segment = record.segment
+ # Build content: if summary exists, add it before the segment content
if segment.answer:
- document_context_list.append(
- DocumentContext(
- content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
- score=record.score,
- )
- )
+ segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
else:
- document_context_list.append(
- DocumentContext(
- content=segment.get_sign_content(),
- score=record.score,
- )
+ segment_content = segment.get_sign_content()
+
+ # If summary exists, prepend it to the content
+ if record.summary:
+ final_content = f"{record.summary}\n{segment_content}"
+ else:
+ final_content = segment_content
+
+ document_context_list.append(
+ DocumentContext(
+ content=final_content,
+ score=record.score,
)
+ )
if self.return_resource:
for record in records:
@@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
+ # Add summary if this segment was retrieved via summary
+ if hasattr(record, "summary") and record.summary:
+ source.summary = record.summary
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index ca2aa39861..df322eda1c 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -101,6 +101,8 @@ class ToolFileMessageTransformer:
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
+ if not mimetype:
+ mimetype = "application/octet-stream"
# get filename from meta
filename = meta.get("filename", None)
# if message is str, encode it to bytes
diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py
index 6eabde3991..584975de05 100644
--- a/api/core/tools/utils/parser.py
+++ b/api/core/tools/utils/parser.py
@@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
- ) -> tuple[list[ApiToolBundle], str]:
+ ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle
@@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser:
except ToolApiSchemaError as e:
openapi_error = e
- # openai parse error, fallback to swagger
+ # openapi parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
@@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser:
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
-
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py
index 0f9a91a111..4bfaa5e49b 100644
--- a/api/core/tools/utils/text_processing_utils.py
+++ b/api/core/tools/utils/text_processing_utils.py
@@ -4,6 +4,7 @@ import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
+ Preserves markdown links like [text](url) at the start.
Args:
text (str): The input text to process.
@@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
Returns:
str: The text with leading punctuation or symbols removed.
"""
+ # Check if text starts with a markdown link - preserve it
+ markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
+ if re.match(markdown_link_pattern, text):
+ return text
+
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py
index 0439fb1d60..a706f101ca 100644
--- a/api/core/tools/workflow_as_tool/provider.py
+++ b/api/core/tools/workflow_as_tool/provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from collections.abc import Mapping
from pydantic import Field
@@ -5,6 +7,7 @@ from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
+from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@@ -46,34 +49,30 @@ class WorkflowToolProviderController(ToolProviderController):
self.provider_id = provider_id
@classmethod
- def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
- provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
- if not provider:
- raise ValueError("workflow provider not found")
- app = session.get(App, provider.app_id)
+ def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
+ with session_factory.create_session() as session, session.begin():
+ app = session.get(App, db_provider.app_id)
if not app:
raise ValueError("app not found")
- user = session.get(Account, provider.user_id) if provider.user_id else None
-
+ user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user.name if user else "",
- name=provider.label,
- label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
- description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
- icon=provider.icon,
+ name=db_provider.label,
+ label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
+ description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
+ icon=db_provider.icon,
),
credentials_schema=[],
plugin_id=None,
),
- provider_id=provider.id or "",
+ provider_id=db_provider.id,
)
controller.tools = [
- controller._get_db_provider_tool(provider, app, session=session, user=user),
+ controller._get_db_provider_tool(db_provider, app, session=session, user=user),
]
return controller
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 30334f5da8..9c1ceff145 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -1,12 +1,13 @@
+from __future__ import annotations
+
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
-from flask import has_request_context
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@@ -18,9 +19,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
-from extensions.ext_database import db
from factories.file_factory import build_from_mapping
-from libs.login import current_user
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
@@ -181,7 +180,7 @@ class WorkflowTool(Tool):
return found
return None
- def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
"""
fork a new tool with metadata
@@ -208,50 +207,38 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
- if has_request_context():
- return self._resolve_user_from_request()
- else:
- return self._resolve_user_from_database(user_id=user_id)
-
- def _resolve_user_from_request(self) -> Account | EndUser | None:
- """
- Resolve user from Flask request context.
- """
- try:
- # Note: `current_user` is a LocalProxy. Never compare it with None directly.
- return getattr(current_user, "_get_current_object", lambda: current_user)()
- except Exception as e:
- logger.warning("Failed to resolve user from request context: %s", e)
- return None
+ return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user from database (worker/Celery context).
"""
+ with session_factory.create_session() as session:
+ tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
+ tenant = session.scalar(tenant_stmt)
+ if not tenant:
+ return None
+
+ user_stmt = select(Account).where(Account.id == user_id)
+ user = session.scalar(user_stmt)
+ if user:
+ user.current_tenant = tenant
+ session.expunge(user)
+ return user
+
+ end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
+ end_user = session.scalar(end_user_stmt)
+ if end_user:
+ session.expunge(end_user)
+ return end_user
- tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
- tenant = db.session.scalar(tenant_stmt)
- if not tenant:
return None
- user_stmt = select(Account).where(Account.id == user_id)
- user = db.session.scalar(user_stmt)
- if user:
- user.current_tenant = tenant
- return user
-
- end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
- end_user = db.session.scalar(end_user_stmt)
- if end_user:
- return end_user
-
- return None
-
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@@ -263,22 +250,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
- if not workflow:
- raise ValueError("workflow not found or not published")
+ if not workflow:
+ raise ValueError("workflow not found or not published")
- return workflow
+ session.expunge(workflow)
+ return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
- if not app:
- raise ValueError("app not found")
+ if not app:
+ raise ValueError("app not found")
- return app
+ session.expunge(app)
+ return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
diff --git a/api/core/trigger/debug/event_bus.py b/api/core/trigger/debug/event_bus.py
index 9d10e1a0e0..e3fb6a13d9 100644
--- a/api/core/trigger/debug/event_bus.py
+++ b/api/core/trigger/debug/event_bus.py
@@ -23,8 +23,8 @@ class TriggerDebugEventBus:
"""
# LUA_SELECT: Atomic poll or register for event
- # KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id}
- # KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:...
+ # KEYS[1] = trigger_debug_inbox:{}:
+ # KEYS[2] = trigger_debug_waiting_pool:{}:...
# ARGV[1] = address_id
LUA_SELECT = (
"local v=redis.call('GET',KEYS[1]);"
@@ -35,7 +35,7 @@ class TriggerDebugEventBus:
)
# LUA_DISPATCH: Dispatch event to all waiting addresses
- # KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:...
+ # KEYS[1] = trigger_debug_waiting_pool:{}:...
# ARGV[1] = tenant_id
# ARGV[2] = event_json
LUA_DISPATCH = (
@@ -43,7 +43,7 @@ class TriggerDebugEventBus:
"if #a==0 then return 0 end;"
"redis.call('DEL',KEYS[1]);"
"for i=1,#a do "
- f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
+ f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
"end;"
"return #a"
)
@@ -108,7 +108,7 @@ class TriggerDebugEventBus:
Event object if available, None otherwise
"""
address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
- address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}"
+ address: str = f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}"
try:
event_data = redis_client.eval(
diff --git a/api/core/trigger/debug/events.py b/api/core/trigger/debug/events.py
index 9f7bab5e49..9aec342ed1 100644
--- a/api/core/trigger/debug/events.py
+++ b/api/core/trigger/debug/events.py
@@ -42,7 +42,7 @@ def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str:
app_id: App ID
node_id: Node ID
"""
- return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}"
+ return f"{TriggerDebugPoolKey.WEBHOOK}:{{{tenant_id}}}:{app_id}:{node_id}"
class PluginTriggerDebugEvent(BaseDebugEvent):
@@ -64,4 +64,4 @@ def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str
provider_id: Provider ID
subscription_id: Subscription ID
"""
- return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}"
+ return f"{TriggerDebugPoolKey.PLUGIN}:{{{tenant_id}}}:{str(provider_id)}:{subscription_id}:{name}"
diff --git a/api/core/trigger/utils/encryption.py b/api/core/trigger/utils/encryption.py
index 026a65aa23..b12291e299 100644
--- a/api/core/trigger/utils/encryption.py
+++ b/api/core/trigger/utils/encryption.py
@@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
- cache = TriggerProviderCredentialsCache(
+ TriggerProviderCredentialsCache(
tenant_id=tenant_id,
provider_id=provider_id,
credential_id=subscription_id,
- )
- cache.delete()
+ ).delete()
+ TriggerProviderPropertiesCache(
+ tenant_id=tenant_id,
+ provider_id=provider_id,
+ subscription_id=subscription_id,
+ ).delete()
def create_trigger_provider_encrypter_for_properties(
diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py
index 7a1cbf9940..7498224923 100644
--- a/api/core/variables/__init__.py
+++ b/api/core/variables/__init__.py
@@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
+ VariableBase,
)
__all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
+ "VariableBase",
]
diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py
index 406b4e6f93..8330f1fe19 100644
--- a/api/core/variables/segments.py
+++ b/api/core/variables/segments.py
@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
-# - `Variable` and its subclasses, which are handled by `VariableUnion`.
+# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
diff --git a/api/core/variables/types.py b/api/core/variables/types.py
index ce71711344..13b926c978 100644
--- a/api/core/variables/types.py
+++ b/api/core/variables/types.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
from collections.abc import Mapping
from enum import StrEnum
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Any
from core.file.models import File
@@ -52,7 +54,7 @@ class SegmentType(StrEnum):
return self in _ARRAY_TYPES
@classmethod
- def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
+ def infer_segment_type(cls, value: Any) -> SegmentType | None:
"""
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
@@ -173,7 +175,7 @@ class SegmentType(StrEnum):
raise AssertionError("this statement should be unreachable.")
@staticmethod
- def cast_value(value: Any, type_: "SegmentType"):
+ def cast_value(value: Any, type_: SegmentType):
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
@@ -193,7 +195,7 @@ class SegmentType(StrEnum):
return [int(i) for i in value]
return value
- def exposed_type(self) -> "SegmentType":
+ def exposed_type(self) -> SegmentType:
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
@@ -202,7 +204,7 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
- def element_type(self) -> "SegmentType | None":
+ def element_type(self) -> SegmentType | None:
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
@@ -217,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
- def get_zero_value(t: "SegmentType"):
+ def get_zero_value(t: SegmentType):
# Lazy import to avoid circular dependency
from factories import variable_factory
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index 9fd0bbc5b2..a19c53918d 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
-class Variable(Segment):
+class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
-class StringVariable(StringSegment, Variable):
+class StringVariable(StringSegment, VariableBase):
pass
-class FloatVariable(FloatSegment, Variable):
+class FloatVariable(FloatSegment, VariableBase):
pass
-class IntegerVariable(IntegerSegment, Variable):
+class IntegerVariable(IntegerSegment, VariableBase):
pass
-class ObjectVariable(ObjectSegment, Variable):
+class ObjectVariable(ObjectSegment, VariableBase):
pass
-class ArrayVariable(ArraySegment, Variable):
+class ArrayVariable(ArraySegment, VariableBase):
pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
-class NoneVariable(NoneSegment, Variable):
+class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
-class FileVariable(FileSegment, Variable):
+class FileVariable(FileSegment, VariableBase):
pass
-class BooleanVariable(BooleanSegment, Variable):
+class BooleanVariable(BooleanSegment, VariableBase):
pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
-# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
-# Use `Variable` for type hinting when serialization is not required.
+# The `Variable` type is used to enable serialization and deserialization with Pydantic.
+# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
-# - All variants in `VariableUnion` must inherit from the `Variable` class.
-# - The union must include all non-abstract subclasses of `Segment`, except:
-VariableUnion: TypeAlias = Annotated[
+# - All variants in `Variable` must inherit from the `VariableBase` class.
+# - The union must include all non-abstract subclasses of `VariableBase`.
+Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md
index 72f5dbe1e2..9a39f976a6 100644
--- a/api/core/workflow/README.md
+++ b/api/core/workflow/README.md
@@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
+`engine.layer()` binds the read-only runtime state before execution, so layer hooks
+can assume `graph_runtime_state` is available.
+
### Event-Driven Architecture
All node executions emit events for monitoring and integration:
diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py
new file mode 100644
index 0000000000..1237d6a017
--- /dev/null
+++ b/api/core/workflow/context/__init__.py
@@ -0,0 +1,34 @@
+"""
+Execution Context - Context management for workflow execution.
+
+This package provides Flask-independent context management for workflow
+execution in multi-threaded environments.
+"""
+
+from core.workflow.context.execution_context import (
+ AppContext,
+ ContextProviderNotFoundError,
+ ExecutionContext,
+ IExecutionContext,
+ NullAppContext,
+ capture_current_context,
+ read_context,
+ register_context,
+ register_context_capturer,
+ reset_context_provider,
+)
+from core.workflow.context.models import SandboxContext
+
+__all__ = [
+ "AppContext",
+ "ContextProviderNotFoundError",
+ "ExecutionContext",
+ "IExecutionContext",
+ "NullAppContext",
+ "SandboxContext",
+ "capture_current_context",
+ "read_context",
+ "register_context",
+ "register_context_capturer",
+ "reset_context_provider",
+]
diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py
new file mode 100644
index 0000000000..e3007530f0
--- /dev/null
+++ b/api/core/workflow/context/execution_context.py
@@ -0,0 +1,284 @@
+"""
+Execution Context - Abstracted context management for workflow execution.
+"""
+
+import contextvars
+import threading
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Generator
+from contextlib import AbstractContextManager, contextmanager
+from typing import Any, Protocol, TypeVar, final, runtime_checkable
+
+from pydantic import BaseModel
+
+
+class AppContext(ABC):
+ """
+ Abstract application context interface.
+
+ This abstraction allows workflow execution to work with or without Flask
+ by providing a common interface for application context management.
+ """
+
+ @abstractmethod
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ pass
+
+ @abstractmethod
+ def get_extension(self, name: str) -> Any:
+ """Get Flask extension by name (e.g., 'db', 'cache')."""
+ pass
+
+ @abstractmethod
+ def enter(self) -> AbstractContextManager[None]:
+ """Enter the application context."""
+ pass
+
+
+@runtime_checkable
+class IExecutionContext(Protocol):
+ """
+ Protocol for execution context.
+
+ This protocol defines the interface that all execution contexts must implement,
+ allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
+ """
+
+ def __enter__(self) -> "IExecutionContext":
+ """Enter the execution context."""
+ ...
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ ...
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ ...
+
+
+@final
+class ExecutionContext:
+ """
+ Execution context for workflow execution in worker threads.
+
+ This class encapsulates all context needed for workflow execution:
+ - Application context (Flask app or standalone)
+ - Context variables for Python contextvars
+ - User information (optional)
+
+ It is designed to be serializable and passable to worker threads.
+ """
+
+ def __init__(
+ self,
+ app_context: AppContext | None = None,
+ context_vars: contextvars.Context | None = None,
+ user: Any = None,
+ ) -> None:
+ """
+ Initialize execution context.
+
+ Args:
+ app_context: Application context (Flask or standalone)
+ context_vars: Python contextvars to preserve
+ user: User object (optional)
+ """
+ self._app_context = app_context
+ self._context_vars = context_vars
+ self._user = user
+ self._local = threading.local()
+
+ @property
+ def app_context(self) -> AppContext | None:
+ """Get application context."""
+ return self._app_context
+
+ @property
+ def context_vars(self) -> contextvars.Context | None:
+ """Get context variables."""
+ return self._context_vars
+
+ @property
+ def user(self) -> Any:
+ """Get user object."""
+ return self._user
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """
+ Enter this execution context.
+
+ This is a convenience method that creates a context manager.
+ """
+ # Restore context variables if provided
+ if self._context_vars:
+ for var, val in self._context_vars.items():
+ var.set(val)
+
+ # Enter app context if available
+ if self._app_context is not None:
+ with self._app_context.enter():
+ yield
+ else:
+ yield
+
+ def __enter__(self) -> "ExecutionContext":
+ """Enter the execution context."""
+ cm = self.enter()
+ self._local.cm = cm
+ cm.__enter__()
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ """Exit the execution context."""
+ cm = getattr(self._local, "cm", None)
+ if cm is not None:
+ cm.__exit__(*args)
+
+
+class NullAppContext(AppContext):
+ """
+ Null implementation of AppContext for non-Flask environments.
+
+ This is used when running without Flask (e.g., in tests or standalone mode).
+ """
+
+ def __init__(self, config: dict[str, Any] | None = None) -> None:
+ """
+ Initialize null app context.
+
+ Args:
+ config: Optional configuration dictionary
+ """
+ self._config = config or {}
+ self._extensions: dict[str, Any] = {}
+
+ def get_config(self, key: str, default: Any = None) -> Any:
+ """Get configuration value by key."""
+ return self._config.get(key, default)
+
+ def get_extension(self, name: str) -> Any:
+ """Get extension by name."""
+ return self._extensions.get(name)
+
+ def set_extension(self, name: str, extension: Any) -> None:
+ """Set extension by name."""
+ self._extensions[name] = extension
+
+ @contextmanager
+ def enter(self) -> Generator[None, None, None]:
+ """Enter null context (no-op)."""
+ yield
+
+
+class ExecutionContextBuilder:
+ """
+ Builder for creating ExecutionContext instances.
+
+ This provides a fluent API for building execution contexts.
+ """
+
+ def __init__(self) -> None:
+ self._app_context: AppContext | None = None
+ self._context_vars: contextvars.Context | None = None
+ self._user: Any = None
+
+ def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
+ """Set application context."""
+ self._app_context = app_context
+ return self
+
+ def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
+ """Set context variables."""
+ self._context_vars = context_vars
+ return self
+
+ def with_user(self, user: Any) -> "ExecutionContextBuilder":
+ """Set user."""
+ self._user = user
+ return self
+
+ def build(self) -> ExecutionContext:
+ """Build the execution context."""
+ return ExecutionContext(
+ app_context=self._app_context,
+ context_vars=self._context_vars,
+ user=self._user,
+ )
+
+
+_capturer: Callable[[], IExecutionContext] | None = None
+
+# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
+# Key mapping:
+# (name, tenant_id) -> provider
+# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
+# - tenant_id: tenant identifier string
+# Value:
+# provider: Callable[[], BaseModel] returning the typed context value
+# Type-safety note:
+# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
+# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
+# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
+# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
+_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
+
+T = TypeVar("T", bound=BaseModel)
+
+
+class ContextProviderNotFoundError(KeyError):
+ """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
+
+ pass
+
+
+def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
+ """Register a single enterable execution context capturer (e.g., Flask)."""
+ global _capturer
+ _capturer = capturer
+
+
+def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
+ """Register a tenant-specific provider for a named context.
+
+ Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
+ Consider adding a typed wrapper for this registration in your feature module.
+ """
+ _tenant_context_providers[(name, tenant_id)] = provider
+
+
+def read_context(name: str, *, tenant_id: str) -> BaseModel:
+ """
+ Read a context value for a specific tenant.
+
+ Raises KeyError if the provider for (name, tenant_id) is not registered.
+ """
+ prov = _tenant_context_providers.get((name, tenant_id))
+ if prov is None:
+ raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
+ return prov()
+
+
+def capture_current_context() -> IExecutionContext:
+ """
+ Capture current execution context from the calling environment.
+
+ If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
+ context with NullAppContext + copy of current contextvars.
+ """
+ if _capturer is None:
+ return ExecutionContext(
+ app_context=NullAppContext(),
+ context_vars=contextvars.copy_context(),
+ )
+ return _capturer()
+
+
+def reset_context_provider() -> None:
+ """Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
+ global _capturer
+ _capturer = None
+ _tenant_context_providers.clear()
diff --git a/api/core/workflow/context/models.py b/api/core/workflow/context/models.py
new file mode 100644
index 0000000000..af5a4b2614
--- /dev/null
+++ b/api/core/workflow/context/models.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+from pydantic import AnyHttpUrl, BaseModel
+
+
+class SandboxContext(BaseModel):
+ """Typed context for sandbox integration. All fields optional by design."""
+
+ sandbox_url: AnyHttpUrl | None = None
+ sandbox_token: str | None = None # optional, if later needed for auth
+
+
+__all__ = ["SandboxContext"]
diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py
index fd78248c17..75f47691da 100644
--- a/api/core/workflow/conversation_variable_updater.py
+++ b/api/core/workflow/conversation_variable_updater.py
@@ -1,7 +1,7 @@
import abc
from typing import Protocol
-from core.variables import Variable
+from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
- def update(self, conversation_id: str, variable: "Variable"):
+ def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
- :param variable: The `Variable` instance containing the updated value.
+ :param variable: The `VariableBase` instance containing the updated value.
"""
pass
diff --git a/api/core/workflow/entities/graph_config.py b/api/core/workflow/entities/graph_config.py
new file mode 100644
index 0000000000..209dcfe6bc
--- /dev/null
+++ b/api/core/workflow/entities/graph_config.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+import sys
+
+from pydantic import TypeAdapter, with_config
+
+if sys.version_info >= (3, 12):
+ from typing import TypedDict
+else:
+ from typing_extensions import TypedDict
+
+
+@with_config(extra="allow")
+class NodeConfigData(TypedDict):
+ type: str
+
+
+@with_config(extra="allow")
+class NodeConfigDict(TypedDict):
+ id: str
+ data: NodeConfigData
+
+
+NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py
index a8a86d3db2..1b3fb36f1f 100644
--- a/api/core/workflow/entities/workflow_execution.py
+++ b/api/core/workflow/entities/workflow_execution.py
@@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
+from __future__ import annotations
+
from collections.abc import Mapping
from datetime import datetime
from typing import Any
@@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
- ) -> "WorkflowExecution":
+ ) -> WorkflowExecution:
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,
diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py
index cf12d5ec1f..bb3b13e8c6 100644
--- a/api/core/workflow/enums.py
+++ b/api/core/workflow/enums.py
@@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
+ @classmethod
+ def ended_values(cls) -> list[str]:
+ return [status.value for status in _END_STATE]
+
_END_STATE = frozenset(
[
@@ -247,6 +251,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
+ COMPLETED_REASON = "completed_reason" # completed reason for loop node
class WorkflowNodeExecutionStatus(StrEnum):
diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py
index ba5a01fc94..52bbbb20cc 100644
--- a/api/core/workflow/graph/graph.py
+++ b/api/core/workflow/graph/graph.py
@@ -1,17 +1,24 @@
+from __future__ import annotations
+
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
+from pydantic import TypeAdapter
+
+from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
-from libs.typing import is_str, is_str_dict
+from libs.typing import is_str
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
+_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
+
class NodeFactory(Protocol):
"""
@@ -21,7 +28,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety.
"""
- def create_node(self, node_config: dict[str, object]) -> Node:
+ def create_node(self, node_config: NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data.
@@ -61,28 +68,24 @@ class Graph:
self.root_node = root_node
@classmethod
- def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
+ def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
- node_configs_map: dict[str, dict[str, object]] = {}
+ node_configs_map: dict[str, NodeConfigDict] = {}
for node_config in node_configs:
- node_id = node_config.get("id")
- if not node_id or not isinstance(node_id, str):
- continue
-
- node_configs_map[node_id] = node_config
+ node_configs_map[node_config["id"]] = node_config
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
- node_configs_map: Mapping[str, Mapping[str, object]],
+ node_configs_map: Mapping[str, NodeConfigDict],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
@@ -111,10 +114,8 @@ class Graph:
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
- node_data = node_configs_map[nid].get("data")
- if not is_str_dict(node_data):
- continue
- node_type = node_data.get("type")
+ node_data = node_configs_map[nid]["data"]
+ node_type = node_data["type"]
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
@@ -174,8 +175,8 @@ class Graph:
@classmethod
def _create_node_instances(
cls,
- node_configs_map: dict[str, dict[str, object]],
- node_factory: "NodeFactory",
+ node_configs_map: dict[str, NodeConfigDict],
+ node_factory: NodeFactory,
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
@@ -197,7 +198,7 @@ class Graph:
return nodes
@classmethod
- def new(cls) -> "GraphBuilder":
+ def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@@ -284,9 +285,10 @@ class Graph:
cls,
*,
graph_config: Mapping[str, object],
- node_factory: "NodeFactory",
+ node_factory: NodeFactory,
root_node_id: str | None = None,
- ) -> "Graph":
+ skip_validation: bool = False,
+ ) -> Graph:
"""
Initialize graph
@@ -300,7 +302,7 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
- node_configs = cast(list[dict[str, object]], node_configs)
+ node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
@@ -337,8 +339,9 @@ class Graph:
root_node=root_node,
)
- # Validate the graph structure using built-in validators
- get_graph_validator().validate(graph)
+ if not skip_validation:
+ # Validate the graph structure using built-in validators
+ get_graph_validator().validate(graph)
return graph
@@ -383,7 +386,7 @@ class GraphBuilder:
self._edges: list[Edge] = []
self._edge_counter = 0
- def add_root(self, node: Node) -> "GraphBuilder":
+ def add_root(self, node: Node) -> GraphBuilder:
"""Register the root node. Must be called exactly once."""
if self._nodes:
@@ -398,7 +401,7 @@ class GraphBuilder:
*,
from_node_id: str | None = None,
source_handle: str = "source",
- ) -> "GraphBuilder":
+ ) -> GraphBuilder:
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
@@ -419,7 +422,7 @@ class GraphBuilder:
return self
- def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
+ def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py
index fe792c71ad..0e1c7dd60a 100644
--- a/api/core/workflow/graph_engine/__init__.py
+++ b/api/core/workflow/graph_engine/__init__.py
@@ -1,3 +1,4 @@
+from .config import GraphEngineConfig
from .graph_engine import GraphEngine
-__all__ = ["GraphEngine"]
+__all__ = ["GraphEngine", "GraphEngineConfig"]
diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py
index 4be3adb8f8..0fccd4a0fd 100644
--- a/api/core/workflow/graph_engine/command_channels/redis_channel.py
+++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py
@@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
-from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
+from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
+ if command_type == CommandType.UPDATE_VARIABLES:
+ return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py
index 837f5e55fd..7b4f0dfff7 100644
--- a/api/core/workflow/graph_engine/command_processing/__init__.py
+++ b/api/core/workflow/graph_engine/command_processing/__init__.py
@@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
-from .command_handlers import AbortCommandHandler, PauseCommandHandler
+from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
+ "UpdateVariablesCommandHandler",
]
diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py
index e9f109c88c..cfe856d9e8 100644
--- a/api/core/workflow/graph_engine/command_processing/command_handlers.py
+++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py
@@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
+from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
-from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
+from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
+
+
+@final
+class UpdateVariablesCommandHandler(CommandHandler):
+ def __init__(self, variable_pool: VariablePool) -> None:
+ self._variable_pool = variable_pool
+
+ @override
+ def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
+ assert isinstance(command, UpdateVariablesCommand)
+ for update in command.updates:
+ try:
+ variable = update.value
+ self._variable_pool.add(variable.selector, variable)
+ logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
+ except ValueError as exc:
+ logger.warning(
+ "Skipping invalid variable selector %s for workflow %s: %s",
+ getattr(update.value, "selector", None),
+ execution.workflow_id,
+ exc,
+ )
diff --git a/api/core/workflow/graph_engine/config.py b/api/core/workflow/graph_engine/config.py
new file mode 100644
index 0000000000..10dbbd7535
--- /dev/null
+++ b/api/core/workflow/graph_engine/config.py
@@ -0,0 +1,14 @@
+"""
+GraphEngine configuration models.
+"""
+
+from pydantic import BaseModel
+
+
+class GraphEngineConfig(BaseModel):
+ """Configuration for GraphEngine worker pool scaling."""
+
+ min_workers: int = 1
+ max_workers: int = 5
+ scale_up_threshold: int = 3
+ scale_down_idle_time: float = 5.0
diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py
index 0d51b2b716..41276eb444 100644
--- a/api/core/workflow/graph_engine/entities/commands.py
+++ b/api/core/workflow/graph_engine/entities/commands.py
@@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
-from enum import StrEnum
+from collections.abc import Sequence
+from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
+from core.variables.variables import Variable
+
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
- ABORT = "abort"
- PAUSE = "pause"
+ ABORT = auto()
+ PAUSE = auto()
+ UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
+
+
+class VariableUpdate(BaseModel):
+ """Represents a single variable update instruction."""
+
+ value: Variable = Field(description="New variable value")
+
+
+class UpdateVariablesCommand(GraphEngineCommand):
+ """Command to update a group of variables in the variable pool."""
+
+ command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
+ updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index a4b2df2a8c..2b76b563ff 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -5,14 +5,15 @@ This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
-import contextvars
+from __future__ import annotations
+
import logging
import queue
+import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
-from flask import Flask, current_app
-
+from core.workflow.context import capture_current_context
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@@ -30,8 +31,14 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
-from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
-from .entities.commands import AbortCommand, PauseCommand
+from .command_processing import (
+ AbortCommandHandler,
+ CommandProcessor,
+ PauseCommandHandler,
+ UpdateVariablesCommandHandler,
+)
+from .config import GraphEngineConfig
+from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@@ -39,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
-from .ready_queue import ReadyQueue
from .worker_management import WorkerPool
if TYPE_CHECKING:
@@ -64,32 +70,26 @@ class GraphEngine:
graph: Graph,
graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel,
- min_workers: int | None = None,
- max_workers: int | None = None,
- scale_up_threshold: int | None = None,
- scale_down_idle_time: float | None = None,
+ config: GraphEngineConfig,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
+ # stop event
+ self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
+ self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
+ self._config = config
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
self._graph_execution.workflow_id = workflow_id
- # === Worker Management Parameters ===
- # Parameters for dynamic worker pool scaling
- self._min_workers = min_workers
- self._max_workers = max_workers
- self._scale_up_threshold = scale_up_threshold
- self._scale_down_idle_time = scale_down_idle_time
-
# === Execution Queues ===
- self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
+ self._ready_queue = self._graph_runtime_state.ready_queue
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
@@ -140,30 +140,26 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
- # === Worker Pool Setup ===
- # Capture Flask app context for worker threads
- flask_app: Flask | None = None
- try:
- app = current_app._get_current_object() # type: ignore
- if isinstance(app, Flask):
- flask_app = app
- except RuntimeError:
- pass
+ update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
+ self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
- # Capture context variables for worker threads
- context_vars = contextvars.copy_context()
+ # === Extensibility ===
+ # Layers allow plugins to extend engine functionality
+ self._layers: list[GraphEngineLayer] = []
+
+ # === Worker Pool Setup ===
+ # Capture execution context for worker threads
+ execution_context = capture_current_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
- flask_app=flask_app,
- context_vars=context_vars,
- min_workers=self._min_workers,
- max_workers=self._max_workers,
- scale_up_threshold=self._scale_up_threshold,
- scale_down_idle_time=self._scale_down_idle_time,
+ layers=self._layers,
+ execution_context=execution_context,
+ config=self._config,
+ stop_event=self._stop_event,
)
# === Orchestration ===
@@ -194,12 +190,9 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
+ stop_event=self._stop_event,
)
- # === Extensibility ===
- # Layers allow plugins to extend engine functionality
- self._layers: list[GraphEngineLayer] = []
-
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()
@@ -211,9 +204,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
- def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
+ def _bind_layer_context(
+ self,
+ layer: GraphEngineLayer,
+ ) -> None:
+ layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
+
+ def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
+ self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@@ -300,14 +300,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
- # Create a read-only wrapper for the runtime state
- read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
- try:
- layer.initialize(read_only_state, self._command_channel)
- except Exception as e:
- logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
-
try:
layer.on_graph_start()
except Exception as e:
@@ -315,6 +308,7 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
+ self._stop_event.clear()
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
@@ -342,13 +336,12 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
+ self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
- logger = logging.getLogger(__name__)
-
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py
index 78f8ecdcdf..b9c9243963 100644
--- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py
+++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py
@@ -60,6 +60,7 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
+ self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/core/workflow/graph_engine/layers/README.md
index 17845ee1f0..b0f295037c 100644
--- a/api/core/workflow/graph_engine/layers/README.md
+++ b/api/core/workflow/graph_engine/layers/README.md
@@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
-- `initialize()` - Receive runtime context
+- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
+`engine.layer()` binds the read-only runtime state before execution, so
+`graph_runtime_state` is always available inside layer hooks.
+
## Custom Layers
```python
diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py
index 24c12c2934..ff4a483aed 100644
--- a/api/core/workflow/graph_engine/layers/base.py
+++ b/api/core/workflow/graph_engine/layers/base.py
@@ -8,10 +8,19 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
-from core.workflow.graph_events import GraphEngineEvent
+from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
+class GraphEngineLayerNotInitializedError(Exception):
+ """Raised when a layer's runtime state is accessed before initialization."""
+
+ def __init__(self, layer_name: str | None = None) -> None:
+ name = layer_name or "GraphEngineLayer"
+ super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
+
+
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@@ -27,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
- self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
+ self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
+ @property
+ def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
+ if self._graph_runtime_state is None:
+ raise GraphEngineLayerNotInitializedError(type(self).__name__)
+ return self._graph_runtime_state
+
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
- Called by GraphEngine before execution starts to inject the read-only runtime state
- and command channel. This allows layers to observe engine context and send
- commands, but prevents direct state modification.
-
+ Called by GraphEngine to inject the read-only runtime state and command channel.
+ This is invoked when the layer is registered with a `GraphEngine` instance.
+ Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
- self.graph_runtime_state = graph_runtime_state
+ self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
@@ -83,3 +97,32 @@ class GraphEngineLayer(ABC):
error: The exception that caused execution to fail, or None if successful
"""
pass
+
+ def on_node_run_start(self, node: Node) -> None:
+ """
+ Called immediately before a node begins execution.
+
+ Layers can override to inject behavior (e.g., start spans) prior to node execution.
+ The node's execution ID is available via `node._node_execution_id` and will be
+ consistent with all events emitted by this node execution.
+
+ Args:
+ node: The node instance about to be executed
+ """
+ return
+
+ def on_node_run_end(
+ self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ """
+ Called after a node finishes execution.
+
+ The node's execution ID is available via `node._node_execution_id` and matches
+ the `id` field in all events emitted by this node execution.
+
+ Args:
+ node: The node instance that just finished execution
+ error: Exception instance if the node failed, otherwise None
+ result_event: The final result event from node execution (succeeded/failed/paused), if any
+ """
+ return
diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py
index 034ebcf54f..e0402cd09c 100644
--- a/api/core/workflow/graph_engine/layers/debug_logging.py
+++ b/api/core/workflow/graph_engine/layers/debug_logging.py
@@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
-
- if self.graph_runtime_state:
- # Log initial state
- self.logger.info("Initial State:")
+ # Log initial state
+ self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
- if self.graph_runtime_state and self.include_outputs:
- if self.graph_runtime_state.outputs:
- self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
+ if self.include_outputs and self.graph_runtime_state.outputs:
+ self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)
diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py
index 0577ba8f02..d2cfa755d9 100644
--- a/api/core/workflow/graph_engine/manager.py
+++ b/api/core/workflow/graph_engine/manager.py
@@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
-Supports stop, pause, and resume operations.
"""
import logging
+from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
-from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
+from core.workflow.graph_engine.entities.commands import (
+ AbortCommand,
+ GraphEngineCommand,
+ PauseCommand,
+ UpdateVariablesCommand,
+ VariableUpdate,
+)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
- Supports stop and pause operations.
"""
@staticmethod
@@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
+ @staticmethod
+ def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
+ """Send a command to update variables in a running workflow."""
+
+ if not updates:
+ return
+
+ update_command = UpdateVariablesCommand(updates=updates)
+ GraphEngineManager._send_command(task_id, update_command)
+
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py
index 334a3f77bf..27439a2412 100644
--- a/api/core/workflow/graph_engine/orchestration/dispatcher.py
+++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py
@@ -44,6 +44,7 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
+ stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@@ -61,7 +62,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
- self._stop_event = threading.Event()
+ self._stop_event = stop_event
self._start_time: float | None = None
def start(self) -> None:
@@ -69,16 +70,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
- self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
- self._stop_event.set()
if self._thread and self._thread.is_alive():
- self._thread.join(timeout=10.0)
+ self._thread.join(timeout=2.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py
index 1144e1de69..a9d4f470e5 100644
--- a/api/core/workflow/graph_engine/ready_queue/factory.py
+++ b/api/core/workflow/graph_engine/ready_queue/factory.py
@@ -2,6 +2,8 @@
Factory for creating ReadyQueue instances from serialized state.
"""
+from __future__ import annotations
+
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
@@ -11,7 +13,7 @@ if TYPE_CHECKING:
from .protocol import ReadyQueue
-def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
+def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.
diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py
index 98e0ea91ef..e82ba29438 100644
--- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py
+++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py
@@ -15,10 +15,10 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
-from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
+from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .path import Path
from .session import ResponseSession
@@ -75,7 +75,7 @@ class ResponseStreamCoordinator:
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
- def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
+ def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
"""
Initialize coordinator with variable pool.
diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py
index 8b7c2e441e..5e4fada7d9 100644
--- a/api/core/workflow/graph_engine/response_coordinator/session.py
+++ b/api/core/workflow/graph_engine/response_coordinator/session.py
@@ -5,13 +5,15 @@ This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
+from __future__ import annotations
+
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
-from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
+from core.workflow.runtime.graph_runtime_state import NodeProtocol
@dataclass
@@ -27,21 +29,26 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
- def from_node(cls, node: Node) -> "ResponseSession":
+ def from_node(cls, node: NodeProtocol) -> ResponseSession:
"""
- Create a ResponseSession from an AnswerNode or EndNode.
+ Create a ResponseSession from a response-capable node.
+
+ The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
+ but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
+ - `id: str`
+ - `get_streaming_template() -> Template`
Args:
- node: Must be either an AnswerNode or EndNode instance
+ node: Node from the materialized workflow graph.
Returns:
ResponseSession configured with the node's streaming template
Raises:
- TypeError: If node is not an AnswerNode or EndNode
+ TypeError: If node is not a supported response node type.
"""
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
- raise TypeError
+ raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
return cls(
node_id=node.id,
template=node.get_streaming_template(),
diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py
index 73e59ee298..512df6ff86 100644
--- a/api/core/workflow/graph_engine/worker.py
+++ b/api/core/workflow/graph_engine/worker.py
@@ -5,24 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
-import contextvars
import queue
import threading
import time
+from collections.abc import Sequence
from datetime import datetime
-from typing import final
-from uuid import uuid4
+from typing import TYPE_CHECKING, final
-from flask import Flask
from typing_extensions import override
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
-from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
+from core.workflow.graph_engine.layers.base import GraphEngineLayer
+from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
from core.workflow.nodes.base.node import Node
-from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
+if TYPE_CHECKING:
+ pass
+
@final
class Worker(threading.Thread):
@@ -39,9 +41,10 @@ class Worker(threading.Thread):
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
+ layers: Sequence[GraphEngineLayer],
+ stop_event: threading.Event,
worker_id: int = 0,
- flask_app: Flask | None = None,
- context_vars: contextvars.Context | None = None,
+ execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize worker thread.
@@ -50,23 +53,26 @@ class Worker(threading.Thread):
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
+ layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
- flask_app: Optional Flask application for context preservation
- context_vars: Optional context variables to preserve in worker thread
+ execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
- self._flask_app = flask_app
- self._context_vars = context_vars
- self._stop_event = threading.Event()
+ self._execution_context = execution_context
+ self._stop_event = stop_event
+ self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
- """Signal the worker to stop processing."""
- self._stop_event.set()
+ """Worker is controlled via shared stop_event from GraphEngine.
+
+ This method is a no-op retained for backward compatibility.
+ """
+ pass
@property
def is_idle(self) -> bool:
@@ -106,7 +112,7 @@ class Worker(threading.Thread):
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
- id=str(uuid4()),
+ id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
@@ -122,20 +128,56 @@ class Worker(threading.Thread):
Args:
node: The node instance to execute
"""
- # Execute the node with preserved context if Flask app is provided
- if self._flask_app and self._context_vars:
- with preserve_flask_contexts(
- flask_app=self._flask_app,
- context_vars=self._context_vars,
- ):
- # Execute the node
+ node.ensure_execution_id()
+
+ error: Exception | None = None
+ result_event: GraphNodeEventBase | None = None
+
+ # Execute the node with preserved context if execution context is provided
+ if self._execution_context is not None:
+ with self._execution_context:
+ self._invoke_node_run_start_hooks(node)
+ try:
+ node_events = node.run()
+ for event in node_events:
+ self._event_queue.put(event)
+ if is_node_result_event(event):
+ result_event = event
+ except Exception as exc:
+ error = exc
+ raise
+ finally:
+ self._invoke_node_run_end_hooks(node, error, result_event)
+ else:
+ self._invoke_node_run_start_hooks(node)
+ try:
node_events = node.run()
for event in node_events:
- # Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
- else:
- # Execute without context preservation
- node_events = node.run()
- for event in node_events:
- # Forward event to dispatcher immediately for streaming
- self._event_queue.put(event)
+ if is_node_result_event(event):
+ result_event = event
+ except Exception as exc:
+ error = exc
+ raise
+ finally:
+ self._invoke_node_run_end_hooks(node, error, result_event)
+
+ def _invoke_node_run_start_hooks(self, node: Node) -> None:
+ """Invoke on_node_run_start hooks for all layers."""
+ for layer in self._layers:
+ try:
+ layer.on_node_run_start(node)
+ except Exception:
+ # Silently ignore layer errors to prevent disrupting node execution
+ continue
+
+ def _invoke_node_run_end_hooks(
+ self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ """Invoke on_node_run_end hooks for all layers."""
+ for layer in self._layers:
+ try:
+ layer.on_node_run_end(node, error, result_event)
+ except Exception:
+ # Silently ignore layer errors to prevent disrupting node execution
+ continue
diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py
index a9aada9ea5..3bff566ac8 100644
--- a/api/core/workflow/graph_engine/worker_management/worker_pool.py
+++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py
@@ -8,22 +8,19 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
-from typing import TYPE_CHECKING, final
+from typing import final
-from configs import dify_config
+from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
+from ..config import GraphEngineConfig
+from ..layers.base import GraphEngineLayer
from ..ready_queue import ReadyQueue
from ..worker import Worker
logger = logging.getLogger(__name__)
-if TYPE_CHECKING:
- from contextvars import Context
-
- from flask import Flask
-
@final
class WorkerPool:
@@ -39,12 +36,10 @@ class WorkerPool:
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
- flask_app: "Flask | None" = None,
- context_vars: "Context | None" = None,
- min_workers: int | None = None,
- max_workers: int | None = None,
- scale_up_threshold: int | None = None,
- scale_down_idle_time: float | None = None,
+ layers: list[GraphEngineLayer],
+ stop_event: threading.Event,
+ config: GraphEngineConfig,
+ execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize the simple worker pool.
@@ -53,30 +48,23 @@ class WorkerPool:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
- flask_app: Optional Flask app for context preservation
- context_vars: Optional context variables
- min_workers: Minimum number of workers
- max_workers: Maximum number of workers
- scale_up_threshold: Queue depth to trigger scale up
- scale_down_idle_time: Seconds before scaling down idle workers
+ layers: Graph engine layers for node execution hooks
+ config: GraphEngine worker pool configuration
+ execution_context: Optional execution context for context preservation
"""
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
- self._flask_app = flask_app
- self._context_vars = context_vars
-
- # Scaling parameters with defaults
- self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
- self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
- self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
- self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
+ self._execution_context = execution_context
+ self._layers = layers
+ self._config = config
# Worker management
self._workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
+ self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@@ -97,18 +85,18 @@ class WorkerPool:
if initial_count is None:
node_count = len(self._graph.nodes)
if node_count < 10:
- initial_count = self._min_workers
+ initial_count = self._config.min_workers
elif node_count < 50:
- initial_count = min(self._min_workers + 1, self._max_workers)
+ initial_count = min(self._config.min_workers + 1, self._config.max_workers)
else:
- initial_count = min(self._min_workers + 2, self._max_workers)
+ initial_count = min(self._config.min_workers + 2, self._config.max_workers)
logger.debug(
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
initial_count,
node_count,
- self._min_workers,
- self._max_workers,
+ self._config.min_workers,
+ self._config.max_workers,
)
# Create initial workers
@@ -131,7 +119,7 @@ class WorkerPool:
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
- worker.join(timeout=10.0)
+ worker.join(timeout=2.0)
self._workers.clear()
@@ -144,9 +132,10 @@ class WorkerPool:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
+ layers=self._layers,
worker_id=worker_id,
- flask_app=self._flask_app,
- context_vars=self._context_vars,
+ execution_context=self._execution_context,
+ stop_event=self._stop_event,
)
worker.start()
@@ -176,7 +165,7 @@ class WorkerPool:
Returns:
True if scaled up, False otherwise
"""
- if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
+ if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers:
old_count = current_count
self._create_worker()
@@ -185,7 +174,7 @@ class WorkerPool:
old_count,
len(self._workers),
queue_depth,
- self._scale_up_threshold,
+ self._config.scale_up_threshold,
)
return True
return False
@@ -204,7 +193,7 @@ class WorkerPool:
True if scaled down, False otherwise
"""
# Skip if we're at minimum or have no idle workers
- if current_count <= self._min_workers or idle_count == 0:
+ if current_count <= self._config.min_workers or idle_count == 0:
return False
# Check if we have excess capacity
@@ -222,10 +211,10 @@ class WorkerPool:
for worker in self._workers:
# Check if worker is idle and has exceeded idle time threshold
- if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
+ if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time:
# Don't remove if it would leave us unable to handle the queue
remaining_workers = current_count - len(workers_to_remove) - 1
- if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
+ if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2):
workers_to_remove.append((worker, worker.worker_id))
# Only remove one worker per check to avoid aggressive scaling
break
@@ -242,7 +231,7 @@ class WorkerPool:
old_count,
len(self._workers),
len(workers_to_remove),
- self._scale_down_idle_time,
+ self._config.scale_down_idle_time,
queue_depth,
active_count,
idle_count - len(workers_to_remove),
@@ -286,6 +275,6 @@ class WorkerPool:
return {
"total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(),
- "min_workers": self._min_workers,
- "max_workers": self._max_workers,
+ "min_workers": self._config.min_workers,
+ "max_workers": self._config.max_workers,
}
diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py
index 7a5edbb331..2b6ee4ec1c 100644
--- a/api/core/workflow/graph_events/__init__.py
+++ b/api/core/workflow/graph_events/__init__.py
@@ -44,6 +44,7 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
+ is_node_result_event,
)
__all__ = [
@@ -73,4 +74,5 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
+ "is_node_result_event",
]
diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py
index f225798d41..4d0108e77b 100644
--- a/api/core/workflow/graph_events/node.py
+++ b/api/core/workflow/graph_events/node.py
@@ -56,3 +56,26 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
+
+
+def is_node_result_event(event: GraphNodeEventBase) -> bool:
+ """
+ Check if an event is a final result event from node execution.
+
+ A result event indicates the completion of a node execution and contains
+ runtime information such as inputs, outputs, or error details.
+
+ Args:
+ event: The event to check
+
+ Returns:
+ True if the event is a node result event (succeeded/failed/paused), False otherwise
+ """
+ return isinstance(
+ event,
+ (
+ NodeRunSucceededEvent,
+ NodeRunFailedEvent,
+ NodeRunPauseRequestedEvent,
+ ),
+ )
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 4be006de11..5a365f769d 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
@@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
- strategy: "PluginAgentStrategy",
+ strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@@ -233,7 +235,18 @@ class AgentNode(Node[AgentNodeData]):
0,
):
value_param = param.get("value", {})
- params[key] = value_param.get("value", "") if value_param is not None else None
+ if value_param and value_param.get("type", "") == "variable":
+ variable_selector = value_param.get("value")
+ if not variable_selector:
+ raise ValueError("Variable selector is missing for a variable-type parameter.")
+
+ variable = variable_pool.get(variable_selector)
+ if variable is None:
+ raise AgentVariableNotFoundError(str(variable_selector))
+
+ params[key] = variable.value
+ else:
+ params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
@@ -328,7 +341,7 @@ class AgentNode(Node[AgentNodeData]):
def _generate_credentials(
self,
parameters: dict[str, Any],
- ) -> "InvokeCredentials":
+ ) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
@@ -442,9 +455,7 @@ class AgentNode(Node[AgentNodeData]):
model_schema.features.remove(feature)
return model_schema
- def _filter_mcp_type_tool(
- self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
- ) -> list[dict[str, Any]]:
+ def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
@@ -483,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
text = ""
files: list[File] = []
- json_list: list[dict] = []
+ json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
@@ -557,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
- msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
- llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
- agent_execution_metadata = {
- WorkflowNodeExecutionMetadataKey(key): value
- for key, value in msg_metadata.items()
- if key in WorkflowNodeExecutionMetadataKey.__members__.values()
- }
+ if isinstance(message.message.json_object, dict):
+ msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
+ llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
+ agent_execution_metadata = {
+ WorkflowNodeExecutionMetadataKey(key): value
+ for key, value in msg_metadata.items()
+ if key in WorkflowNodeExecutionMetadataKey.__members__.values()
+ }
+ else:
+ msg_metadata = {}
+ llm_usage = LLMUsage.empty_usage()
+ agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
@@ -672,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
- json_output: list[dict[str, Any]] = []
+ json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py
index 944f5f0b20..ba2c83d8a6 100644
--- a/api/core/workflow/nodes/agent/exc.py
+++ b/api/core/workflow/nodes/agent/exc.py
@@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
+
+
+class AgentMaxIterationError(AgentNodeError):
+ """Exception raised when the agent exceeds the maximum iteration limit."""
+
+ def __init__(self, max_iteration: int):
+ self.max_iteration = max_iteration
+ super().__init__(
+ f"Agent exceeded the maximum iteration limit of {max_iteration}. "
+ f"The agent was unable to complete the task within the allowed number of iterations."
+ )
diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py
index 5aab6bbde4..c5426e3fb7 100644
--- a/api/core/workflow/nodes/base/entities.py
+++ b/api/core/workflow/nodes/base/entities.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from abc import ABC
from builtins import type as type_
@@ -111,9 +113,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
- def validate_value_type(self) -> "DefaultValue":
+ def validate_value_type(self) -> DefaultValue:
# Type validation configuration
- type_validators = {
+ type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py
index c2e1105971..63e0260341 100644
--- a/api/core/workflow/nodes/base/node.py
+++ b/api/core/workflow/nodes/base/node.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import importlib
import logging
import operator
@@ -59,7 +61,7 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
- node_type: ClassVar["NodeType"]
+ node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@@ -198,14 +200,14 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
- _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
+ _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
def __init__(
self,
id: str,
config: Mapping[str, Any],
- graph_init_params: "GraphInitParams",
- graph_runtime_state: "GraphRuntimeState",
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self.id = id
@@ -241,9 +243,18 @@ class Node(Generic[NodeDataT]):
return
@property
- def graph_init_params(self) -> "GraphInitParams":
+ def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
+ @property
+ def execution_id(self) -> str:
+ return self._node_execution_id
+
+ def ensure_execution_id(self) -> str:
+ if not self._node_execution_id:
+ self._node_execution_id = str(uuid4())
+ return self._node_execution_id
+
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@@ -255,15 +266,17 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
+ def _should_stop(self) -> bool:
+ """Check if execution should be stopped."""
+ return self.graph_runtime_state.stop_event.is_set()
+
def run(self) -> Generator[GraphNodeEventBase, None, None]:
- # Generate a single node execution ID to use for all events
- if not self._node_execution_id:
- self._node_execution_id = str(uuid4())
+ execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
- id=self._node_execution_id,
+ id=execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
@@ -321,10 +334,25 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
- event.id = self._node_execution_id
+ event.id = self.execution_id
yield event
else:
yield event
+
+ if self._should_stop():
+ error_message = "Execution cancelled"
+ yield NodeRunFailedEvent(
+ id=self.execution_id,
+ node_id=self._node_id,
+ node_type=self.node_type,
+ start_at=self._start_at,
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ error=error_message,
+ ),
+ error=error_message,
+ )
+ return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(
@@ -333,7 +361,7 @@ class Node(Generic[NodeDataT]):
error_type="WorkflowNodeError",
)
yield NodeRunFailedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@@ -431,7 +459,7 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
- def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
+ def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.
@@ -441,12 +469,8 @@ class Node(Generic[NodeDataT]):
import core.workflow.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
- # Avoid importing modules that depend on the registry to prevent circular imports
- # e.g. node_factory imports node_mapping which builds the mapping here.
- if _modname in {
- "core.workflow.nodes.node_factory",
- "core.workflow.nodes.node_mapping",
- }:
+ # Avoid importing modules that depend on the registry to prevent circular imports.
+ if _modname == "core.workflow.nodes.node_mapping":
continue
importlib.import_module(_modname)
@@ -512,7 +536,7 @@ class Node(Generic[NodeDataT]):
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@@ -521,7 +545,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@@ -537,7 +561,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
@@ -550,7 +574,7 @@ class Node(Generic[NodeDataT]):
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@@ -558,7 +582,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@@ -573,7 +597,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
@@ -583,7 +607,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
@@ -599,7 +623,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -612,7 +636,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -623,7 +647,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -637,7 +661,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -652,7 +676,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -665,7 +689,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -676,7 +700,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -690,7 +714,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@@ -705,7 +729,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
- id=self._node_execution_id,
+ id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,
diff --git a/api/core/workflow/nodes/base/template.py b/api/core/workflow/nodes/base/template.py
index ba3e2058cf..81f4b9f6fb 100644
--- a/api/core/workflow/nodes/base/template.py
+++ b/api/core/workflow/nodes/base/template.py
@@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
@@ -58,7 +60,7 @@ class Template:
segments: list[TemplateSegmentUnion]
@classmethod
- def from_answer_template(cls, template_str: str) -> "Template":
+ def from_answer_template(cls, template_str: str) -> Template:
"""Create a Template from an Answer node template string.
Example:
@@ -107,7 +109,7 @@ class Template:
return cls(segments=segments)
@classmethod
- def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
+ def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.
diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py
index a38e10030a..e3035d3bf0 100644
--- a/api/core/workflow/nodes/code/code_node.py
+++ b/api/core/workflow/nodes/code/code_node.py
@@ -1,8 +1,7 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
-from typing import Any, cast
+from typing import TYPE_CHECKING, Any, ClassVar, cast
-from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
@@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
+from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@@ -20,9 +20,41 @@ from .exc import (
OutputValidationError,
)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
+ _DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
+ Python3CodeProvider,
+ JavascriptCodeProvider,
+ )
+ _limits: CodeNodeLimits
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ code_executor: type[CodeExecutor] | None = None,
+ code_providers: Sequence[type[CodeNodeProvider]] | None = None,
+ code_limits: CodeNodeLimits,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
+ self._code_providers: tuple[type[CodeNodeProvider], ...] = (
+ tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
+ )
+ self._limits = code_limits
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
- providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
- code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
+ code_provider: type[CodeNodeProvider] = next(
+ provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
+ )
return code_provider.get_default_config()
+ @classmethod
+ def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
+ return cls._DEFAULT_CODE_PROVIDERS
+
@classmethod
def version(cls) -> str:
return "1"
@@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
- result = CodeExecutor.execute_workflow_code_template(
+ _ = self._select_code_provider(code_language)
+ result = self._code_executor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
@@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
+ def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
+ for provider in self._code_providers:
+ if provider.is_accept_language(code_language):
+ return provider
+ raise CodeNodeError(f"Unsupported code language: {code_language}")
+
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
@@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
- if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
+ if len(value) > self._limits.max_string_length:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
- f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
+ f" less than {self._limits.max_string_length} characters"
)
return value.replace("\x00", "")
@@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
- if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
+ if value > self._limits.max_number or value < self._limits.min_number:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
- f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
+ f" it must be between {self._limits.min_number} and {self._limits.max_number}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
- if precision > dify_config.CODE_MAX_PRECISION:
+ if precision > self._limits.max_precision:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
- f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
+ f" it must be less than {self._limits.max_precision} digits."
)
return value
@@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
- if depth > dify_config.CODE_MAX_DEPTH:
- raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
+ if depth > self._limits.max_depth:
+ raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
@@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
- if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
+ if len(value) > self._limits.max_number_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
- f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
+ f" less than {self._limits.max_number_array_length} elements."
)
for i, inner_value in enumerate(value):
@@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
- if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
+ if len(result[output_name]) > self._limits.max_string_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
- f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
+ f" less than {self._limits.max_string_array_length} elements."
)
transformed_result[output_name] = [
@@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
- if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
+ if len(result[output_name]) > self._limits.max_object_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
- f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
+ f" less than {self._limits.max_object_array_length} elements."
)
for i, value in enumerate(result[output_name]):
diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py
index 10a1c897e9..8026011196 100644
--- a/api/core/workflow/nodes/code/entities.py
+++ b/api/core/workflow/nodes/code/entities.py
@@ -1,4 +1,4 @@
-from typing import Annotated, Literal, Self
+from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
@@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
- children: dict[str, Self] | None = None
+ children: dict[str, "CodeNodeData.Output"] | None = None
class Dependency(BaseModel):
name: str
diff --git a/api/core/workflow/nodes/code/limits.py b/api/core/workflow/nodes/code/limits.py
new file mode 100644
index 0000000000..a6b9e9e68e
--- /dev/null
+++ b/api/core/workflow/nodes/code/limits.py
@@ -0,0 +1,13 @@
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class CodeNodeLimits:
+ max_string_length: int
+ max_number: int | float
+ min_number: int | float
+ max_precision: int
+ max_depth: int
+ max_number_array_length: int
+ max_string_array_length: int
+ max_object_array_length: int
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index bb2140f42e..fd71d610b4 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]):
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
+ datasource_type = DatasourceProviderType.value_of(datasource_type)
+
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
- datasource_type=DatasourceProviderType.value_of(datasource_type),
+ datasource_type=datasource_type,
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
@@ -301,7 +303,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
text = ""
files: list[File] = []
- json: list[dict] = []
+ json: list[dict | list] = []
variables: dict[str, Any] = {}
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index f0c84872fb..7de8216562 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -2,7 +2,7 @@ import base64
import json
import secrets
import string
-from collections.abc import Mapping
+from collections.abc import Callable, Mapping
from copy import deepcopy
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
@@ -11,12 +11,13 @@ import httpx
from json_repair import repair_json
from configs import dify_config
-from core.file import file_manager
from core.file.enums import FileTransferMethod
-from core.helper import ssrf_proxy
+from core.file.file_manager import file_manager as default_file_manager
+from core.helper.ssrf_proxy import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
+from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
+ http_client: HttpClientProtocol | None = None,
+ file_manager: FileManagerProtocol | None = None,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@@ -86,6 +89,11 @@ class Executor:
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
+ # Validate that API key is not empty after template conversion
+ if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
+ raise AuthorizationConfigError(
+ "API key is required for authorization but was empty. Please provide a valid API key."
+ )
self.url = node_data.url
self.method = node_data.method
@@ -99,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
+ self._http_client = http_client or ssrf_proxy
+ self._file_manager = file_manager or default_file_manager
# init template
self.variable_pool = variable_pool
@@ -195,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
- self.content = file_manager.download(file)
+ self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@@ -234,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
- file_manager.download(file),
+ self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@@ -326,20 +336,19 @@ class Executor:
"""
do http request depending on api bundle
"""
- _METHOD_MAP = {
- "get": ssrf_proxy.get,
- "head": ssrf_proxy.head,
- "post": ssrf_proxy.post,
- "put": ssrf_proxy.put,
- "delete": ssrf_proxy.delete,
- "patch": ssrf_proxy.patch,
+ _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
+ "get": self._http_client.get,
+ "head": self._http_client.head,
+ "post": self._http_client.post,
+ "put": self._http_client.put,
+ "delete": self._http_client.delete,
+ "patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
- request_args = {
- "url": self.url,
+ request_args: dict[str, Any] = {
"data": self.data,
"files": self.files,
"json": self.json,
@@ -352,10 +361,13 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
- response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
- except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
+ response = _METHOD_MAP[method_lc](
+ url=self.url,
+ **request_args,
+ max_retries=self.max_retries,
+ )
+ except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
- # FIXME: fix type ignore, this maybe httpx type issue
return response
def invoke(self) -> Response:
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index 9bd1cb9761..480482375f 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -1,10 +1,12 @@
import logging
import mimetypes
-from collections.abc import Mapping, Sequence
-from typing import Any
+from collections.abc import Callable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod
+from core.file.file_manager import file_manager as default_file_manager
+from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -13,6 +15,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
+from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@@ -30,10 +33,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ http_client: HttpClientProtocol | None = None,
+ tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
+ file_manager: FileManagerProtocol | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._http_client = http_client or ssrf_proxy
+ self._tool_file_manager_factory = tool_file_manager_factory
+ self._file_manager = file_manager or default_file_manager
+
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@@ -71,6 +99,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
+ http_client=self._http_client,
+ file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@@ -199,7 +229,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
- tool_file_manager = ToolFileManager()
+ tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index e5d86414c1..25a881ea7d 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -1,17 +1,15 @@
-import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
-from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
-from core.variables.variables import VariableUnion
+from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
-from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -51,6 +48,7 @@ from .exc import (
)
if TYPE_CHECKING:
+ from core.workflow.context import IExecutionContext
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@@ -240,7 +238,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
- dict[str, VariableUnion],
+ dict[str, Variable],
LLMUsage,
]
],
@@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._execute_single_iteration_parallel,
index=index,
item=item,
- flask_app=current_app._get_current_object(), # type: ignore
- context_vars=contextvars.copy_context(),
+ execution_context=self._capture_execution_context(),
)
future_to_index[future] = index
@@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self,
index: int,
item: object,
- flask_app: Flask,
- context_vars: contextvars.Context,
- ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
+ execution_context: "IExecutionContext",
+ ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
- with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
+ with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
@@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph_engine.graph_runtime_state.llm_usage,
)
+ def _capture_execution_context(self) -> "IExecutionContext":
+ """Capture current execution context for parallel iterations."""
+ from core.workflow.context import capture_current_context
+
+ return capture_current_context()
+
def _handle_iteration_success(
self,
started_at: datetime,
@@ -395,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return outputs
# Check if all non-None outputs are lists
- non_none_outputs = [output for output in outputs if output is not None]
+ non_none_outputs: list[object] = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs
@@ -515,11 +517,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
- def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
+ def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
- def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
+ def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
@@ -586,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
+ from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
- from core.workflow.graph_engine import GraphEngine
+ from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
- from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
@@ -638,6 +640,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph=iteration_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
+ config=GraphEngineConfig(),
)
return graph_engine
diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py
index 3daca90b9b..bfeb9b5b79 100644
--- a/api/core/workflow/nodes/knowledge_index/entities.py
+++ b/api/core/workflow/nodes/knowledge_index/entities.py
@@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]
+ indexing_technique: str | None = None
+ summary_index_setting: dict | None = None
diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
index 17ca4bef7b..b88c2d510f 100644
--- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
+++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
@@ -1,9 +1,11 @@
+import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
+from flask import current_app
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
-from models.dataset import Dataset, Document, DocumentSegment
+from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
+from services.summary_index_service import SummaryIndexService
+from tasks.generate_summary_index_task import generate_summary_index_task
from .entities import KnowledgeIndexNodeData
from .exc import (
@@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
# index knowledge
try:
if is_preview:
- outputs = self._get_preview_output(node_data.chunk_structure, chunks)
+ # Preview mode: generate summaries for chunks directly without saving to database
+ # Format preview and generate summaries on-the-fly
+ # Get indexing_technique and summary_index_setting from node_data (workflow graph config)
+ # or fallback to dataset if not available in node_data
+ indexing_technique = node_data.indexing_technique or dataset.indexing_technique
+ summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
+
+ outputs = self._get_preview_output_with_summaries(
+ node_data.chunk_structure,
+ chunks,
+ dataset=dataset,
+ indexing_technique=indexing_technique,
+ summary_index_setting=summary_index_setting,
+ )
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
)
.scalar()
)
+ # Update need_summary based on dataset's summary_index_setting
+ if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
+ document.need_summary = True
+ else:
+ document.need_summary = False
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).where(
@@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
db.session.commit()
+ # Generate summary index if enabled
+ self._handle_summary_index_generation(dataset, document, variable_pool)
+
return {
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
@@ -173,9 +198,304 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
"display_status": "completed",
}
- def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
+ def _handle_summary_index_generation(
+ self,
+ dataset: Dataset,
+ document: Document,
+ variable_pool: VariablePool,
+ ) -> None:
+ """
+ Handle summary index generation based on mode (debug/preview or production).
+
+ Args:
+ dataset: Dataset containing the document
+ document: Document to generate summaries for
+ variable_pool: Variable pool to check invoke_from
+ """
+ # Only generate summary index for high_quality indexing technique
+ if dataset.indexing_technique != "high_quality":
+ return
+
+ # Check if summary index is enabled
+ summary_index_setting = dataset.summary_index_setting
+ if not summary_index_setting or not summary_index_setting.get("enable"):
+ return
+
+ # Skip qa_model documents
+ if document.doc_form == "qa_model":
+ return
+
+ # Determine if in preview/debug mode
+ invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
+ is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
+
+ if is_preview:
+ try:
+ # Query segments that need summary generation
+ query = db.session.query(DocumentSegment).filter_by(
+ dataset_id=dataset.id,
+ document_id=document.id,
+ status="completed",
+ enabled=True,
+ )
+ segments = query.all()
+
+ if not segments:
+ logger.info("No segments found for document %s", document.id)
+ return
+
+ # Filter segments based on mode
+ segments_to_process = []
+ for segment in segments:
+ # Skip if summary already exists
+ existing_summary = (
+ db.session.query(DocumentSegmentSummary)
+ .filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
+ .first()
+ )
+ if existing_summary:
+ continue
+
+ # For parent-child mode, all segments are parent chunks, so process all
+ segments_to_process.append(segment)
+
+ if not segments_to_process:
+ logger.info("No segments need summary generation for document %s", document.id)
+ return
+
+ # Use ThreadPoolExecutor for concurrent generation
+ flask_app = current_app._get_current_object() # type: ignore
+ max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
+
+ def process_segment(segment: DocumentSegment) -> None:
+ """Process a single segment in a thread with Flask app context."""
+ with flask_app.app_context():
+ try:
+ SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
+ except Exception:
+ logger.exception(
+ "Failed to generate summary for segment %s",
+ segment.id,
+ )
+ # Continue processing other segments
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
+ # Wait for all tasks to complete
+ concurrent.futures.wait(futures)
+
+ logger.info(
+ "Successfully generated summary index for %s segments in document %s",
+ len(segments_to_process),
+ document.id,
+ )
+ except Exception:
+ logger.exception("Failed to generate summary index for document %s", document.id)
+ # Don't fail the entire indexing process if summary generation fails
+ else:
+ # Production mode: asynchronous generation
+ logger.info(
+ "Queuing summary index generation task for document %s (production mode)",
+ document.id,
+ )
+ try:
+ generate_summary_index_task.delay(dataset.id, document.id, None)
+ logger.info("Summary index generation task queued for document %s", document.id)
+ except Exception:
+ logger.exception(
+ "Failed to queue summary index generation task for document %s",
+ document.id,
+ )
+ # Don't fail the entire indexing process if task queuing fails
+
+ def _get_preview_output_with_summaries(
+ self,
+ chunk_structure: str,
+ chunks: Any,
+ dataset: Dataset,
+ indexing_technique: str | None = None,
+ summary_index_setting: dict | None = None,
+ ) -> Mapping[str, Any]:
+ """
+ Generate preview output with summaries for chunks in preview mode.
+ This method generates summaries on-the-fly without saving to database.
+
+ Args:
+ chunk_structure: Chunk structure type
+ chunks: Chunks to generate preview for
+ dataset: Dataset object (for tenant_id)
+ indexing_technique: Indexing technique from node config or dataset
+ summary_index_setting: Summary index setting from node config or dataset
+ """
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
- return index_processor.format_preview(chunks)
+ preview_output = index_processor.format_preview(chunks)
+
+ # Check if summary index is enabled
+ if indexing_technique != "high_quality":
+ return preview_output
+
+ if not summary_index_setting or not summary_index_setting.get("enable"):
+ return preview_output
+
+ # Generate summaries for chunks
+ if "preview" in preview_output and isinstance(preview_output["preview"], list):
+ chunk_count = len(preview_output["preview"])
+ logger.info(
+ "Generating summaries for %s chunks in preview mode (dataset: %s)",
+ chunk_count,
+ dataset.id,
+ )
+ # Use ParagraphIndexProcessor's generate_summary method
+ from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
+
+ # Get Flask app for application context in worker threads
+ flask_app = None
+ try:
+ flask_app = current_app._get_current_object() # type: ignore
+ except RuntimeError:
+ logger.warning("No Flask application context available, summary generation may fail")
+
+ def generate_summary_for_chunk(preview_item: dict) -> None:
+ """Generate summary for a single chunk."""
+ if "content" in preview_item:
+ # Set Flask application context in worker thread
+ if flask_app:
+ with flask_app.app_context():
+ summary, _ = ParagraphIndexProcessor.generate_summary(
+ tenant_id=dataset.tenant_id,
+ text=preview_item["content"],
+ summary_index_setting=summary_index_setting,
+ )
+ if summary:
+ preview_item["summary"] = summary
+ else:
+ # Fallback: try without app context (may fail)
+ summary, _ = ParagraphIndexProcessor.generate_summary(
+ tenant_id=dataset.tenant_id,
+ text=preview_item["content"],
+ summary_index_setting=summary_index_setting,
+ )
+ if summary:
+ preview_item["summary"] = summary
+
+ # Generate summaries concurrently using ThreadPoolExecutor
+ # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
+ timeout_seconds = min(300, 60 * len(preview_output["preview"]))
+ errors: list[Exception] = []
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
+ futures = [
+ executor.submit(generate_summary_for_chunk, preview_item)
+ for preview_item in preview_output["preview"]
+ ]
+ # Wait for all tasks to complete with timeout
+ done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
+
+ # Cancel tasks that didn't complete in time
+ if not_done:
+ timeout_error_msg = (
+ f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
+ )
+ logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
+ # In preview mode, timeout is also an error
+ errors.append(TimeoutError(timeout_error_msg))
+ for future in not_done:
+ future.cancel()
+ # Wait a bit for cancellation to take effect
+ concurrent.futures.wait(not_done, timeout=5)
+
+ # Collect exceptions from completed futures
+ for future in done:
+ try:
+ future.result() # This will raise any exception that occurred
+ except Exception as e:
+ logger.exception("Error in summary generation future")
+ errors.append(e)
+
+ # In preview mode, if there are any errors, fail the request
+ if errors:
+ error_messages = [str(e) for e in errors]
+ error_summary = (
+ f"Failed to generate summaries for {len(errors)} chunk(s). "
+ f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
+ )
+ if len(errors) > 3:
+ error_summary += f" (and {len(errors) - 3} more)"
+ logger.error("Summary generation failed in preview mode: %s", error_summary)
+ raise KnowledgeIndexNodeError(error_summary)
+
+ completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
+ logger.info(
+ "Completed summary generation for preview chunks: %s/%s succeeded",
+ completed_count,
+ len(preview_output["preview"]),
+ )
+
+ return preview_output
+
+ def _get_preview_output(
+ self,
+ chunk_structure: str,
+ chunks: Any,
+ dataset: Dataset | None = None,
+ variable_pool: VariablePool | None = None,
+ ) -> Mapping[str, Any]:
+ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
+ preview_output = index_processor.format_preview(chunks)
+
+ # If dataset is provided, try to enrich preview with summaries
+ if dataset and variable_pool:
+ document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
+ if document_id:
+ document = db.session.query(Document).filter_by(id=document_id.value).first()
+ if document:
+ # Query summaries for this document
+ summaries = (
+ db.session.query(DocumentSegmentSummary)
+ .filter_by(
+ dataset_id=dataset.id,
+ document_id=document.id,
+ status="completed",
+ enabled=True,
+ )
+ .all()
+ )
+
+ if summaries:
+ # Create a map of segment content to summary for matching
+ # Use content matching as chunks in preview might not be indexed yet
+ summary_by_content = {}
+ for summary in summaries:
+ segment = (
+ db.session.query(DocumentSegment)
+ .filter_by(id=summary.chunk_id, dataset_id=dataset.id)
+ .first()
+ )
+ if segment:
+ # Normalize content for matching (strip whitespace)
+ normalized_content = segment.content.strip()
+ summary_by_content[normalized_content] = summary.summary_content
+
+ # Enrich preview with summaries by content matching
+ if "preview" in preview_output and isinstance(preview_output["preview"], list):
+ matched_count = 0
+ for preview_item in preview_output["preview"]:
+ if "content" in preview_item:
+ # Normalize content for matching
+ normalized_chunk_content = preview_item["content"].strip()
+ if normalized_chunk_content in summary_by_content:
+ preview_item["summary"] = summary_by_content[normalized_chunk_content]
+ matched_count += 1
+
+ if matched_count > 0:
+ logger.info(
+ "Enriched preview with %s existing summaries (dataset: %s, document: %s)",
+ matched_count,
+ dataset.id,
+ document.id,
+ )
+
+ return preview_output
@classmethod
def version(cls) -> str:
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index adc474bd60..3c4850ebac 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
-from sqlalchemy import and_, func, literal, or_, select
+from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
+ # Add summary if available
+ if record.summary:
+ source["summary"] = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
@@ -460,7 +463,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
- self._process_metadata_filter_func(
+ DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
@@ -504,7 +507,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
value=expected_value,
)
)
- filters = self._process_metadata_filter_func(
+ filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
@@ -603,87 +606,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return [], usage
return automatic_metadata_filters, usage
- def _process_metadata_filter_func(
- self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
- ) -> list[Any]:
- if value is None and condition not in ("empty", "not empty"):
- return filters
-
- json_field = Document.doc_metadata[metadata_name].as_string()
-
- match condition:
- case "contains":
- filters.append(json_field.like(f"%{value}%"))
-
- case "not contains":
- filters.append(json_field.notlike(f"%{value}%"))
-
- case "start with":
- filters.append(json_field.like(f"{value}%"))
-
- case "end with":
- filters.append(json_field.like(f"%{value}"))
- case "in":
- if isinstance(value, str):
- value_list = [v.strip() for v in value.split(",") if v.strip()]
- elif isinstance(value, (list, tuple)):
- value_list = [str(v) for v in value if v is not None]
- else:
- value_list = [str(value)] if value is not None else []
-
- if not value_list:
- filters.append(literal(False))
- else:
- filters.append(json_field.in_(value_list))
-
- case "not in":
- if isinstance(value, str):
- value_list = [v.strip() for v in value.split(",") if v.strip()]
- elif isinstance(value, (list, tuple)):
- value_list = [str(v) for v in value if v is not None]
- else:
- value_list = [str(value)] if value is not None else []
-
- if not value_list:
- filters.append(literal(True))
- else:
- filters.append(json_field.notin_(value_list))
-
- case "is" | "=":
- if isinstance(value, str):
- filters.append(json_field == value)
- elif isinstance(value, (int, float)):
- filters.append(Document.doc_metadata[metadata_name].as_float() == value)
-
- case "is not" | "≠":
- if isinstance(value, str):
- filters.append(json_field != value)
- elif isinstance(value, (int, float)):
- filters.append(Document.doc_metadata[metadata_name].as_float() != value)
-
- case "empty":
- filters.append(Document.doc_metadata[metadata_name].is_(None))
-
- case "not empty":
- filters.append(Document.doc_metadata[metadata_name].isnot(None))
-
- case "before" | "<":
- filters.append(Document.doc_metadata[metadata_name].as_float() < value)
-
- case "after" | ">":
- filters.append(Document.doc_metadata[metadata_name].as_float() > value)
-
- case "≤" | "<=":
- filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
-
- case "≥" | ">=":
- filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
-
- case _:
- pass
-
- return filters
-
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py
index 813d898b9a..235f5b9c52 100644
--- a/api/core/workflow/nodes/list_operator/node.py
+++ b/api/core/workflow/nodes/list_operator/node.py
@@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "name":
return lambda x: x.filename or ""
case "type":
- return lambda x: x.type
+ return lambda x: str(x.type)
case "extension":
return lambda x: x.extension or ""
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
- return lambda x: x.transfer_method
+ return lambda x: str(x.transfer_method)
case "url":
return lambda x: x.remote_url or ""
case "related_id":
@@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
- extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
- extract_func = _get_file_extract_number_func(key=key)
- return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
+ extract_number = _get_file_extract_number_func(key=key)
+ return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
else:
raise InvalidKeyError(f"Invalid key: {key}")
diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py
index 0c545469bc..01e25cbf5c 100644
--- a/api/core/workflow/nodes/llm/llm_utils.py
+++ b/api/core/workflow/nodes/llm/llm_utils.py
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.entities.provider_entities import QuotaUnit
+from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@@ -136,21 +136,37 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
- with Session(db.engine) as session:
- stmt = (
- update(Provider)
- .where(
- Provider.tenant_id == tenant_id,
- # TODO: Use provider name with prefix after the data migration.
- Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
- Provider.provider_type == ProviderType.SYSTEM,
- Provider.quota_type == system_configuration.current_quota_type.value,
- Provider.quota_limit > Provider.quota_used,
- )
- .values(
- quota_used=Provider.quota_used + used_quota,
- last_used=naive_utc_now(),
- )
+ if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
+ tenant_id=tenant_id,
+ credits_required=used_quota,
)
- session.execute(stmt)
- session.commit()
+ elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
+ tenant_id=tenant_id,
+ credits_required=used_quota,
+ pool_type="paid",
+ )
+ else:
+ with Session(db.engine) as session:
+ stmt = (
+ update(Provider)
+ .where(
+ Provider.tenant_id == tenant_id,
+ # TODO: Use provider name with prefix after the data migration.
+ Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
+ Provider.provider_type == ProviderType.SYSTEM.value,
+ Provider.quota_type == system_configuration.current_quota_type.value,
+ Provider.quota_limit > Provider.quota_used,
+ )
+ .values(
+ quota_used=Provider.quota_used + used_quota,
+ last_used=naive_utc_now(),
+ )
+ )
+ session.execute(stmt)
+ session.commit()
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index 04e2802191..beccf79344 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import io
import json
@@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]):
# Instance attributes specific to LLMNode.
# Output variable for file
- _file_outputs: list["File"]
+ _file_outputs: list[File]
_llm_file_saver: LLMFileSaver
@@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]):
self,
id: str,
config: Mapping[str, Any],
- graph_init_params: "GraphInitParams",
- graph_runtime_state: "GraphRuntimeState",
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]):
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]):
)
@staticmethod
- def _image_file_to_markdown(file: "File", /):
+ def _image_file_to_markdown(file: File, /):
text_chunk = f"})"
return text_chunk
@@ -683,6 +685,8 @@ class LLMNode(Node[LLMNodeData]):
if "content" not in item:
raise InvalidContextStructureError(f"Invalid context structure: {item}")
+ if item.get("summary"):
+ context_str += item["summary"] + "\n"
context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item)
@@ -744,6 +748,7 @@ class LLMNode(Node[LLMNodeData]):
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
+ summary=context_dict.get("summary"),
)
return source
@@ -774,7 +779,7 @@ class LLMNode(Node[LLMNodeData]):
def fetch_prompt_messages(
*,
sys_query: str | None = None,
- sys_files: Sequence["File"],
+ sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@@ -785,7 +790,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
- context_files: list["File"] | None = None,
+ context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -847,18 +852,16 @@ class LLMNode(Node[LLMNodeData]):
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
- prompt_content_type = type(prompt_content)
- if prompt_content_type == str:
+ if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
- elif prompt_content_type == list:
- prompt_content = prompt_content if isinstance(prompt_content, list) else []
+ elif isinstance(prompt_content, list):
for content_item in prompt_content:
- if content_item.type == PromptMessageContentType.TEXT:
+ if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
@@ -868,13 +871,12 @@ class LLMNode(Node[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
- if prompt_content_type == str:
+ if isinstance(prompt_content, str):
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
- elif prompt_content_type == list:
- prompt_content = prompt_content if isinstance(prompt_content, list) else []
+ elif isinstance(prompt_content, list):
for content_item in prompt_content:
- if content_item.type == PromptMessageContentType.TEXT:
+ if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
@@ -1028,14 +1030,14 @@ class LLMNode(Node[LLMNodeData]):
if typed_node_data.prompt_config:
enable_jinja = False
- if isinstance(prompt_template, list):
+ if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
+ if prompt_template.edition_type == "jinja2":
+ enable_jinja = True
+ else:
for prompt in prompt_template:
if prompt.edition_type == "jinja2":
enable_jinja = True
break
- else:
- if prompt_template.edition_type == "jinja2":
- enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
@@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_latency: float | None = None,
) -> ModelInvokeCompletedEvent:
@@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]):
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
- ) -> "File":
+ ) -> File:
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]):
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
- file_outputs: list["File"],
+ file_outputs: list[File],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py
index 4fcad888e4..92a8702fc3 100644
--- a/api/core/workflow/nodes/loop/entities.py
+++ b/api/core/workflow/nodes/loop/entities.py
@@ -1,3 +1,4 @@
+from enum import StrEnum
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
@@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
Get current output.
"""
return self.current_output
+
+
+class LoopCompletedReason(StrEnum):
+ LOOP_BREAK = "loop_break"
+ LOOP_COMPLETED = "loop_completed"
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
index 1c26bbc2d0..84a9c29414 100644
--- a/api/core/workflow/nodes/loop/loop_node.py
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -29,7 +29,7 @@ from core.workflow.node_events import (
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
-from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
+from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
@@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
+ loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
# Start Loop event
yield LoopStartedEvent(
@@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_count = 0
for i in range(loop_count):
+ # Clear stale variables from previous loop iterations to avoid streaming old values
+ self._clear_loop_subgraph_variables(loop_node_ids)
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
@@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
- "completed_reason": "loop_break" if reach_break_condition else "loop_completed",
+ WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
+ LoopCompletedReason.LOOP_BREAK
+ if reach_break_condition
+ else LoopCompletedReason.LOOP_COMPLETED.value
+ ),
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
@@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
+ def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
+ """
+ Remove variables produced by loop sub-graph nodes from previous iterations.
+
+ Keeping stale variables causes a freshly created response coordinator in the
+ next iteration to fall back to outdated values when no stream chunks exist.
+ """
+ variable_pool = self.graph_runtime_state.variable_pool
+ for node_id in loop_node_ids:
+ variable_pool.remove([node_id])
+
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@@ -395,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
+ from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
- from core.workflow.graph_engine import GraphEngine
+ from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
- from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
@@ -434,6 +452,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
graph=loop_graph,
graph_runtime_state=graph_runtime_state_copy,
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
+ config=GraphEngineConfig(),
)
return graph_engine
diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py
deleted file mode 100644
index c55ad346bf..0000000000
--- a/api/core/workflow/nodes/node_factory.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from typing import TYPE_CHECKING, final
-
-from typing_extensions import override
-
-from core.workflow.enums import NodeType
-from core.workflow.graph import NodeFactory
-from core.workflow.nodes.base.node import Node
-from libs.typing import is_str, is_str_dict
-
-from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
-
-if TYPE_CHECKING:
- from core.workflow.entities import GraphInitParams
- from core.workflow.runtime import GraphRuntimeState
-
-
-@final
-class DifyNodeFactory(NodeFactory):
- """
- Default implementation of NodeFactory that uses the traditional node mapping.
-
- This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
- and instantiating the appropriate node class.
- """
-
- def __init__(
- self,
- graph_init_params: "GraphInitParams",
- graph_runtime_state: "GraphRuntimeState",
- ) -> None:
- self.graph_init_params = graph_init_params
- self.graph_runtime_state = graph_runtime_state
-
- @override
- def create_node(self, node_config: dict[str, object]) -> Node:
- """
- Create a Node instance from node configuration data using the traditional mapping.
-
- :param node_config: node configuration dictionary containing type and other data
- :return: initialized Node instance
- :raises ValueError: if node type is unknown or configuration is invalid
- """
- # Get node_id from config
- node_id = node_config.get("id")
- if not is_str(node_id):
- raise ValueError("Node config missing id")
-
- # Get node type from config
- node_data = node_config.get("data", {})
- if not is_str_dict(node_data):
- raise ValueError(f"Node {node_id} missing data information")
-
- node_type_str = node_data.get("type")
- if not is_str(node_type_str):
- raise ValueError(f"Node {node_id} missing or invalid type information")
-
- try:
- node_type = NodeType(node_type_str)
- except ValueError:
- raise ValueError(f"Unknown node type: {node_type_str}")
-
- # Get node class
- node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
- if not node_mapping:
- raise ValueError(f"No class mapping found for node type: {node_type}")
-
- latest_node_class = node_mapping.get(LATEST_VERSION)
- node_version = str(node_data.get("version", "1"))
- matched_node_class = node_mapping.get(node_version)
- node_class = matched_node_class or latest_node_class
- if not node_class:
- raise ValueError(f"No latest version class found for node type: {node_type}")
-
- # Create node instance
- return node_class(
- id=node_id,
- config=node_config,
- graph_init_params=self.graph_init_params,
- graph_runtime_state=self.graph_runtime_state,
- )
diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
index 93db417b15..08e0542d61 100644
--- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
+++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
@@ -281,7 +281,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# handle invoke result
- text = invoke_result.message.content or ""
+ text = invoke_result.message.get_text_content()
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py
new file mode 100644
index 0000000000..2ad39e0ab5
--- /dev/null
+++ b/api/core/workflow/nodes/protocols.py
@@ -0,0 +1,29 @@
+from typing import Any, Protocol
+
+import httpx
+
+from core.file import File
+
+
+class HttpClientProtocol(Protocol):
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]: ...
+
+ @property
+ def request_error(self) -> type[Exception]: ...
+
+ def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
+
+ def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
+
+ def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
+
+ def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
+
+ def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
+
+ def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
+
+
+class FileManagerProtocol(Protocol):
+ def download(self, f: File, /) -> bytes: ...
diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py
index 38effa79f7..53c1b4ee6b 100644
--- a/api/core/workflow/nodes/start/start_node.py
+++ b/api/core/workflow/nodes/start/start_node.py
@@ -42,9 +42,17 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
- if not isinstance(value, dict):
- raise ValueError(f"{key} must be a JSON object")
+ # If no value provided, skip further processing for this key
+ if not value:
+ continue
+ if not isinstance(value, dict):
+ raise ValueError(f"JSON object for '{key}' must be an object")
+
+ # Overwrite with normalized dict to ensure downstream consistency
+ node_inputs[key] = value
+
+ # If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
@@ -53,4 +61,3 @@ class StartNode(Node[StartNodeData]):
Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
- node_inputs[key] = value
diff --git a/api/core/workflow/nodes/template_transform/template_renderer.py b/api/core/workflow/nodes/template_transform/template_renderer.py
new file mode 100644
index 0000000000..a5f06bf2bb
--- /dev/null
+++ b/api/core/workflow/nodes/template_transform/template_renderer.py
@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+from collections.abc import Mapping
+from typing import Any, Protocol
+
+from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
+
+
+class TemplateRenderError(ValueError):
+ """Raised when rendering a Jinja2 template fails."""
+
+
+class Jinja2TemplateRenderer(Protocol):
+ """Render Jinja2 templates for template transform nodes."""
+
+ def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
+ """Render a Jinja2 template with provided variables."""
+ raise NotImplementedError
+
+
+class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
+ """Adapter that renders Jinja2 templates via CodeExecutor."""
+
+ _code_executor: type[CodeExecutor]
+
+ def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
+ self._code_executor = code_executor or CodeExecutor
+
+ def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
+ try:
+ result = self._code_executor.execute_workflow_code_template(
+ language=CodeLanguage.JINJA2, code=template, inputs=variables
+ )
+ except CodeExecutionError as exc:
+ raise TemplateRenderError(str(exc)) from exc
+
+ rendered = result.get("result")
+ if not isinstance(rendered, str):
+ raise TemplateRenderError("Template render result must be a string.")
+ return rendered
diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py
index 2274323960..f7e0bccccf 100644
--- a/api/core/workflow/nodes/template_transform/template_transform_node.py
+++ b/api/core/workflow/nodes/template_transform/template_transform_node.py
@@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
-from typing import Any
+from typing import TYPE_CHECKING, Any
from configs import dify_config
-from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
+from core.workflow.nodes.template_transform.template_renderer import (
+ CodeExecutorJinja2TemplateRenderer,
+ Jinja2TemplateRenderer,
+ TemplateRenderError,
+)
+
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
+ _template_renderer: Jinja2TemplateRenderer
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ *,
+ template_renderer: Jinja2TemplateRenderer | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
- result = CodeExecutor.execute_workflow_code_template(
- language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
- )
- except CodeExecutionError as e:
+ rendered = self._template_renderer.render_template(self.node_data.template, variables)
+ except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
- if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
+ if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
- status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
+ status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod
diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py
index c1cfbb1edc..8fe33c240a 100644
--- a/api/core/workflow/nodes/tool/entities.py
+++ b/api/core/workflow/nodes/tool/entities.py
@@ -54,8 +54,8 @@ class ToolNodeData(BaseNodeData, ToolEntity):
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
- elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
- raise ValueError("value must be a string, int, float, bool or dict")
+ elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
+ raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ
tool_parameters: dict[str, ToolInput]
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 2e7ec757b4..68ac60e4f6 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
text = ""
files: list[File] = []
- json: list[dict] = []
+ json: list[dict | list] = []
variables: dict[str, Any] = {}
@@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
message.message.metadata = dict_metadata
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
- json_output: list[dict[str, Any]] = []
+ json_output: list[dict[str, Any] | list[Any]] = []
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py
index 3631c8653d..ec8c4b8ee3 100644
--- a/api/core/workflow/nodes/trigger_webhook/node.py
+++ b/api/core/workflow/nodes/trigger_webhook/node.py
@@ -1,14 +1,22 @@
+import logging
from collections.abc import Mapping
from typing import Any
+from core.file import FileTransferMethod
+from core.variables.types import SegmentType
+from core.variables.variables import FileVariable
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
+from factories import file_factory
+from factories.variable_factory import build_segment_with_type
from .entities import ContentType, WebhookData
+logger = logging.getLogger(__name__)
+
class TriggerWebhookNode(Node[WebhookData]):
node_type = NodeType.TRIGGER_WEBHOOK
@@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs=outputs,
)
+ def generate_file_var(self, param_name: str, file: dict):
+ related_id = file.get("related_id")
+ transfer_method_value = file.get("transfer_method")
+ if transfer_method_value:
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
+ match transfer_method:
+ case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL:
+ file["upload_file_id"] = related_id
+ case FileTransferMethod.TOOL_FILE:
+ file["tool_file_id"] = related_id
+ case FileTransferMethod.DATASOURCE_FILE:
+ file["datasource_file_id"] = related_id
+
+ try:
+ file_obj = file_factory.build_from_mapping(
+ mapping=file,
+ tenant_id=self.tenant_id,
+ )
+ file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
+ return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
+ except ValueError:
+ logger.error(
+ "Failed to build FileVariable for webhook file parameter %s",
+ param_name,
+ exc_info=True,
+ )
+ return None
+
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
"""Extract outputs based on node configuration from webhook inputs."""
outputs = {}
@@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
- outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
+ raw_data: dict = webhook_data.get("body", {}).get("raw", {})
+ file_var = self.generate_file_var(param_name, raw_data)
+ if file_var:
+ outputs[param_name] = file_var
+ else:
+ outputs[param_name] = raw_data
continue
if param_type == "file":
# Get File object (already processed by webhook controller)
- file_obj = webhook_data.get("files", {}).get(param_name)
- outputs[param_name] = file_obj
+ files = webhook_data.get("files", {})
+ if files and isinstance(files, dict):
+ file = files.get(param_name)
+ if file and isinstance(file, dict):
+ file_var = self.generate_file_var(param_name, file)
+ if file_var:
+ outputs[param_name] = file_var
+ else:
+ outputs[param_name] = files
+ else:
+ outputs[param_name] = files
+ else:
+ outputs[param_name] = files
else:
# Get regular body parameter
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
# Include raw webhook data for debugging/advanced use
outputs["_webhook_raw"] = webhook_data
-
return outputs
diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py
deleted file mode 100644
index 050e213535..0000000000
--- a/api/core/workflow/nodes/variable_assigner/common/impl.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from sqlalchemy import select
-from sqlalchemy.orm import Session
-
-from core.variables.variables import Variable
-from extensions.ext_database import db
-from models import ConversationVariable
-
-from .exc import VariableOperatorNodeError
-
-
-class ConversationVariableUpdaterImpl:
- def update(self, conversation_id: str, variable: Variable):
- stmt = select(ConversationVariable).where(
- ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
- )
- with Session(db.engine) as session:
- row = session.scalar(stmt)
- if not row:
- raise VariableOperatorNodeError("conversation variable not found in the database")
- row.data = variable.model_dump_json()
- session.commit()
-
- def flush(self):
- pass
-
-
-def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
- return ConversationVariableUpdaterImpl()
diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py
index da23207b62..9f5818f4bb 100644
--- a/api/core/workflow/nodes/variable_assigner/v1/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v1/node.py
@@ -1,9 +1,8 @@
-from collections.abc import Callable, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, TypeAlias
+from collections.abc import Mapping, Sequence
+from typing import TYPE_CHECKING, Any
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
-from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@@ -11,19 +10,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
-from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
-_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
-
-
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
- _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
@@ -31,7 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
):
super().__init__(
id=id,
@@ -39,7 +32,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
- self._conv_var_updater_factory = conv_var_updater_factory
+
+ def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
+ """
+ Check if this Variable Assigner node blocks the output of specific variables.
+
+ Returns True if this node updates any of the requested conversation variables.
+ """
+ assigned_selector = tuple(self.node_data.assigned_variable_selector)
+ return assigned_selector in variable_selectors
@classmethod
def version(cls) -> str:
@@ -72,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
- if not isinstance(original_variable, Variable):
+ if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
@@ -96,16 +97,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
- # TODO: Move database operation to the pipeline.
- # Update conversation variable.
- conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
- if not conversation_id:
- raise VariableOperatorNodeError("conversation_id not found")
- conv_var_updater = self._conv_var_updater_factory()
- conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
- conv_var_updater.flush()
updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
-
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index 389fb54d35..5857702e72 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -1,24 +1,20 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
-from typing import Any, cast
+from typing import TYPE_CHECKING, Any
-from core.app.entities.app_invoke_entities import InvokeFrom
-from core.variables import SegmentType, Variable
+from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
-from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
-from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
- ConversationIDNotFoundError,
InputTypeNotSupportedError,
InvalidDataError,
InvalidInputValueError,
@@ -26,6 +22,10 @@ from .exc import (
VariableNotFoundError,
)
+if TYPE_CHECKING:
+ from core.workflow.entities import GraphInitParams
+ from core.workflow.runtime import GraphRuntimeState
+
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@@ -53,6 +53,20 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph_runtime_state: "GraphRuntimeState",
+ ):
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@@ -70,9 +84,6 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return False
- def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
- return conversation_variable_updater_factory()
-
@classmethod
def version(cls) -> str:
return "2"
@@ -107,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@@ -179,26 +190,12 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
- conv_var_updater = self._conv_var_updater_factory()
- # Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
- if not isinstance(variable, Variable):
+ if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
- if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
- conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
- if not conversation_id:
- if self.invoke_from != InvokeFrom.DEBUGGER:
- raise ConversationIDNotFoundError
- else:
- conversation_id = conversation_id.value
- conv_var_updater.update(
- conversation_id=cast(str, conversation_id),
- variable=variable,
- )
- conv_var_updater.flush()
updated_variables = [
common_helpers.variable_to_processed_data(selector, seg)
for selector in updated_variable_selectors
@@ -216,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
- variable: Variable,
+ variable: VariableBase,
operation: Operation,
value: Any,
):
diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/core/workflow/repositories/draft_variable_repository.py
index 97bfcd5666..66ef714c16 100644
--- a/api/core/workflow/repositories/draft_variable_repository.py
+++ b/api/core/workflow/repositories/draft_variable_repository.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import abc
from collections.abc import Mapping
from typing import Any, Protocol
@@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
- ) -> "DraftVariableSaver":
+ ) -> DraftVariableSaver:
pass
diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py
index 1561b789df..acf0ee6839 100644
--- a/api/core/workflow/runtime/graph_runtime_state.py
+++ b/api/core/workflow/runtime/graph_runtime_state.py
@@ -2,15 +2,17 @@ from __future__ import annotations
import importlib
import json
+import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
-from typing import Any, Protocol
+from typing import Any, ClassVar, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
+from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.runtime.variable_pool import VariablePool
@@ -102,14 +104,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
...
+class NodeProtocol(Protocol):
+ """Structural interface for graph nodes."""
+
+ id: str
+ state: NodeState
+ execution_type: NodeExecutionType
+ node_type: ClassVar[NodeType]
+
+ def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ...
+
+
+class EdgeProtocol(Protocol):
+ id: str
+ state: NodeState
+ tail: str
+ head: str
+ source_handle: str
+
+
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
- nodes: Mapping[str, object]
- edges: Mapping[str, object]
- root_node: object
+ nodes: Mapping[str, NodeProtocol]
+ edges: Mapping[str, EdgeProtocol]
+ root_node: NodeProtocol
- def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
+ def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
@dataclass(slots=True)
@@ -168,6 +189,7 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
+ self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)
diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py
index 5e0878e873..bfbb5ba704 100644
--- a/api/core/workflow/runtime/graph_runtime_state_protocol.py
+++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py
@@ -1,4 +1,4 @@
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -9,7 +9,7 @@ from core.workflow.system_variable import SystemVariableReadOnlyView
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
- def get(self, node_id: str, variable_key: str) -> Segment | None:
+ def get(self, selector: Sequence[str], /) -> Segment | None:
"""Get a variable value (read-only)."""
...
diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py
index 8539727fd6..d3e4c60d9b 100644
--- a/api/core/workflow/runtime/read_only_wrappers.py
+++ b/api/core/workflow/runtime/read_only_wrappers.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any
@@ -18,9 +18,9 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
- def get(self, node_id: str, variable_key: str) -> Segment | None:
+ def get(self, selector: Sequence[str], /) -> Segment | None:
"""Return a copy of a variable value if present."""
- value = self._variable_pool.get([node_id, variable_key])
+ value = self._variable_pool.get(selector)
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py
index 7fbaec9e70..c4b077fa69 100644
--- a/api/core/workflow/runtime/variable_pool.py
+++ b/api/core/workflow/runtime/variable_pool.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
@@ -7,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
-from core.variables import Segment, SegmentGroup, Variable
+from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
-from core.variables.variables import RAGPipelineVariableInput, VariableUnion
+from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@@ -30,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
- variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
+ variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@@ -42,15 +44,15 @@ class VariablePool(BaseModel):
)
system_variables: SystemVariable = Field(
description="System variables",
- default_factory=SystemVariable.empty,
+ default_factory=SystemVariable.default,
)
- environment_variables: Sequence[VariableUnion] = Field(
+ environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
- conversation_variables: Sequence[VariableUnion] = Field(
+ conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
- default_factory=list[VariableUnion],
+ default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@@ -103,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@@ -112,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
@@ -267,6 +269,6 @@ class VariablePool(BaseModel):
self.add(selector, value)
@classmethod
- def empty(cls) -> "VariablePool":
+ def empty(cls) -> VariablePool:
"""Create an empty variable pool."""
- return cls(system_variables=SystemVariable.empty())
+ return cls(system_variables=SystemVariable.default())
diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py
index ad925912a4..6946e3e6ab 100644
--- a/api/core/workflow/system_variable.py
+++ b/api/core/workflow/system_variable.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
+from uuid import uuid4
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
@@ -70,8 +73,8 @@ class SystemVariable(BaseModel):
return data
@classmethod
- def empty(cls) -> "SystemVariable":
- return cls()
+ def default(cls) -> SystemVariable:
+ return cls(workflow_execution_id=str(uuid4()))
def to_dict(self) -> dict[SystemVariableKey, Any]:
# NOTE: This method is provided for compatibility with legacy code.
@@ -114,7 +117,7 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.TIMESTAMP] = self.timestamp
return d
- def as_view(self) -> "SystemVariableReadOnlyView":
+ def as_view(self) -> SystemVariableReadOnlyView:
return SystemVariableReadOnlyView(self)
diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py
index ea0bdc3537..7992785fe1 100644
--- a/api/core/workflow/variable_loader.py
+++ b/api/core/workflow/variable_loader.py
@@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
-from core.variables import Variable
+from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
- :return: a list of Variable objects that match the provided selectors.
+ :return: a list of VariableBase objects that match the provided selectors.
"""
pass
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index d4ec29518a..4b1845cda2 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -7,12 +7,14 @@ from typing import Any
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.workflow.layers.observability import ObservabilityLayer
+from core.app.workflow.node_factory import DifyNodeFactory
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
-from core.workflow.graph_engine import GraphEngine
+from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
@@ -23,6 +25,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
+from extensions.otel.runtime import is_instrument_flag_enabled
from factories import file_factory
from models.enums import UserFrom
from models.workflow import Workflow
@@ -78,6 +81,12 @@ class WorkflowEntry:
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
+ config=GraphEngineConfig(
+ min_workers=dify_config.GRAPH_ENGINE_MIN_WORKERS,
+ max_workers=dify_config.GRAPH_ENGINE_MAX_WORKERS,
+ scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD,
+ scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
+ ),
)
# Add debug logging layer when in debug mode
@@ -98,6 +107,10 @@ class WorkflowEntry:
)
self.graph_engine.layer(limits_layer)
+ # Add observability layer when OTel is enabled
+ if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
+ self.graph_engine.layer(ObservabilityLayer())
+
def run(self) -> Generator[GraphEngineEvent, None, None]:
graph_engine = self.graph_engine
@@ -132,12 +145,10 @@ class WorkflowEntry:
:return:
"""
node_config = workflow.get_node_config_by_id(node_id)
- node_config_data = node_config.get("data", {})
+ node_config_data = node_config["data"]
- # Get node class
- node_type = NodeType(node_config_data.get("type"))
- node_version = node_config_data.get("version", "1")
- node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
+ # Get node type
+ node_type = NodeType(node_config_data["type"])
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -153,12 +164,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
- node = node_cls(
- id=str(uuid.uuid4()),
- config=node_config,
+ node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
+ node = node_factory.create_node(node_config)
+ node_cls = type(node)
try:
# variable selector to variable mapping
@@ -185,8 +196,7 @@ class WorkflowEntry:
)
try:
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
@@ -273,7 +283,7 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs={},
environment_variables=[],
)
@@ -319,8 +329,7 @@ class WorkflowEntry:
tenant_id=tenant_id,
)
- # run node
- generator = node.run()
+ generator = cls._traced_node_run(node)
return node, generator
except Exception as e:
@@ -426,3 +435,26 @@ class WorkflowEntry:
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)
+
+ @staticmethod
+ def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]:
+ """
+ Wraps a node's run method with OpenTelemetry tracing and returns a generator.
+ """
+ # Wrap node.run() with ObservabilityLayer hooks to produce node-level spans
+ layer = ObservabilityLayer()
+ layer.on_graph_start()
+ node.ensure_execution_id()
+
+ def _gen():
+ error: Exception | None = None
+ layer.on_node_run_start(node)
+ try:
+ yield from node.run()
+ except Exception as exc:
+ error = exc
+ raise
+ finally:
+ layer.on_node_run_end(node, error)
+
+ return _gen()
diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh
index 6313085e64..c0279f893b 100755
--- a/api/docker/entrypoint.sh
+++ b/api/docker/entrypoint.sh
@@ -3,8 +3,9 @@
set -e
# Set UTF-8 encoding to address potential encoding issues in containerized environments
-export LANG=${LANG:-en_US.UTF-8}
-export LC_ALL=${LC_ALL:-en_US.UTF-8}
+# Use C.UTF-8 which is universally available in all containers
+export LANG=${LANG:-C.UTF-8}
+export LC_ALL=${LC_ALL:-C.UTF-8}
export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8}
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
@@ -34,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
- DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
+ DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
- DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
+ DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"
@@ -69,6 +70,53 @@ if [[ "${MODE}" == "worker" ]]; then
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
+
+elif [[ "${MODE}" == "job" ]]; then
+ # Job mode: Run a one-time Flask command and exit
+ # Pass Flask command and arguments via container args
+ # Example K8s usage:
+ # args:
+ # - create-tenant
+ # - --email
+ # - admin@example.com
+ #
+ # Example Docker usage:
+ # docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com
+
+ if [[ $# -eq 0 ]]; then
+ echo "Error: No command specified for job mode."
+ echo ""
+ echo "Usage examples:"
+ echo " Kubernetes:"
+ echo " args: [create-tenant, --email, admin@example.com]"
+ echo ""
+ echo " Docker:"
+ echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com"
+ echo ""
+ echo "Available commands:"
+ echo " create-tenant, reset-password, reset-email, upgrade-db,"
+ echo " vdb-migrate, install-plugins, and more..."
+ echo ""
+ echo "Run 'flask --help' to see all available commands."
+ exit 1
+ fi
+
+ echo "Running Flask job command: flask $*"
+
+ # Temporarily disable exit on error to capture exit code
+ set +e
+ flask "$@"
+ JOB_EXIT_CODE=$?
+ set -e
+
+ if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then
+ echo "Job completed successfully."
+ else
+ echo "Job failed with exit code ${JOB_EXIT_CODE}."
+ fi
+
+ exit ${JOB_EXIT_CODE}
+
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
diff --git a/api/enums/hosted_provider.py b/api/enums/hosted_provider.py
new file mode 100644
index 0000000000..c6d3715dc1
--- /dev/null
+++ b/api/enums/hosted_provider.py
@@ -0,0 +1,21 @@
+from enum import StrEnum
+
+
+class HostedTrialProvider(StrEnum):
+ """
+ Enum representing hosted model provider names for trial access.
+ """
+
+ OPENAI = "langgenius/openai/openai"
+ ANTHROPIC = "langgenius/anthropic/anthropic"
+ GEMINI = "langgenius/gemini/google"
+ X = "langgenius/x/x"
+ DEEPSEEK = "langgenius/deepseek/deepseek"
+ TONGYI = "langgenius/tongyi/tongyi"
+
+ @property
+ def config_key(self) -> str:
+ """Return the config key used in dify_config (e.g., HOSTED_{config_key}_PAID_ENABLED)."""
+ if self == HostedTrialProvider.X:
+ return "XAI"
+ return self.name
diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py
index c79764983b..d37217e168 100644
--- a/api/events/event_handlers/__init__.py
+++ b/api/events/event_handlers/__init__.py
@@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
)
+from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
@@ -30,6 +31,7 @@ __all__ = [
"handle_create_installed_app_when_app_created",
"handle_create_site_record_when_app_created",
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
+ "handle_queue_credential_sync_when_tenant_created",
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",
diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py
index 1666e2e29f..d6007662d8 100644
--- a/api/events/event_handlers/clean_when_dataset_deleted.py
+++ b/api/events/event_handlers/clean_when_dataset_deleted.py
@@ -15,4 +15,5 @@ def handle(sender: Dataset, **kwargs):
dataset.index_struct,
dataset.collection_binding_id,
dataset.doc_form,
+ dataset.pipeline_id,
)
diff --git a/api/events/event_handlers/queue_credential_sync_when_tenant_created.py b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py
new file mode 100644
index 0000000000..6566c214b0
--- /dev/null
+++ b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py
@@ -0,0 +1,19 @@
+from configs import dify_config
+from events.tenant_event import tenant_was_created
+from services.enterprise.workspace_sync import WorkspaceSyncService
+
+
+@tenant_was_created.connect
+def handle(sender, **kwargs):
+ """Queue credential sync when a tenant/workspace is created."""
+ # Only queue sync tasks if plugin manager (enterprise feature) is enabled
+ if not dify_config.ENTERPRISE_ENABLED:
+ return
+
+ tenant = sender
+
+ # Determine source from kwargs if available, otherwise use generic
+ source = kwargs.get("source", "tenant_created")
+
+ # Queue credential sync task to Redis for enterprise backend to process
+ WorkspaceSyncService.queue_credential_sync(tenant.id, source=source)
diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py
index 84266ab0fa..1ddcc8f792 100644
--- a/api/events/event_handlers/update_provider_when_message_created.py
+++ b/api/events/event_handlers/update_provider_when_message_created.py
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
-from core.entities.provider_entities import QuotaUnit, SystemConfiguration
+from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@@ -134,22 +134,38 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
-
if used_quota is not None:
- quota_update = _ProviderUpdateOperation(
- filters=_ProviderUpdateFilters(
+ if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
- provider_name=ModelProviderID(model_config.provider).provider_name,
- provider_type=ProviderType.SYSTEM,
- quota_type=provider_configuration.system_configuration.current_quota_type.value,
- ),
- values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
- additional_filters=_ProviderUpdateAdditionalFilters(
- quota_limit_check=True # Provider.quota_limit > Provider.quota_used
- ),
- description="quota_deduction_update",
- )
- updates_to_perform.append(quota_update)
+ credits_required=used_quota,
+ pool_type="trial",
+ )
+ elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.check_and_deduct_credits(
+ tenant_id=tenant_id,
+ credits_required=used_quota,
+ pool_type="paid",
+ )
+ else:
+ quota_update = _ProviderUpdateOperation(
+ filters=_ProviderUpdateFilters(
+ tenant_id=tenant_id,
+ provider_name=ModelProviderID(model_config.provider).provider_name,
+ provider_type=ProviderType.SYSTEM.value,
+ quota_type=provider_configuration.system_configuration.current_quota_type.value,
+ ),
+ values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
+ additional_filters=_ProviderUpdateAdditionalFilters(
+ quota_limit_check=True # Provider.quota_limit > Provider.quota_used
+ ),
+ description="quota_deduction_update",
+ )
+ updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()
diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py
index 725e5351e6..7d13f0c061 100644
--- a/api/extensions/ext_blueprints.py
+++ b/api/extensions/ext_blueprints.py
@@ -6,14 +6,25 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
+EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
+def _apply_cors_once(bp, /, **cors_kwargs):
+ """Make CORS idempotent so blueprints can be reused across multiple app instances."""
+
+ if getattr(bp, "_dify_cors_applied", False):
+ return
+
+ from flask_cors import CORS
+
+ CORS(bp, **cors_kwargs)
+ bp._dify_cors_applied = True
+
+
def init_app(app: DifyApp):
# register blueprint routers
- from flask_cors import CORS
-
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
@@ -22,7 +33,7 @@ def init_app(app: DifyApp):
from controllers.trigger import bp as trigger_bp
from controllers.web import bp as web_bp
- CORS(
+ _apply_cors_once(
service_api_bp,
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
@@ -30,17 +41,35 @@ def init_app(app: DifyApp):
)
app.register_blueprint(service_api_bp)
- CORS(
+ _apply_cors_once(
web_bp,
- resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
- supports_credentials=True,
- allow_headers=list(AUTHENTICATED_HEADERS),
- methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ resources={
+ # Embedded bot endpoints (unauthenticated, cross-origin safe)
+ r"^/chat-messages$": {
+ "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
+ "supports_credentials": False,
+ "allow_headers": list(EMBED_HEADERS),
+ "methods": ["GET", "POST", "OPTIONS"],
+ },
+ r"^/chat-messages/.*": {
+ "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
+ "supports_credentials": False,
+ "allow_headers": list(EMBED_HEADERS),
+ "methods": ["GET", "POST", "OPTIONS"],
+ },
+ # Default web application endpoints (authenticated)
+ r"/*": {
+ "origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS,
+ "supports_credentials": True,
+ "allow_headers": list(AUTHENTICATED_HEADERS),
+ "methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ },
+ },
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(web_bp)
- CORS(
+ _apply_cors_once(
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
@@ -50,7 +79,7 @@ def init_app(app: DifyApp):
)
app.register_blueprint(console_app_bp)
- CORS(
+ _apply_cors_once(
files_bp,
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
@@ -62,7 +91,7 @@ def init_app(app: DifyApp):
app.register_blueprint(mcp_bp)
# Register trigger blueprint with CORS for webhook calls
- CORS(
+ _apply_cors_once(
trigger_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 5cf4984709..af983f6d87 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -12,9 +12,8 @@ from dify_app import DifyApp
def _get_celery_ssl_options() -> dict[str, Any] | None:
"""Get SSL configuration for Celery broker/backend connections."""
- # Use REDIS_USE_SSL for consistency with the main Redis client
# Only apply SSL if we're using Redis as broker/backend
- if not dify_config.REDIS_USE_SSL:
+ if not dify_config.BROKER_USE_SSL:
return None
# Check if Celery is actually using Redis
@@ -47,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
+ from core.logging.context import init_request_context
+
with app.app_context():
+ # Initialize logging context for this task (similar to before_request in Flask)
+ init_request_context()
return self.run(*args, **kwargs)
broker_transport_options = {}
@@ -99,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
imports = [
"tasks.async_workflow_tasks", # trigger workers
"tasks.trigger_processing_tasks", # async trigger processing
+ "tasks.generate_summary_index_task", # summary index generation
+ "tasks.regenerate_summary_index_task", # summary index regeneration
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
@@ -160,6 +165,13 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
"schedule": crontab(minute="0", hour="2"),
}
+ if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK:
+ # for saas only
+ imports.append("schedule.clean_workflow_runs_task")
+ beat_schedule["clean_workflow_runs_task"] = {
+ "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
+ "schedule": crontab(minute="0", hour="0"),
+ }
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {
diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py
index 71a63168a5..46885761a1 100644
--- a/api/extensions/ext_commands.py
+++ b/api/extensions/ext_commands.py
@@ -4,13 +4,18 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
+ archive_workflow_runs,
+ clean_expired_messages,
+ clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
convert_to_agent_apps,
create_tenant,
+ delete_archived_workflow_runs,
extract_plugins,
extract_unique_plugins,
+ file_usage,
fix_app_site_missing,
install_plugins,
install_rag_pipeline_plugins,
@@ -21,6 +26,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
+ restore_workflow_runs,
setup_datasource_oauth_client,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
@@ -47,6 +53,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
+ file_usage,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
cleanup_orphaned_draft_variables,
@@ -54,6 +61,11 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
+ archive_workflow_runs,
+ delete_archived_workflow_runs,
+ restore_workflow_runs,
+ clean_workflow_runs,
+ clean_expired_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)
diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py
index c90b1d0a9f..2e0d4c889a 100644
--- a/api/extensions/ext_database.py
+++ b/api/extensions/ext_database.py
@@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
def init_app(app: DifyApp):
db.init_app(app)
_setup_gevent_compatibility()
+
+ # Eagerly build the engine so pool_size/max_overflow/etc. come from config
+ try:
+ with app.app_context():
+ _ = db.engine # triggers engine creation with the configured options
+ except Exception:
+ logger.exception("Failed to initialize SQLAlchemy engine during app startup")
diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py
new file mode 100644
index 0000000000..ab4d23a072
--- /dev/null
+++ b/api/extensions/ext_fastopenapi.py
@@ -0,0 +1,48 @@
+from fastopenapi.routers import FlaskRouter
+from flask_cors import CORS
+
+from configs import dify_config
+from controllers.fastopenapi import console_router
+from dify_app import DifyApp
+from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
+
+DOCS_PREFIX = "/fastopenapi"
+
+
+def init_app(app: DifyApp) -> None:
+ docs_enabled = dify_config.SWAGGER_UI_ENABLED
+ docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
+ redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None
+ openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None
+
+ router = FlaskRouter(
+ app=app,
+ docs_url=docs_url,
+ redoc_url=redoc_url,
+ openapi_url=openapi_url,
+ openapi_version="3.0.0",
+ title="Dify API (FastOpenAPI PoC)",
+ version="1.0",
+ description="FastOpenAPI proof of concept for Dify API",
+ )
+
+ # Ensure route decorators are evaluated.
+ import controllers.console.init_validate as init_validate_module
+ import controllers.console.ping as ping_module
+ from controllers.console import remote_files, setup
+
+ _ = init_validate_module
+ _ = ping_module
+ _ = remote_files
+ _ = setup
+
+ router.include_router(console_router, prefix="/console/api")
+ CORS(
+ app,
+ resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
+ supports_credentials=True,
+ allow_headers=list(AUTHENTICATED_HEADERS),
+ methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+ expose_headers=list(EXPOSED_HEADERS),
+ )
+ app.extensions["fastopenapi"] = router
diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py
index 000d03ac41..978a40c503 100644
--- a/api/extensions/ext_logging.py
+++ b/api/extensions/ext_logging.py
@@ -1,18 +1,19 @@
+"""Logging extension for Dify Flask application."""
+
import logging
import os
import sys
-import uuid
from logging.handlers import RotatingFileHandler
-import flask
-
from configs import dify_config
-from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
+ """Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
+
+ # File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@@ -25,27 +26,53 @@ def init_app(app: DifyApp):
)
)
- # Always add StreamHandler to log to console
+ # Console handler
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
- # Apply RequestIdFilter to all handlers
- for handler in log_handlers:
- handler.addFilter(RequestIdFilter())
+ # Apply filters to all handlers
+ from core.logging.filters import IdentityContextFilter, TraceContextFilter
+ for handler in log_handlers:
+ handler.addFilter(TraceContextFilter())
+ handler.addFilter(IdentityContextFilter())
+
+ # Configure formatter based on format type
+ formatter = _create_formatter()
+ for handler in log_handlers:
+ handler.setFormatter(formatter)
+
+ # Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
- format=dify_config.LOG_FORMAT,
- datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
- # Apply RequestIdFormatter to all handlers
- apply_request_id_formatter()
-
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
+
+ # Apply timezone if specified (only for text format)
+ if dify_config.LOG_OUTPUT_FORMAT == "text":
+ _apply_timezone(log_handlers)
+
+
+def _create_formatter() -> logging.Formatter:
+ """Create appropriate formatter based on configuration."""
+ if dify_config.LOG_OUTPUT_FORMAT == "json":
+ from core.logging.structured_formatter import StructuredJSONFormatter
+
+ return StructuredJSONFormatter()
+ else:
+ # Text format - use existing pattern with backward compatible formatter
+ return _TextFormatter(
+ fmt=dify_config.LOG_FORMAT,
+ datefmt=dify_config.LOG_DATEFORMAT,
+ )
+
+
+def _apply_timezone(handlers: list[logging.Handler]):
+ """Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@@ -57,34 +84,51 @@ def init_app(app: DifyApp):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
- for handler in logging.root.handlers:
+ for handler in handlers:
if handler.formatter:
- handler.formatter.converter = time_converter
+ handler.formatter.converter = time_converter # type: ignore[attr-defined]
-def get_request_id():
- if getattr(flask.g, "request_id", None):
- return flask.g.request_id
+class _TextFormatter(logging.Formatter):
+ """Text formatter that ensures trace_id and req_id are always present."""
- new_uuid = uuid.uuid4().hex[:10]
- flask.g.request_id = new_uuid
-
- return new_uuid
+ def format(self, record: logging.LogRecord) -> str:
+ if not hasattr(record, "req_id"):
+ record.req_id = ""
+ if not hasattr(record, "trace_id"):
+ record.trace_id = ""
+ if not hasattr(record, "span_id"):
+ record.span_id = ""
+ return super().format(record)
+def get_request_id() -> str:
+ """Get request ID for current request context.
+
+ Deprecated: Use core.logging.context.get_request_id() directly.
+ """
+ from core.logging.context import get_request_id as _get_request_id
+
+ return _get_request_id()
+
+
+# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
- # This is a logging filter that makes the request ID available for use in
- # the logging format. Note that we're checking if we're in a request
- # context, as we may want to log things before Flask is fully loaded.
- def filter(self, record):
- trace_id = get_trace_id_from_otel_context() or ""
- record.req_id = get_request_id() if flask.has_request_context() else ""
- record.trace_id = trace_id
+ """Deprecated: Use TraceContextFilter from core.logging.filters instead."""
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ from core.logging.context import get_request_id as _get_request_id
+ from core.logging.context import get_trace_id as _get_trace_id
+
+ record.req_id = _get_request_id()
+ record.trace_id = _get_trace_id()
return True
class RequestIdFormatter(logging.Formatter):
- def format(self, record):
+ """Deprecated: Use _TextFormatter instead."""
+
+ def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
+ """Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py
new file mode 100644
index 0000000000..cda2d1ad1e
--- /dev/null
+++ b/api/extensions/ext_logstore.py
@@ -0,0 +1,94 @@
+"""
+Logstore extension for Dify application.
+
+This extension initializes the logstore (Aliyun SLS) on application startup,
+creating necessary projects, logstores, and indexes if they don't exist.
+"""
+
+import logging
+import os
+
+from dotenv import load_dotenv
+
+from configs import dify_config
+from dify_app import DifyApp
+
+logger = logging.getLogger(__name__)
+
+
+def is_enabled() -> bool:
+ """
+ Check if logstore extension is enabled.
+
+ Logstore is considered enabled when:
+ 1. All required Aliyun SLS environment variables are set
+ 2. At least one repository configuration points to a logstore implementation
+
+ Returns:
+ True if logstore should be initialized, False otherwise
+ """
+ # Load environment variables from .env file
+ load_dotenv()
+
+ # Check if Aliyun SLS connection parameters are configured
+ required_vars = [
+ "ALIYUN_SLS_ACCESS_KEY_ID",
+ "ALIYUN_SLS_ACCESS_KEY_SECRET",
+ "ALIYUN_SLS_ENDPOINT",
+ "ALIYUN_SLS_REGION",
+ "ALIYUN_SLS_PROJECT_NAME",
+ ]
+
+ sls_vars_set = all(os.environ.get(var) for var in required_vars)
+
+ if not sls_vars_set:
+ return False
+
+ # Check if any repository configuration points to logstore implementation
+ repository_configs = [
+ dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
+ dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
+ dify_config.API_WORKFLOW_RUN_REPOSITORY,
+ ]
+
+ uses_logstore = any("logstore" in config.lower() for config in repository_configs)
+
+ if not uses_logstore:
+ return False
+
+ logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
+ return True
+
+
+def init_app(app: DifyApp):
+ """
+ Initialize logstore on application startup.
+ If initialization fails, the application continues running without logstore features.
+
+ Args:
+ app: The Dify application instance
+ """
+ try:
+ from extensions.logstore.aliyun_logstore import AliyunLogStore
+
+ logger.info("Initializing Aliyun SLS Logstore...")
+
+ # Create logstore client and initialize resources
+ logstore_client = AliyunLogStore()
+ logstore_client.init_project_logstore()
+
+ app.extensions["logstore"] = logstore_client
+
+ logger.info("Logstore initialized successfully")
+
+ except Exception:
+ logger.exception(
+ "Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
+ "Application will continue but logstore features will NOT work.",
+ os.environ.get("ALIYUN_SLS_ENDPOINT"),
+ os.environ.get("ALIYUN_SLS_REGION"),
+ os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
+ os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
+ )
+ # Don't raise - allow application to continue even if logstore setup fails
diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py
new file mode 100644
index 0000000000..0eb43d66f4
--- /dev/null
+++ b/api/extensions/ext_session_factory.py
@@ -0,0 +1,7 @@
+from core.db.session_factory import configure_session_factory
+from extensions.ext_database import db
+
+
+def init_app(app):
+ with app.app_context():
+ configure_session_factory(db.engine)
diff --git a/api/extensions/logstore/__init__.py b/api/extensions/logstore/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py
new file mode 100644
index 0000000000..f6a4765f14
--- /dev/null
+++ b/api/extensions/logstore/aliyun_logstore.py
@@ -0,0 +1,928 @@
+from __future__ import annotations
+
+import logging
+import os
+import socket
+import threading
+import time
+from collections.abc import Sequence
+from typing import Any
+
+import sqlalchemy as sa
+from aliyun.log import ( # type: ignore[import-untyped]
+ GetLogsRequest,
+ IndexConfig,
+ IndexKeyConfig,
+ IndexLineConfig,
+ LogClient,
+ LogItem,
+ PutLogsRequest,
+)
+from aliyun.log.auth import AUTH_VERSION_4 # type: ignore[import-untyped]
+from aliyun.log.logexception import LogException # type: ignore[import-untyped]
+from dotenv import load_dotenv
+from sqlalchemy.orm import DeclarativeBase
+
+from configs import dify_config
+from extensions.logstore.aliyun_logstore_pg import AliyunLogStorePG
+
+logger = logging.getLogger(__name__)
+
+
+class AliyunLogStore:
+ """
+ Singleton class for Aliyun SLS LogStore operations.
+
+ Ensures only one instance exists to prevent multiple PG connection pools.
+ """
+
+ _instance: AliyunLogStore | None = None
+ _initialized: bool = False
+
+ # Track delayed PG connection for newly created projects
+ _pg_connection_timer: threading.Timer | None = None
+ _pg_connection_delay: int = 90 # delay seconds
+
+ # Default tokenizer for text/json fields and full-text index
+ # Common delimiters: comma, space, quotes, punctuation, operators, brackets, special chars
+ DEFAULT_TOKEN_LIST = [
+ ",",
+ " ",
+ '"',
+ '"',
+ ";",
+ "=",
+ "(",
+ ")",
+ "[",
+ "]",
+ "{",
+ "}",
+ "?",
+ "@",
+ "&",
+ "<",
+ ">",
+ "/",
+ ":",
+ "\n",
+ "\t",
+ ]
+
+ def __new__(cls) -> AliyunLogStore:
+ """Implement singleton pattern."""
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+ project_des = "dify"
+
+ workflow_execution_logstore = "workflow_execution"
+
+ workflow_node_execution_logstore = "workflow_node_execution"
+
+ @staticmethod
+ def _sqlalchemy_type_to_logstore_type(column: Any) -> str:
+ """
+ Map SQLAlchemy column type to Aliyun LogStore index type.
+
+ Args:
+ column: SQLAlchemy column object
+
+ Returns:
+ LogStore index type: 'text', 'long', 'double', or 'json'
+ """
+ column_type = column.type
+
+ # Integer types -> long
+ if isinstance(column_type, (sa.Integer, sa.BigInteger, sa.SmallInteger)):
+ return "long"
+
+ # Float types -> double
+ if isinstance(column_type, (sa.Float, sa.Numeric)):
+ return "double"
+
+ # String and Text types -> text
+ if isinstance(column_type, (sa.String, sa.Text)):
+ return "text"
+
+ # DateTime -> text (stored as ISO format string in logstore)
+ if isinstance(column_type, sa.DateTime):
+ return "text"
+
+ # Boolean -> long (stored as 0/1)
+ if isinstance(column_type, sa.Boolean):
+ return "long"
+
+ # JSON -> json
+ if isinstance(column_type, sa.JSON):
+ return "json"
+
+ # Default to text for unknown types
+ return "text"
+
+ @staticmethod
+ def _generate_index_keys_from_model(model_class: type[DeclarativeBase]) -> dict[str, IndexKeyConfig]:
+ """
+ Automatically generate LogStore field index configuration from SQLAlchemy model.
+
+ This method introspects the SQLAlchemy model's column definitions and creates
+ corresponding LogStore index configurations. When the PG schema is updated via
+ Flask-Migrate, this method will automatically pick up the new fields on next startup.
+
+ Args:
+ model_class: SQLAlchemy model class (e.g., WorkflowRun, WorkflowNodeExecutionModel)
+
+ Returns:
+ Dictionary mapping field names to IndexKeyConfig objects
+ """
+ index_keys = {}
+
+ # Iterate over all mapped columns in the model
+ if hasattr(model_class, "__mapper__"):
+ for column_name, column_property in model_class.__mapper__.columns.items():
+ # Skip relationship properties and other non-column attributes
+ if not hasattr(column_property, "type"):
+ continue
+
+ # Map SQLAlchemy type to LogStore type
+ logstore_type = AliyunLogStore._sqlalchemy_type_to_logstore_type(column_property)
+
+ # Create index configuration
+ # - text fields: case_insensitive for better search, with tokenizer and Chinese support
+ # - all fields: doc_value=True for analytics
+ if logstore_type == "text":
+ index_keys[column_name] = IndexKeyConfig(
+ index_type="text",
+ case_sensitive=False,
+ doc_value=True,
+ token_list=AliyunLogStore.DEFAULT_TOKEN_LIST,
+ chinese=True,
+ )
+ else:
+ index_keys[column_name] = IndexKeyConfig(index_type=logstore_type, doc_value=True)
+
+ # Add log_version field (not in PG model, but used in logstore for versioning)
+ index_keys["log_version"] = IndexKeyConfig(index_type="long", doc_value=True)
+
+ return index_keys
+
+ def __init__(self) -> None:
+ # Skip initialization if already initialized (singleton pattern)
+ if self.__class__._initialized:
+ return
+
+ load_dotenv()
+
+ self.access_key_id: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_ID", "")
+ self.access_key_secret: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_SECRET", "")
+ self.endpoint: str = os.environ.get("ALIYUN_SLS_ENDPOINT", "")
+ self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
+ self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
+ self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
+ self.log_enabled: bool = (
+ os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
+ or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
+ )
+ self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
+
+ # Get timeout configuration
+ check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
+
+ # Pre-check endpoint connectivity to prevent indefinite hangs
+ self._check_endpoint_connectivity(self.endpoint, check_timeout)
+
+ # Initialize SDK client
+ self.client = LogClient(
+ self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
+ )
+
+ # Append Dify identification to the existing user agent
+ original_user_agent = self.client._user_agent # pyright: ignore[reportPrivateUsage]
+ dify_version = dify_config.project.version
+ enhanced_user_agent = f"Dify,Dify-{dify_version},{original_user_agent}"
+ self.client.set_user_agent(enhanced_user_agent)
+
+ # PG client will be initialized in init_project_logstore
+ self._pg_client: AliyunLogStorePG | None = None
+ self._use_pg_protocol: bool = False
+
+ self.__class__._initialized = True
+
+ @staticmethod
+ def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
+ """
+ Check if the SLS endpoint is reachable before creating LogClient.
+ Prevents indefinite hangs when the endpoint is unreachable.
+
+ Args:
+ endpoint: SLS endpoint URL
+ timeout: Connection timeout in seconds
+
+ Raises:
+ ConnectionError: If endpoint is not reachable
+ """
+ # Parse endpoint URL to extract hostname and port
+ from urllib.parse import urlparse
+
+ parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
+ hostname = parsed_url.hostname
+ port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
+
+ if not hostname:
+ raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
+
+ sock = None
+ try:
+ # Create socket and set timeout
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(timeout)
+ sock.connect((hostname, port))
+ except Exception as e:
+ # Catch all exceptions and provide clear error message
+ error_type = type(e).__name__
+ raise ConnectionError(
+ f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
+ ) from e
+ finally:
+ # Ensure socket is properly closed
+ if sock:
+ try:
+ sock.close()
+ except Exception: # noqa: S110
+ pass # Ignore errors during cleanup
+
+ @property
+ def supports_pg_protocol(self) -> bool:
+ """Check if PG protocol is supported and enabled."""
+ return self._use_pg_protocol
+
+ def _attempt_pg_connection_init(self) -> bool:
+ """
+ Attempt to initialize PG connection.
+
+ This method tries to establish PG connection and performs necessary checks.
+ It's used both for immediate connection (existing projects) and delayed connection (new projects).
+
+ Returns:
+ True if PG connection was successfully established, False otherwise.
+ """
+ if not self.pg_mode_enabled or not self._pg_client:
+ return False
+
+ try:
+ self._use_pg_protocol = self._pg_client.init_connection()
+ if self._use_pg_protocol:
+ logger.info("Using PG protocol for project %s", self.project_name)
+ # Check if scan_index is enabled for all logstores
+ self._check_and_disable_pg_if_scan_index_disabled()
+ return True
+ else:
+ logger.info("Using SDK mode for project %s", self.project_name)
+ return False
+ except Exception as e:
+ logger.info("Using SDK mode for project %s", self.project_name)
+ logger.debug("PG connection details: %s", str(e))
+ self._use_pg_protocol = False
+ return False
+
+ def _delayed_pg_connection_init(self) -> None:
+ """
+ Delayed initialization of PG connection for newly created projects.
+
+ This method is called by a background timer 3 minutes after project creation.
+ """
+ # Double check conditions in case state changed
+ if self._use_pg_protocol:
+ return
+
+ self._attempt_pg_connection_init()
+ self.__class__._pg_connection_timer = None
+
+ def init_project_logstore(self):
+ """
+ Initialize project, logstore, index, and PG connection.
+
+ This method should be called once during application startup to ensure
+ all required resources exist and connections are established.
+ """
+ # Step 1: Ensure project and logstore exist
+ project_is_new = False
+ if not self.is_project_exist():
+ self.create_project()
+ project_is_new = True
+
+ self.create_logstore_if_not_exist()
+
+ # Step 2: Initialize PG client and connection (if enabled)
+ if not self.pg_mode_enabled:
+ logger.info("PG mode is disabled. Will use SDK mode.")
+ return
+
+ # Create PG client if not already created
+ if self._pg_client is None:
+ logger.info("Initializing PG client for project %s...", self.project_name)
+ self._pg_client = AliyunLogStorePG(
+ self.access_key_id, self.access_key_secret, self.endpoint, self.project_name
+ )
+
+ # Step 3: Establish PG connection based on project status
+ if project_is_new:
+ # For newly created projects, schedule delayed PG connection
+ self._use_pg_protocol = False
+ logger.info("Using SDK mode for project %s (newly created)", self.project_name)
+ if self.__class__._pg_connection_timer is not None:
+ self.__class__._pg_connection_timer.cancel()
+ self.__class__._pg_connection_timer = threading.Timer(
+ self.__class__._pg_connection_delay,
+ self._delayed_pg_connection_init,
+ )
+ self.__class__._pg_connection_timer.daemon = True # Don't block app shutdown
+ self.__class__._pg_connection_timer.start()
+ else:
+ # For existing projects, attempt PG connection immediately
+ self._attempt_pg_connection_init()
+
+ def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
+ """
+ Check if scan_index is enabled for all logstores.
+ If any logstore has scan_index=false, disable PG protocol.
+
+ This is necessary because PG protocol requires scan_index to be enabled.
+ """
+ logstore_name_list = [
+ AliyunLogStore.workflow_execution_logstore,
+ AliyunLogStore.workflow_node_execution_logstore,
+ ]
+
+ for logstore_name in logstore_name_list:
+ existing_config = self.get_existing_index_config(logstore_name)
+ if existing_config and not existing_config.scan_index:
+ logger.info(
+ "Logstore %s requires scan_index enabled, using SDK mode for project %s",
+ logstore_name,
+ self.project_name,
+ )
+ self._use_pg_protocol = False
+ # Close PG connection if it was initialized
+ if self._pg_client:
+ self._pg_client.close()
+ self._pg_client = None
+ return
+
+ def is_project_exist(self) -> bool:
+ try:
+ self.client.get_project(self.project_name)
+ return True
+ except Exception as e:
+ if e.args[0] == "ProjectNotExist":
+ return False
+ else:
+ raise e
+
+ def create_project(self):
+ try:
+ self.client.create_project(self.project_name, AliyunLogStore.project_des)
+ logger.info("Project %s created successfully", self.project_name)
+ except LogException as e:
+ logger.exception(
+ "Failed to create project %s: errorCode=%s, errorMessage=%s, requestId=%s",
+ self.project_name,
+ e.get_error_code(),
+ e.get_error_message(),
+ e.get_request_id(),
+ )
+ raise
+
+ def is_logstore_exist(self, logstore_name: str) -> bool:
+ try:
+ _ = self.client.get_logstore(self.project_name, logstore_name)
+ return True
+ except Exception as e:
+ if e.args[0] == "LogStoreNotExist":
+ return False
+ else:
+ raise e
+
+ def create_logstore_if_not_exist(self) -> None:
+ logstore_name_list = [
+ AliyunLogStore.workflow_execution_logstore,
+ AliyunLogStore.workflow_node_execution_logstore,
+ ]
+
+ for logstore_name in logstore_name_list:
+ if not self.is_logstore_exist(logstore_name):
+ try:
+ self.client.create_logstore(
+ project_name=self.project_name, logstore_name=logstore_name, ttl=self.logstore_ttl
+ )
+ logger.info("logstore %s created successfully", logstore_name)
+ except LogException as e:
+ logger.exception(
+ "Failed to create logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
+ logstore_name,
+ e.get_error_code(),
+ e.get_error_message(),
+ e.get_request_id(),
+ )
+ raise
+
+ # Ensure index contains all Dify-required fields
+ # This intelligently merges with existing config, preserving custom indexes
+ self.ensure_index_config(logstore_name)
+
+ def is_index_exist(self, logstore_name: str) -> bool:
+ try:
+ _ = self.client.get_index_config(self.project_name, logstore_name)
+ return True
+ except Exception as e:
+ if e.args[0] == "IndexConfigNotExist":
+ return False
+ else:
+ raise e
+
+ def get_existing_index_config(self, logstore_name: str) -> IndexConfig | None:
+ """
+ Get existing index configuration from logstore.
+
+ Args:
+ logstore_name: Name of the logstore
+
+ Returns:
+ IndexConfig object if index exists, None otherwise
+ """
+ try:
+ response = self.client.get_index_config(self.project_name, logstore_name)
+ return response.get_index_config()
+ except Exception as e:
+ if e.args[0] == "IndexConfigNotExist":
+ return None
+ else:
+ logger.exception("Failed to get index config for logstore %s", logstore_name)
+ raise e
+
+ def _get_workflow_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
+ """
+ Get field index configuration for workflow_execution logstore.
+
+ This method automatically generates index configuration from the WorkflowRun SQLAlchemy model.
+ When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
+ updated on next application startup.
+ """
+ from models.workflow import WorkflowRun
+
+ index_keys = self._generate_index_keys_from_model(WorkflowRun)
+
+ # Add custom fields that are in logstore but not in PG model
+ # These fields are added by the repository layer
+ index_keys["error_message"] = IndexKeyConfig(
+ index_type="text",
+ case_sensitive=False,
+ doc_value=True,
+ token_list=self.DEFAULT_TOKEN_LIST,
+ chinese=True,
+ ) # Maps to 'error' in PG
+ index_keys["started_at"] = IndexKeyConfig(
+ index_type="text",
+ case_sensitive=False,
+ doc_value=True,
+ token_list=self.DEFAULT_TOKEN_LIST,
+ chinese=True,
+ ) # Maps to 'created_at' in PG
+
+ logger.info("Generated %d index keys for workflow_execution from WorkflowRun model", len(index_keys))
+ return index_keys
+
+ def _get_workflow_node_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
+ """
+ Get field index configuration for workflow_node_execution logstore.
+
+ This method automatically generates index configuration from the WorkflowNodeExecutionModel.
+ When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
+ updated on next application startup.
+ """
+ from models.workflow import WorkflowNodeExecutionModel
+
+ index_keys = self._generate_index_keys_from_model(WorkflowNodeExecutionModel)
+
+ logger.debug(
+ "Generated %d index keys for workflow_node_execution from WorkflowNodeExecutionModel", len(index_keys)
+ )
+ return index_keys
+
+ def _get_index_config(self, logstore_name: str) -> IndexConfig:
+ """
+ Get index configuration for the specified logstore.
+
+ Args:
+ logstore_name: Name of the logstore
+
+ Returns:
+ IndexConfig object with line and field indexes
+ """
+ # Create full-text index (line config) with tokenizer
+ line_config = IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True)
+
+ # Get field index configuration based on logstore name
+ field_keys = {}
+ if logstore_name == AliyunLogStore.workflow_execution_logstore:
+ field_keys = self._get_workflow_execution_index_keys()
+ elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
+ field_keys = self._get_workflow_node_execution_index_keys()
+
+ # key_config_list should be a dict, not a list
+ # Create index config with both line and field indexes
+ return IndexConfig(line_config=line_config, key_config_list=field_keys, scan_index=True)
+
+ def create_index(self, logstore_name: str) -> None:
+ """
+ Create index for the specified logstore with both full-text and field indexes.
+ Field indexes are automatically generated from the corresponding SQLAlchemy model.
+ """
+ index_config = self._get_index_config(logstore_name)
+
+ try:
+ self.client.create_index(self.project_name, logstore_name, index_config)
+ logger.info(
+ "index for %s created successfully with %d field indexes",
+ logstore_name,
+ len(index_config.key_config_list or {}),
+ )
+ except LogException as e:
+ logger.exception(
+ "Failed to create index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
+ logstore_name,
+ e.get_error_code(),
+ e.get_error_message(),
+ e.get_request_id(),
+ )
+ raise
+
+ def _merge_index_configs(
+ self, existing_config: IndexConfig, required_keys: dict[str, IndexKeyConfig], logstore_name: str
+ ) -> tuple[IndexConfig, bool]:
+ """
+ Intelligently merge existing index config with Dify's required field indexes.
+
+ This method:
+ 1. Preserves all existing field indexes in logstore (including custom fields)
+ 2. Adds missing Dify-required fields
+ 3. Updates fields where type doesn't match (with json/text compatibility)
+ 4. Corrects case mismatches (e.g., if Dify needs 'status' but logstore has 'Status')
+
+ Type compatibility rules:
+ - json and text types are considered compatible (users can manually choose either)
+ - All other type mismatches will be corrected to match Dify requirements
+
+ Note: Logstore is case-sensitive and doesn't allow duplicate fields with different cases.
+ Case mismatch means: existing field name differs from required name only in case.
+
+ Args:
+ existing_config: Current index configuration from logstore
+ required_keys: Dify's required field index configurations
+ logstore_name: Name of the logstore (for logging)
+
+ Returns:
+ Tuple of (merged_config, needs_update)
+ """
+ # key_config_list is already a dict in the SDK
+ # Make a copy to avoid modifying the original
+ existing_keys = dict(existing_config.key_config_list) if existing_config.key_config_list else {}
+
+ # Track changes
+ needs_update = False
+ case_corrections = [] # Fields that need case correction (e.g., 'Status' -> 'status')
+ missing_fields = []
+ type_mismatches = []
+
+ # First pass: Check for and resolve case mismatches with required fields
+ # Note: Logstore itself doesn't allow duplicate fields with different cases,
+ # so we only need to check if the existing case matches the required case
+ for required_name in required_keys:
+ lower_name = required_name.lower()
+ # Find key that matches case-insensitively but not exactly
+ wrong_case_key = None
+ for existing_key in existing_keys:
+ if existing_key.lower() == lower_name and existing_key != required_name:
+ wrong_case_key = existing_key
+ break
+
+ if wrong_case_key:
+ # Field exists but with wrong case (e.g., 'Status' when we need 'status')
+ # Remove the wrong-case key, will be added back with correct case later
+ case_corrections.append((wrong_case_key, required_name))
+ del existing_keys[wrong_case_key]
+ needs_update = True
+
+ # Second pass: Check each required field
+ for required_name, required_config in required_keys.items():
+ # Check for exact match (case-sensitive)
+ if required_name in existing_keys:
+ existing_type = existing_keys[required_name].index_type
+ required_type = required_config.index_type
+
+ # Check if type matches
+ # Special case: json and text are interchangeable for JSON content fields
+ # Allow users to manually configure text instead of json (or vice versa) without forcing updates
+ is_compatible = existing_type == required_type or ({existing_type, required_type} == {"json", "text"})
+
+ if not is_compatible:
+ type_mismatches.append((required_name, existing_type, required_type))
+ # Update with correct type
+ existing_keys[required_name] = required_config
+ needs_update = True
+ # else: field exists with compatible type, no action needed
+ else:
+ # Field doesn't exist (may have been removed in first pass due to case conflict)
+ missing_fields.append(required_name)
+ existing_keys[required_name] = required_config
+ needs_update = True
+
+ # Log changes
+ if missing_fields:
+ logger.info(
+ "Logstore %s: Adding %d missing Dify-required fields: %s",
+ logstore_name,
+ len(missing_fields),
+ ", ".join(missing_fields[:10]) + ("..." if len(missing_fields) > 10 else ""),
+ )
+
+ if type_mismatches:
+ logger.info(
+ "Logstore %s: Fixing %d type mismatches: %s",
+ logstore_name,
+ len(type_mismatches),
+ ", ".join([f"{name}({old}->{new})" for name, old, new in type_mismatches[:5]])
+ + ("..." if len(type_mismatches) > 5 else ""),
+ )
+
+ if case_corrections:
+ logger.info(
+ "Logstore %s: Correcting %d field name cases: %s",
+ logstore_name,
+ len(case_corrections),
+ ", ".join([f"'{old}' -> '{new}'" for old, new in case_corrections[:5]])
+ + ("..." if len(case_corrections) > 5 else ""),
+ )
+
+ # Create merged config
+ # key_config_list should be a dict, not a list
+ # Preserve the original scan_index value - don't force it to True
+ merged_config = IndexConfig(
+ line_config=existing_config.line_config
+ or IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True),
+ key_config_list=existing_keys,
+ scan_index=existing_config.scan_index,
+ )
+
+ return merged_config, needs_update
+
+ def ensure_index_config(self, logstore_name: str) -> None:
+ """
+ Ensure index configuration includes all Dify-required fields.
+
+ This method intelligently manages index configuration:
+ 1. If index doesn't exist, create it with Dify's required fields
+ 2. If index exists:
+ - Check if all Dify-required fields are present
+ - Check if field types match requirements
+ - Only update if fields are missing or types are incorrect
+ - Preserve any additional custom index configurations
+
+ This approach allows users to add their own custom indexes without being overwritten.
+ """
+ # Get Dify's required field indexes
+ required_keys = {}
+ if logstore_name == AliyunLogStore.workflow_execution_logstore:
+ required_keys = self._get_workflow_execution_index_keys()
+ elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
+ required_keys = self._get_workflow_node_execution_index_keys()
+
+ # Check if index exists
+ existing_config = self.get_existing_index_config(logstore_name)
+
+ if existing_config is None:
+ # Index doesn't exist, create it
+ logger.info(
+ "Logstore %s: Index doesn't exist, creating with %d required fields",
+ logstore_name,
+ len(required_keys),
+ )
+ self.create_index(logstore_name)
+ else:
+ merged_config, needs_update = self._merge_index_configs(existing_config, required_keys, logstore_name)
+
+ if needs_update:
+ logger.info("Logstore %s: Updating index to include Dify-required fields", logstore_name)
+ try:
+ self.client.update_index(self.project_name, logstore_name, merged_config)
+ logger.info(
+ "Logstore %s: Index updated successfully, now has %d total field indexes",
+ logstore_name,
+ len(merged_config.key_config_list or {}),
+ )
+ except LogException as e:
+ logger.exception(
+ "Failed to update index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
+ logstore_name,
+ e.get_error_code(),
+ e.get_error_message(),
+ e.get_request_id(),
+ )
+ raise
+ else:
+ logger.info(
+ "Logstore %s: Index already contains all %d Dify-required fields with correct types, "
+ "no update needed",
+ logstore_name,
+ len(required_keys),
+ )
+
+ def put_log(self, logstore: str, contents: Sequence[tuple[str, str]]) -> None:
+ # Route to PG or SDK based on protocol availability
+ if self._use_pg_protocol and self._pg_client:
+ self._pg_client.put_log(logstore, contents, self.log_enabled)
+ else:
+ log_item = LogItem(contents=contents)
+ request = PutLogsRequest(project=self.project_name, logstore=logstore, logitems=[log_item])
+
+ if self.log_enabled:
+ logger.info(
+ "[LogStore-SDK] PUT_LOG | logstore=%s | project=%s | items_count=%d",
+ logstore,
+ self.project_name,
+ len(contents),
+ )
+
+ try:
+ self.client.put_logs(request)
+ except LogException as e:
+ logger.exception(
+ "Failed to put logs to logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
+ logstore,
+ e.get_error_code(),
+ e.get_error_message(),
+ e.get_request_id(),
+ )
+ raise
+
+ def get_logs(
+ self,
+ logstore: str,
+ from_time: int,
+ to_time: int,
+ topic: str = "",
+ query: str = "",
+ line: int = 100,
+ offset: int = 0,
+ reverse: bool = True,
+ ) -> list[dict]:
+ request = GetLogsRequest(
+ project=self.project_name,
+ logstore=logstore,
+ fromTime=from_time,
+ toTime=to_time,
+ topic=topic,
+ query=query,
+ line=line,
+ offset=offset,
+ reverse=reverse,
+ )
+
+ if self.log_enabled:
+ logger.info(
+ "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
+ "from_time=%d | to_time=%d | line=%d | offset=%d | reverse=%s",
+ logstore,
+ self.project_name,
+ query,
+ from_time,
+ to_time,
+ line,
+ offset,
+ reverse,
+ )
+
+ try:
+ response = self.client.get_logs(request)
+ result = []
+ logs = response.get_logs() if response else []
+ for log in logs:
+ result.append(log.get_contents())
+
+ if self.log_enabled:
+ logger.info(
+ "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
+ logstore,
+ len(result),
+ )
+
+ return result
+ except LogException as e:
+ logger.exception(
+ "Failed to get logs from logstore %s with query '%s': errorCode=%s, errorMessage=%s, requestId=%s",
+ logstore,
+ query,
+ e.get_error_code(),
+ e.get_error_message(),
+ e.get_request_id(),
+ )
+ raise
+
+ def execute_sql(
+ self,
+ sql: str,
+ logstore: str | None = None,
+ query: str = "*",
+ from_time: int | None = None,
+ to_time: int | None = None,
+ power_sql: bool = False,
+ ) -> list[dict]:
+ """
+ Execute SQL query for aggregation and analysis.
+
+ Args:
+ sql: SQL query string (SELECT statement)
+ logstore: Name of the logstore (required)
+ query: Search/filter query for SDK mode (default: "*" for all logs).
+ Only used in SDK mode. PG mode ignores this parameter.
+ from_time: Start time (Unix timestamp) - only used in SDK mode
+ to_time: End time (Unix timestamp) - only used in SDK mode
+ power_sql: Whether to use enhanced SQL mode (default: False)
+
+ Returns:
+ List of result rows as dictionaries
+
+ Note:
+ - PG mode: Only executes the SQL directly
+ - SDK mode: Combines query and sql as "query | sql"
+ """
+ # Logstore is required
+ if not logstore:
+ raise ValueError("logstore parameter is required for execute_sql")
+
+ # Route to PG or SDK based on protocol availability
+ if self._use_pg_protocol and self._pg_client:
+ # PG mode: execute SQL directly (ignore query parameter)
+ return self._pg_client.execute_sql(sql, logstore, self.log_enabled)
+ else:
+ # SDK mode: combine query and sql as "query | sql"
+ full_query = f"{query} | {sql}"
+
+ # Provide default time range if not specified
+ if from_time is None:
+ from_time = 0
+
+ if to_time is None:
+ to_time = int(time.time()) # now
+
+ request = GetLogsRequest(
+ project=self.project_name,
+ logstore=logstore,
+ fromTime=from_time,
+ toTime=to_time,
+ query=full_query,
+ )
+
+ if self.log_enabled:
+ logger.info(
+ "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
+ logstore,
+ self.project_name,
+ from_time,
+ to_time,
+ full_query,
+ )
+
+ try:
+ response = self.client.get_logs(request)
+
+ result = []
+ logs = response.get_logs() if response else []
+ for log in logs:
+ result.append(log.get_contents())
+
+ if self.log_enabled:
+ logger.info(
+ "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
+ logstore,
+ len(result),
+ )
+
+ return result
+ except LogException as e:
+ logger.exception(
+ "Failed to execute SQL, logstore %s: errorCode=%s, errorMessage=%s, requestId=%s, full_query=%s",
+ logstore,
+ e.get_error_code(),
+ e.get_error_message(),
+ e.get_request_id(),
+ full_query,
+ )
+ raise
+
+
+if __name__ == "__main__":
+ aliyun_logstore = AliyunLogStore()
+ # aliyun_logstore.init_project_logstore()
+ aliyun_logstore.put_log(AliyunLogStore.workflow_execution_logstore, [("key1", "value1")])
diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py
new file mode 100644
index 0000000000..874c20d144
--- /dev/null
+++ b/api/extensions/logstore/aliyun_logstore_pg.py
@@ -0,0 +1,272 @@
+import logging
+import os
+import socket
+import time
+from collections.abc import Sequence
+from contextlib import contextmanager
+from typing import Any
+
+import psycopg2
+from sqlalchemy import create_engine
+
+from configs import dify_config
+
+logger = logging.getLogger(__name__)
+
+
+class AliyunLogStorePG:
+ """PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
+
+ def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
+ """
+ Initialize PG connection for SLS.
+
+ Args:
+ access_key_id: Aliyun access key ID
+ access_key_secret: Aliyun access key secret
+ endpoint: SLS endpoint
+ project_name: SLS project name
+ """
+ self._access_key_id = access_key_id
+ self._access_key_secret = access_key_secret
+ self._endpoint = endpoint
+ self.project_name = project_name
+ self._engine: Any = None # SQLAlchemy Engine
+ self._use_pg_protocol = False
+
+ def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
+ """Fast TCP port check to avoid long waits on unsupported regions."""
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(timeout)
+ result = sock.connect_ex((host, port))
+ sock.close()
+ return result == 0
+ except Exception as e:
+ logger.debug("Port connectivity check failed for %s:%d: %s", host, port, str(e))
+ return False
+
+ def init_connection(self) -> bool:
+ """Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
+ try:
+ pg_host = self._endpoint.replace("http://", "").replace("https://", "")
+
+ # Pool configuration
+ pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
+ max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
+ pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
+ pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
+
+ logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
+
+ # Fast port check to avoid long waits
+ if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
+ logger.debug("Using SDK mode for host=%s", pg_host)
+ return False
+
+ # Build connection URL
+ from urllib.parse import quote_plus
+
+ username = quote_plus(self._access_key_id)
+ password = quote_plus(self._access_key_secret)
+ database_url = (
+ f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
+ )
+
+ # Create SQLAlchemy engine with connection pool
+ self._engine = create_engine(
+ database_url,
+ pool_size=pool_size,
+ max_overflow=max_overflow,
+ pool_recycle=pool_recycle,
+ pool_pre_ping=pool_pre_ping,
+ pool_timeout=30,
+ connect_args={
+ "connect_timeout": 5,
+ "application_name": f"Dify-{dify_config.project.version}-fixautocommit",
+ "keepalives": 1,
+ "keepalives_idle": 60,
+ "keepalives_interval": 10,
+ "keepalives_count": 5,
+ },
+ )
+
+ self._use_pg_protocol = True
+ logger.info(
+ "PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
+ self.project_name,
+ pool_size,
+ pool_recycle,
+ )
+ return True
+
+ except Exception as e:
+ self._use_pg_protocol = False
+ if self._engine:
+ try:
+ self._engine.dispose()
+ except Exception:
+ logger.debug("Failed to dispose engine during cleanup, ignoring")
+ self._engine = None
+
+ logger.debug("Using SDK mode for region: %s", str(e))
+ return False
+
+ @contextmanager
+ def _get_connection(self):
+ """Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
+ if not self._engine:
+ raise RuntimeError("SQLAlchemy engine is not initialized")
+
+ connection = self._engine.raw_connection()
+ try:
+ connection.autocommit = True # SLS PG protocol does not support transactions
+ yield connection
+ except Exception:
+ raise
+ finally:
+ connection.close()
+
+ def close(self) -> None:
+ """Dispose SQLAlchemy engine and close all connections."""
+ if self._engine:
+ try:
+ self._engine.dispose()
+ logger.info("SQLAlchemy engine disposed")
+ except Exception:
+ logger.exception("Failed to dispose engine")
+
+ def _is_retriable_error(self, error: Exception) -> bool:
+ """Check if error is retriable (connection-related issues)."""
+ # Check for psycopg2 connection errors directly
+ if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
+ return True
+
+ error_msg = str(error).lower()
+ retriable_patterns = [
+ "connection",
+ "timeout",
+ "closed",
+ "broken pipe",
+ "reset by peer",
+ "no route to host",
+ "network",
+ "operational error",
+ "interface error",
+ ]
+ return any(pattern in error_msg for pattern in retriable_patterns)
+
+ def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
+ """Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
+ if not contents:
+ return
+
+ fields = [field_name for field_name, _ in contents]
+ values = [value for _, value in contents]
+ field_list = ", ".join([f'"{field}"' for field in fields])
+
+ if log_enabled:
+ logger.info(
+ "[LogStore-PG] PUT_LOG | logstore=%s | project=%s | items_count=%d",
+ logstore,
+ self.project_name,
+ len(contents),
+ )
+
+ max_retries = 3
+ retry_delay = 0.1
+
+ for attempt in range(max_retries):
+ try:
+ with self._get_connection() as conn:
+ with conn.cursor() as cursor:
+ placeholders = ", ".join(["%s"] * len(fields))
+ values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
+ insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
+ cursor.execute(insert_sql)
+ return
+
+ except psycopg2.Error as e:
+ if not self._is_retriable_error(e):
+ logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
+ raise
+
+ if attempt < max_retries - 1:
+ logger.warning(
+ "Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
+ logstore,
+ attempt + 1,
+ max_retries,
+ str(e),
+ )
+ time.sleep(retry_delay)
+ retry_delay *= 2
+ else:
+ logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
+ raise
+
+ def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
+ """Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
+ if log_enabled:
+ logger.info(
+ "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
+ logstore,
+ self.project_name,
+ sql,
+ )
+
+ max_retries = 3
+ retry_delay = 0.1
+
+ for attempt in range(max_retries):
+ try:
+ with self._get_connection() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(sql)
+ columns = [desc[0] for desc in cursor.description]
+
+ result = []
+ for row in cursor.fetchall():
+ row_dict = {}
+ for col, val in zip(columns, row):
+ row_dict[col] = "" if val is None else str(val)
+ result.append(row_dict)
+
+ if log_enabled:
+ logger.info(
+ "[LogStore-PG] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
+ logstore,
+ len(result),
+ )
+
+ return result
+
+ except psycopg2.Error as e:
+ if not self._is_retriable_error(e):
+ logger.exception(
+ "Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
+ logstore,
+ sql,
+ )
+ raise
+
+ if attempt < max_retries - 1:
+ logger.warning(
+ "Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
+ logstore,
+ attempt + 1,
+ max_retries,
+ str(e),
+ )
+ time.sleep(retry_delay)
+ retry_delay *= 2
+ else:
+ logger.exception(
+ "Failed to execute SQL on logstore %s after %d attempts: sql=%s",
+ logstore,
+ max_retries,
+ sql,
+ )
+ raise
+
+ return []
diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py
new file mode 100644
index 0000000000..b5a4fcf844
--- /dev/null
+++ b/api/extensions/logstore/repositories/__init__.py
@@ -0,0 +1,29 @@
+"""
+LogStore repository utilities.
+"""
+
+from typing import Any
+
+
+def safe_float(value: Any, default: float = 0.0) -> float:
+ """
+ Safely convert a value to float, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ return default
+
+
+def safe_int(value: Any, default: int = 0) -> int:
+ """
+ Safely convert a value to int, handling 'null' strings and None.
+ """
+ if value is None or value in {"null", ""}:
+ return default
+ try:
+ return int(float(value))
+ except (ValueError, TypeError):
+ return default
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
new file mode 100644
index 0000000000..f67723630b
--- /dev/null
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py
@@ -0,0 +1,393 @@
+"""
+LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
+
+This module provides the LogStore-based implementation for service-layer
+WorkflowNodeExecutionModel operations using Aliyun SLS LogStore.
+"""
+
+import logging
+import time
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Any
+
+from sqlalchemy.orm import sessionmaker
+
+from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
+from models.workflow import WorkflowNodeExecutionModel
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+
+logger = logging.getLogger(__name__)
+
+
+def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNodeExecutionModel:
+ """
+ Convert LogStore result dictionary to WorkflowNodeExecutionModel instance.
+
+ Args:
+ data: Dictionary from LogStore query result
+
+ Returns:
+ WorkflowNodeExecutionModel instance (detached from session)
+
+ Note:
+ The returned model is not attached to any SQLAlchemy session.
+ Relationship fields (like offload_data) are not loaded from LogStore.
+ """
+ logger.debug("_dict_to_workflow_node_execution_model: data keys=%s", list(data.keys())[:5])
+ # Create model instance without session
+ model = WorkflowNodeExecutionModel()
+
+ # Map all required fields with validation
+ # Critical fields - must not be None
+ model.id = data.get("id") or ""
+ model.tenant_id = data.get("tenant_id") or ""
+ model.app_id = data.get("app_id") or ""
+ model.workflow_id = data.get("workflow_id") or ""
+ model.triggered_from = data.get("triggered_from") or ""
+ model.node_id = data.get("node_id") or ""
+ model.node_type = data.get("node_type") or ""
+ model.status = data.get("status") or "running" # Default status if missing
+ model.title = data.get("title") or ""
+ model.created_by_role = data.get("created_by_role") or ""
+ model.created_by = data.get("created_by") or ""
+
+ model.index = safe_int(data.get("index", 0))
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
+
+ # Optional fields
+ model.workflow_run_id = data.get("workflow_run_id")
+ model.predecessor_node_id = data.get("predecessor_node_id")
+ model.node_execution_id = data.get("node_execution_id")
+ model.inputs = data.get("inputs")
+ model.process_data = data.get("process_data")
+ model.outputs = data.get("outputs")
+ model.error = data.get("error")
+ model.execution_metadata = data.get("execution_metadata")
+
+ # Handle datetime fields
+ created_at = data.get("created_at")
+ if created_at:
+ if isinstance(created_at, str):
+ model.created_at = datetime.fromisoformat(created_at)
+ elif isinstance(created_at, (int, float)):
+ model.created_at = datetime.fromtimestamp(created_at)
+ else:
+ model.created_at = created_at
+ else:
+ # Provide default created_at if missing
+ model.created_at = datetime.now()
+
+ finished_at = data.get("finished_at")
+ if finished_at:
+ if isinstance(finished_at, str):
+ model.finished_at = datetime.fromisoformat(finished_at)
+ elif isinstance(finished_at, (int, float)):
+ model.finished_at = datetime.fromtimestamp(finished_at)
+ else:
+ model.finished_at = finished_at
+
+ return model
+
+
+class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
+ """
+ LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
+
+ Provides service-layer database operations for WorkflowNodeExecutionModel
+ using LogStore SQL queries with optimized deduplication strategies.
+ """
+
+ def __init__(self, session_maker: sessionmaker | None = None):
+ """
+ Initialize the repository with LogStore client.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
+ """
+ logger.debug("LogstoreAPIWorkflowNodeExecutionRepository.__init__: initializing")
+ self.logstore_client = AliyunLogStore()
+
+ def get_node_last_execution(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_id: str,
+ node_id: str,
+ ) -> WorkflowNodeExecutionModel | None:
+ """
+ Get the most recent execution for a specific node.
+
+ Uses query syntax to get raw logs and selects the one with max log_version.
+ Returns the most recent execution ordered by created_at.
+ """
+ logger.debug(
+ "get_node_last_execution: tenant_id=%s, app_id=%s, workflow_id=%s, node_id=%s",
+ tenant_id,
+ app_id,
+ workflow_id,
+ node_id,
+ )
+ try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_id = escape_identifier(workflow_id)
+ escaped_node_id = escape_identifier(node_id)
+
+ # Check if PG protocol is supported
+ if self.logstore_client.supports_pg_protocol:
+ # Use PG protocol with SQL query (get latest version of each record)
+ sql_query = f"""
+ SELECT * FROM (
+ SELECT *,
+ ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
+ FROM "{AliyunLogStore.workflow_node_execution_logstore}"
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_id = '{escaped_workflow_id}'
+ AND node_id = '{escaped_node_id}'
+ AND __time__ > 0
+ ) AS subquery WHERE rn = 1
+ LIMIT 100
+ """
+ results = self.logstore_client.execute_sql(
+ sql=sql_query,
+ logstore=AliyunLogStore.workflow_node_execution_logstore,
+ )
+ else:
+ # Use SDK with LogStore query syntax
+ query = (
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
+ )
+ from_time = 0
+ to_time = int(time.time()) # now
+
+ results = self.logstore_client.get_logs(
+ logstore=AliyunLogStore.workflow_node_execution_logstore,
+ from_time=from_time,
+ to_time=to_time,
+ query=query,
+ line=100,
+ reverse=False,
+ )
+
+ if not results:
+ return None
+
+ # For SDK mode, group by id and select the one with max log_version for each group
+ # For PG mode, this is already done by the SQL query
+ if not self.logstore_client.supports_pg_protocol:
+ id_to_results: dict[str, list[dict[str, Any]]] = {}
+ for row in results:
+ row_id = row.get("id")
+ if row_id:
+ if row_id not in id_to_results:
+ id_to_results[row_id] = []
+ id_to_results[row_id].append(row)
+
+ # For each id, select the row with max log_version
+ deduplicated_results = []
+ for rows in id_to_results.values():
+ if len(rows) > 1:
+ max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
+ else:
+ max_row = rows[0]
+ deduplicated_results.append(max_row)
+ else:
+ # For PG mode, results are already deduplicated by the SQL query
+ deduplicated_results = results
+
+ # Sort by created_at DESC and return the most recent one
+ deduplicated_results.sort(
+ key=lambda x: x.get("created_at", 0) if isinstance(x.get("created_at"), (int, float)) else 0,
+ reverse=True,
+ )
+
+ if deduplicated_results:
+ return _dict_to_workflow_node_execution_model(deduplicated_results[0])
+
+ return None
+
+ except Exception:
+ logger.exception("Failed to get node last execution from LogStore")
+ raise
+
+ def get_executions_by_workflow_run(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get all node executions for a specific workflow run.
+
+ Uses query syntax to get raw logs and selects the one with max log_version for each node execution.
+ Ordered by index DESC for trace visualization.
+ """
+ logger.debug(
+ "[LogStore] get_executions_by_workflow_run: tenant_id=%s, app_id=%s, workflow_run_id=%s",
+ tenant_id,
+ app_id,
+ workflow_run_id,
+ )
+ try:
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+
+ # Check if PG protocol is supported
+ if self.logstore_client.supports_pg_protocol:
+ # Use PG protocol with SQL query (get latest version of each record)
+ sql_query = f"""
+ SELECT * FROM (
+ SELECT *,
+ ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
+ FROM "{AliyunLogStore.workflow_node_execution_logstore}"
+ WHERE tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND workflow_run_id = '{escaped_workflow_run_id}'
+ AND __time__ > 0
+ ) AS subquery WHERE rn = 1
+ LIMIT 1000
+ """
+ results = self.logstore_client.execute_sql(
+ sql=sql_query,
+ logstore=AliyunLogStore.workflow_node_execution_logstore,
+ )
+ else:
+ # Use SDK with LogStore query syntax
+ query = (
+ f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
+ f"and workflow_run_id: {escaped_workflow_run_id}"
+ )
+ from_time = 0
+ to_time = int(time.time()) # now
+
+ results = self.logstore_client.get_logs(
+ logstore=AliyunLogStore.workflow_node_execution_logstore,
+ from_time=from_time,
+ to_time=to_time,
+ query=query,
+ line=1000, # Get more results for node executions
+ reverse=False,
+ )
+
+ if not results:
+ return []
+
+ # For SDK mode, group by id and select the one with max log_version for each group
+ # For PG mode, this is already done by the SQL query
+ models = []
+ if not self.logstore_client.supports_pg_protocol:
+ id_to_results: dict[str, list[dict[str, Any]]] = {}
+ for row in results:
+ row_id = row.get("id")
+ if row_id:
+ if row_id not in id_to_results:
+ id_to_results[row_id] = []
+ id_to_results[row_id].append(row)
+
+ # For each id, select the row with max log_version
+ for rows in id_to_results.values():
+ if len(rows) > 1:
+ max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
+ else:
+ max_row = rows[0]
+
+ model = _dict_to_workflow_node_execution_model(max_row)
+ if model and model.id: # Ensure model is valid
+ models.append(model)
+ else:
+ # For PG mode, results are already deduplicated by the SQL query
+ for row in results:
+ model = _dict_to_workflow_node_execution_model(row)
+ if model and model.id: # Ensure model is valid
+ models.append(model)
+
+ # Sort by index DESC for trace visualization
+ models.sort(key=lambda x: x.index, reverse=True)
+
+ return models
+
+ except Exception:
+ logger.exception("Failed to get executions by workflow run from LogStore")
+ raise
+
+ def get_execution_by_id(
+ self,
+ execution_id: str,
+ tenant_id: str | None = None,
+ ) -> WorkflowNodeExecutionModel | None:
+ """
+ Get a workflow node execution by its ID.
+ Uses query syntax to get raw logs and selects the one with max log_version.
+ """
+ logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
+ try:
+ # Escape parameters to prevent SQL injection
+ escaped_execution_id = escape_identifier(execution_id)
+
+ # Check if PG protocol is supported
+ if self.logstore_client.supports_pg_protocol:
+ # Use PG protocol with SQL query (get latest version of record)
+ if tenant_id:
+ escaped_tenant_id = escape_identifier(tenant_id)
+ tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
+ else:
+ tenant_filter = ""
+
+ sql_query = f"""
+ SELECT * FROM (
+ SELECT *,
+ ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
+ FROM "{AliyunLogStore.workflow_node_execution_logstore}"
+ WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
+ ) AS subquery WHERE rn = 1
+ LIMIT 1
+ """
+ results = self.logstore_client.execute_sql(
+ sql=sql_query,
+ logstore=AliyunLogStore.workflow_node_execution_logstore,
+ )
+ else:
+ # Use SDK with LogStore query syntax
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
+ if tenant_id:
+ query = (
+ f"id:{escape_logstore_query_value(execution_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
+ )
+ else:
+ query = f"id:{escape_logstore_query_value(execution_id)}"
+
+ from_time = 0
+ to_time = int(time.time()) # now
+
+ results = self.logstore_client.get_logs(
+ logstore=AliyunLogStore.workflow_node_execution_logstore,
+ from_time=from_time,
+ to_time=to_time,
+ query=query,
+ line=100,
+ reverse=False,
+ )
+
+ if not results:
+ return None
+
+ # For PG mode, result is already the latest version
+ # For SDK mode, if multiple results, select the one with max log_version
+ if self.logstore_client.supports_pg_protocol or len(results) == 1:
+ return _dict_to_workflow_node_execution_model(results[0])
+ else:
+ max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
+ return _dict_to_workflow_node_execution_model(max_result)
+
+ except Exception:
+ logger.exception("Failed to get execution by ID from LogStore: execution_id=%s", execution_id)
+ raise
diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
new file mode 100644
index 0000000000..14382ed876
--- /dev/null
+++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py
@@ -0,0 +1,819 @@
+"""
+LogStore API WorkflowRun Repository Implementation
+
+This module provides the LogStore-based implementation of the APIWorkflowRunRepository
+protocol. It handles service-layer WorkflowRun database operations using Aliyun SLS LogStore
+with optimized queries for statistics and pagination.
+
+Key Features:
+- LogStore SQL queries for aggregation and statistics
+- Optimized deduplication using finished_at IS NOT NULL filter
+- Window functions only when necessary (running status queries)
+- Multi-tenant data isolation and security
+- SQL injection prevention via parameter escaping
+"""
+
+import logging
+import os
+import time
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Any, cast
+
+from sqlalchemy.orm import sessionmaker
+
+from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
+from libs.infinite_scroll_pagination import InfiniteScrollPagination
+from models.enums import WorkflowRunTriggeredFrom
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.types import (
+ AverageInteractionStats,
+ DailyRunsStats,
+ DailyTerminalsStats,
+ DailyTokenCostStats,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
+ """
+ Convert LogStore result dictionary to WorkflowRun instance.
+
+ Args:
+ data: Dictionary from LogStore query result
+
+ Returns:
+ WorkflowRun instance
+ """
+ logger.debug("_dict_to_workflow_run: data keys=%s", list(data.keys())[:5])
+ # Create model instance without session
+ model = WorkflowRun()
+
+ # Map all required fields with validation
+ # Critical fields - must not be None
+ model.id = data.get("id") or ""
+ model.tenant_id = data.get("tenant_id") or ""
+ model.app_id = data.get("app_id") or ""
+ model.workflow_id = data.get("workflow_id") or ""
+ model.type = data.get("type") or ""
+ model.triggered_from = data.get("triggered_from") or ""
+ model.version = data.get("version") or ""
+ model.status = data.get("status") or "running" # Default status if missing
+ model.created_by_role = data.get("created_by_role") or ""
+ model.created_by = data.get("created_by") or ""
+
+ model.total_tokens = safe_int(data.get("total_tokens", 0))
+ model.total_steps = safe_int(data.get("total_steps", 0))
+ model.exceptions_count = safe_int(data.get("exceptions_count", 0))
+
+ # Optional fields
+ model.graph = data.get("graph")
+ model.inputs = data.get("inputs")
+ model.outputs = data.get("outputs")
+ model.error = data.get("error_message") or data.get("error")
+
+ # Handle datetime fields
+ started_at = data.get("started_at") or data.get("created_at")
+ if started_at:
+ if isinstance(started_at, str):
+ model.created_at = datetime.fromisoformat(started_at)
+ elif isinstance(started_at, (int, float)):
+ model.created_at = datetime.fromtimestamp(started_at)
+ else:
+ model.created_at = started_at
+ else:
+ # Provide default created_at if missing
+ model.created_at = datetime.now()
+
+ finished_at = data.get("finished_at")
+ if finished_at:
+ if isinstance(finished_at, str):
+ model.finished_at = datetime.fromisoformat(finished_at)
+ elif isinstance(finished_at, (int, float)):
+ model.finished_at = datetime.fromtimestamp(finished_at)
+ else:
+ model.finished_at = finished_at
+
+ # Compute elapsed_time from started_at and finished_at
+ # LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity
+ if model.finished_at and model.created_at:
+ model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
+ else:
+ # Use safe conversion to handle 'null' strings and None values
+ model.elapsed_time = safe_float(data.get("elapsed_time", 0))
+
+ return model
+
+
+class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
+ """
+ LogStore implementation of APIWorkflowRunRepository.
+
+ Provides service-layer WorkflowRun database operations using LogStore SQL
+ with optimized query strategies:
+ - Use finished_at IS NOT NULL for deduplication (10-100x faster)
+ - Use window functions only when running status is required
+ - Proper time range filtering for LogStore queries
+ """
+
+ def __init__(self, session_maker: sessionmaker | None = None):
+ """
+ Initialize the repository with LogStore client.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
+ """
+ logger.debug("LogstoreAPIWorkflowRunRepository.__init__: initializing")
+ self.logstore_client = AliyunLogStore()
+
+ # Control flag for dual-read (fallback to PostgreSQL when LogStore returns no results)
+ # Set to True to enable fallback for safe migration from PostgreSQL to LogStore
+ # Set to False for new deployments without legacy data in PostgreSQL
+ self._enable_dual_read = os.environ.get("LOGSTORE_DUAL_READ_ENABLED", "true").lower() == "true"
+
+ def get_paginated_workflow_runs(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom],
+ limit: int = 20,
+ last_id: str | None = None,
+ status: str | None = None,
+ ) -> InfiniteScrollPagination:
+ """
+ Get paginated workflow runs with filtering.
+
+ Uses window function for deduplication to support both running and finished states.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ app_id: Application identifier
+ triggered_from: Filter by trigger source(s)
+ limit: Maximum number of records to return (default: 20)
+ last_id: Cursor for pagination - ID of the last record from previous page
+ status: Optional filter by status
+
+ Returns:
+ InfiniteScrollPagination object
+ """
+ logger.debug(
+ "get_paginated_workflow_runs: tenant_id=%s, app_id=%s, limit=%d, status=%s",
+ tenant_id,
+ app_id,
+ limit,
+ status,
+ )
+ # Convert triggered_from to list if needed
+ if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
+ triggered_from_list = [triggered_from]
+ else:
+ triggered_from_list = list(triggered_from)
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+
+ # Build triggered_from filter with escaped values
+ # Support both enum and string values for triggered_from
+ triggered_from_filter = " OR ".join(
+ [
+ f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
+ for tf in triggered_from_list
+ ]
+ )
+
+ # Build status filter with escaped value
+ status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
+
+ # Build last_id filter for pagination
+ # Note: This is simplified. In production, you'd need to track created_at from last record
+ last_id_filter = ""
+ if last_id:
+ # TODO: Implement proper cursor-based pagination with created_at
+ logger.warning("last_id pagination not fully implemented for LogStore")
+
+ # Use window function to get latest log_version of each workflow run
+ sql = f"""
+ SELECT * FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND ({triggered_from_filter})
+ {status_filter}
+ {last_id_filter}
+ ) t
+ WHERE rn = 1
+ ORDER BY created_at DESC
+ LIMIT {limit + 1}
+ """
+
+ try:
+ results = self.logstore_client.execute_sql(
+ sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore, from_time=None, to_time=None
+ )
+
+ # Check if there are more records
+ has_more = len(results) > limit
+ if has_more:
+ results = results[:limit]
+
+ # Convert results to WorkflowRun models
+ workflow_runs = [_dict_to_workflow_run(row) for row in results]
+ return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
+
+ except Exception:
+ logger.exception("Failed to get paginated workflow runs from LogStore")
+ raise
+
+ def get_workflow_run_by_id(
+ self,
+ tenant_id: str,
+ app_id: str,
+ run_id: str,
+ ) -> WorkflowRun | None:
+ """
+ Get a specific workflow run by ID with tenant and app isolation.
+
+ Uses query syntax to get raw logs and selects the one with max log_version in code.
+ Falls back to PostgreSQL if not found in LogStore (for data consistency during migration).
+ """
+ logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
+
+ try:
+ # Escape parameters to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+
+ # Check if PG protocol is supported
+ if self.logstore_client.supports_pg_protocol:
+ # Use PG protocol with SQL query (get latest version of record)
+ sql_query = f"""
+ SELECT * FROM (
+ SELECT *,
+ ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
+ FROM "{AliyunLogStore.workflow_execution_logstore}"
+ WHERE id = '{escaped_run_id}'
+ AND tenant_id = '{escaped_tenant_id}'
+ AND app_id = '{escaped_app_id}'
+ AND __time__ > 0
+ ) AS subquery WHERE rn = 1
+ LIMIT 100
+ """
+ results = self.logstore_client.execute_sql(
+ sql=sql_query,
+ logstore=AliyunLogStore.workflow_execution_logstore,
+ )
+ else:
+ # Use SDK with LogStore query syntax
+ # Note: Values must be quoted in LogStore query syntax to prevent injection
+ query = (
+ f"id:{escape_logstore_query_value(run_id)} "
+ f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
+ f"and app_id:{escape_logstore_query_value(app_id)}"
+ )
+ from_time = 0
+ to_time = int(time.time()) # now
+
+ results = self.logstore_client.get_logs(
+ logstore=AliyunLogStore.workflow_execution_logstore,
+ from_time=from_time,
+ to_time=to_time,
+ query=query,
+ line=100,
+ reverse=False,
+ )
+
+ if not results:
+ # Fallback to PostgreSQL for records created before LogStore migration
+ if self._enable_dual_read:
+ logger.debug(
+ "WorkflowRun not found in LogStore, falling back to PostgreSQL: "
+ "run_id=%s, tenant_id=%s, app_id=%s",
+ run_id,
+ tenant_id,
+ app_id,
+ )
+ return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
+ return None
+
+ # For PG mode, results are already deduplicated by the SQL query
+ # For SDK mode, if multiple results, select the one with max log_version
+ if self.logstore_client.supports_pg_protocol or len(results) == 1:
+ return _dict_to_workflow_run(results[0])
+ else:
+ max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
+ return _dict_to_workflow_run(max_result)
+
+ except Exception:
+ logger.exception("Failed to get workflow run by ID from LogStore: run_id=%s", run_id)
+ # Try PostgreSQL fallback on any error (only if dual-read is enabled)
+ if self._enable_dual_read:
+ try:
+ return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
+ except Exception:
+ logger.exception(
+ "PostgreSQL fallback also failed: run_id=%s, tenant_id=%s, app_id=%s", run_id, tenant_id, app_id
+ )
+ raise
+
+ def _fallback_get_workflow_run_by_id_with_tenant(
+ self, run_id: str, tenant_id: str, app_id: str
+ ) -> WorkflowRun | None:
+ """Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
+ from sqlalchemy import select
+ from sqlalchemy.orm import Session
+
+ from extensions.ext_database import db
+
+ with Session(db.engine) as session:
+ stmt = select(WorkflowRun).where(
+ WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
+ )
+ return session.scalar(stmt)
+
+ def get_workflow_run_by_id_without_tenant(
+ self,
+ run_id: str,
+ ) -> WorkflowRun | None:
+ """
+ Get a specific workflow run by ID without tenant/app context.
+ Uses query syntax to get raw logs and selects the one with max log_version.
+ Falls back to PostgreSQL if not found in LogStore (controlled by LOGSTORE_DUAL_READ_ENABLED).
+ """
+ logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
+
+ try:
+ # Escape parameter to prevent SQL injection
+ escaped_run_id = escape_identifier(run_id)
+
+ # Check if PG protocol is supported
+ if self.logstore_client.supports_pg_protocol:
+ # Use PG protocol with SQL query (get latest version of record)
+ sql_query = f"""
+ SELECT * FROM (
+ SELECT *,
+ ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
+ FROM "{AliyunLogStore.workflow_execution_logstore}"
+ WHERE id = '{escaped_run_id}' AND __time__ > 0
+ ) AS subquery WHERE rn = 1
+ LIMIT 100
+ """
+ results = self.logstore_client.execute_sql(
+ sql=sql_query,
+ logstore=AliyunLogStore.workflow_execution_logstore,
+ )
+ else:
+ # Use SDK with LogStore query syntax
+ # Note: Values must be quoted in LogStore query syntax
+ query = f"id:{escape_logstore_query_value(run_id)}"
+ from_time = 0
+ to_time = int(time.time()) # now
+
+ results = self.logstore_client.get_logs(
+ logstore=AliyunLogStore.workflow_execution_logstore,
+ from_time=from_time,
+ to_time=to_time,
+ query=query,
+ line=100,
+ reverse=False,
+ )
+
+ if not results:
+ # Fallback to PostgreSQL for records created before LogStore migration
+ if self._enable_dual_read:
+ logger.debug("WorkflowRun not found in LogStore, falling back to PostgreSQL: run_id=%s", run_id)
+ return self._fallback_get_workflow_run_by_id(run_id)
+ return None
+
+ # For PG mode, results are already deduplicated by the SQL query
+ # For SDK mode, if multiple results, select the one with max log_version
+ if self.logstore_client.supports_pg_protocol or len(results) == 1:
+ return _dict_to_workflow_run(results[0])
+ else:
+ max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
+ return _dict_to_workflow_run(max_result)
+
+ except Exception:
+ logger.exception("Failed to get workflow run without tenant: run_id=%s", run_id)
+ # Try PostgreSQL fallback on any error (only if dual-read is enabled)
+ if self._enable_dual_read:
+ try:
+ return self._fallback_get_workflow_run_by_id(run_id)
+ except Exception:
+ logger.exception("PostgreSQL fallback also failed: run_id=%s", run_id)
+ raise
+
+ def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
+ """Fallback to PostgreSQL query for records not in LogStore."""
+ from sqlalchemy import select
+ from sqlalchemy.orm import Session
+
+ from extensions.ext_database import db
+
+ with Session(db.engine) as session:
+ stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
+ return session.scalar(stmt)
+
+ def get_workflow_runs_count(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ status: str | None = None,
+ time_range: str | None = None,
+ ) -> dict[str, int]:
+ """
+ Get workflow runs count statistics grouped by status.
+
+ Optimization: Use finished_at IS NOT NULL for completed runs (10-50x faster)
+ """
+ logger.debug(
+ "get_workflow_runs_count: tenant_id=%s, app_id=%s, triggered_from=%s, status=%s",
+ tenant_id,
+ app_id,
+ triggered_from,
+ status,
+ )
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter
+ time_filter = ""
+ if time_range:
+ # TODO: Parse time_range and convert to from_time/to_time
+ logger.warning("time_range filter not implemented")
+
+ # If status is provided, simple count
+ if status:
+ escaped_status = escape_sql_string(status)
+
+ if status == "running":
+ # Running status requires window function
+ sql = f"""
+ SELECT COUNT(*) as count
+ FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND status='running'
+ {time_filter}
+ ) t
+ WHERE rn = 1
+ """
+ else:
+ # Finished status uses optimized filter
+ sql = f"""
+ SELECT COUNT(DISTINCT id) as count
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND status='{escaped_status}'
+ AND finished_at IS NOT NULL
+ {time_filter}
+ """
+
+ try:
+ results = self.logstore_client.execute_sql(
+ sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
+ )
+ count = results[0]["count"] if results and len(results) > 0 else 0
+
+ return {
+ "total": count,
+ "running": count if status == "running" else 0,
+ "succeeded": count if status == "succeeded" else 0,
+ "failed": count if status == "failed" else 0,
+ "stopped": count if status == "stopped" else 0,
+ "partial-succeeded": count if status == "partial-succeeded" else 0,
+ }
+ except Exception:
+ logger.exception("Failed to get workflow runs count")
+ raise
+
+ # No status filter - get counts grouped by status
+ # Use optimized query for finished runs, separate query for running
+ try:
+ # Escape parameters (already escaped above, reuse variables)
+ # Count finished runs grouped by status
+ finished_sql = f"""
+ SELECT status, COUNT(DISTINCT id) as count
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND finished_at IS NOT NULL
+ {time_filter}
+ GROUP BY status
+ """
+
+ # Count running runs
+ running_sql = f"""
+ SELECT COUNT(*) as count
+ FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND status='running'
+ {time_filter}
+ ) t
+ WHERE rn = 1
+ """
+
+ finished_results = self.logstore_client.execute_sql(
+ sql=finished_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
+ )
+ running_results = self.logstore_client.execute_sql(
+ sql=running_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
+ )
+
+ # Build response
+ status_counts = {
+ "running": 0,
+ "succeeded": 0,
+ "failed": 0,
+ "stopped": 0,
+ "partial-succeeded": 0,
+ }
+
+ total = 0
+ for result in finished_results:
+ status_val = result.get("status")
+ count = result.get("count", 0)
+ if status_val in status_counts:
+ status_counts[status_val] = count
+ total += count
+
+ # Add running count
+ running_count = running_results[0]["count"] if running_results and len(running_results) > 0 else 0
+ status_counts["running"] = running_count
+ total += running_count
+
+ return {"total": total} | status_counts
+
+ except Exception:
+ logger.exception("Failed to get workflow runs count")
+ raise
+
+ def get_daily_runs_statistics(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ timezone: str = "UTC",
+ ) -> list[DailyRunsStats]:
+ """
+ Get daily runs statistics using optimized query.
+
+ Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT id) (20-100x faster)
+ """
+ logger.debug(
+ "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
+ )
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
+ time_filter = ""
+ if start_date:
+ time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
+ if end_date:
+ time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
+
+ # Optimized query: Use finished_at filter to avoid window function
+ sql = f"""
+ SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND finished_at IS NOT NULL
+ {time_filter}
+ GROUP BY date
+ ORDER BY date
+ """
+
+ try:
+ results = self.logstore_client.execute_sql(
+ sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
+ )
+
+ response_data = []
+ for row in results:
+ response_data.append({"date": str(row.get("date", "")), "runs": row.get("runs", 0)})
+
+ return cast(list[DailyRunsStats], response_data)
+
+ except Exception:
+ logger.exception("Failed to get daily runs statistics")
+ raise
+
+ def get_daily_terminals_statistics(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ timezone: str = "UTC",
+ ) -> list[DailyTerminalsStats]:
+ """
+ Get daily terminals statistics using optimized query.
+
+ Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT created_by) (20-100x faster)
+ """
+ logger.debug(
+ "get_daily_terminals_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
+ tenant_id,
+ app_id,
+ triggered_from,
+ )
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
+ time_filter = ""
+ if start_date:
+ time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
+ if end_date:
+ time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
+
+ sql = f"""
+ SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND finished_at IS NOT NULL
+ {time_filter}
+ GROUP BY date
+ ORDER BY date
+ """
+
+ try:
+ results = self.logstore_client.execute_sql(
+ sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
+ )
+
+ response_data = []
+ for row in results:
+ response_data.append({"date": str(row.get("date", "")), "terminal_count": row.get("terminal_count", 0)})
+
+ return cast(list[DailyTerminalsStats], response_data)
+
+ except Exception:
+ logger.exception("Failed to get daily terminals statistics")
+ raise
+
+ def get_daily_token_cost_statistics(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ timezone: str = "UTC",
+ ) -> list[DailyTokenCostStats]:
+ """
+ Get daily token cost statistics using optimized query.
+
+ Optimization: Use finished_at IS NOT NULL + SUM(total_tokens) (20-100x faster)
+ """
+ logger.debug(
+ "get_daily_token_cost_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
+ tenant_id,
+ app_id,
+ triggered_from,
+ )
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
+ time_filter = ""
+ if start_date:
+ time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
+ if end_date:
+ time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
+
+ sql = f"""
+ SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND finished_at IS NOT NULL
+ {time_filter}
+ GROUP BY date
+ ORDER BY date
+ """
+
+ try:
+ results = self.logstore_client.execute_sql(
+ sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
+ )
+
+ response_data = []
+ for row in results:
+ response_data.append({"date": str(row.get("date", "")), "token_count": row.get("token_count", 0)})
+
+ return cast(list[DailyTokenCostStats], response_data)
+
+ except Exception:
+ logger.exception("Failed to get daily token cost statistics")
+ raise
+
+ def get_average_app_interaction_statistics(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ timezone: str = "UTC",
+ ) -> list[AverageInteractionStats]:
+ """
+ Get average app interaction statistics using optimized query.
+
+ Optimization: Use finished_at IS NOT NULL + AVG (20-100x faster)
+ """
+ logger.debug(
+ "get_average_app_interaction_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
+ tenant_id,
+ app_id,
+ triggered_from,
+ )
+
+ # Escape parameters to prevent SQL injection
+ escaped_tenant_id = escape_identifier(tenant_id)
+ escaped_app_id = escape_identifier(app_id)
+ escaped_triggered_from = escape_sql_string(triggered_from)
+
+ # Build time range filter (datetime.isoformat() is safe)
+ time_filter = ""
+ if start_date:
+ time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
+ if end_date:
+ time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
+
+ sql = f"""
+ SELECT
+ AVG(sub.interactions) AS interactions,
+ sub.date
+ FROM (
+ SELECT
+ DATE(from_unixtime(__time__)) AS date,
+ created_by,
+ COUNT(DISTINCT id) AS interactions
+ FROM {AliyunLogStore.workflow_execution_logstore}
+ WHERE tenant_id='{escaped_tenant_id}'
+ AND app_id='{escaped_app_id}'
+ AND triggered_from='{escaped_triggered_from}'
+ AND finished_at IS NOT NULL
+ {time_filter}
+ GROUP BY date, created_by
+ ) sub
+ GROUP BY sub.date
+ """
+
+ try:
+ results = self.logstore_client.execute_sql(
+ sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
+ )
+
+ response_data = []
+ for row in results:
+ response_data.append(
+ {
+ "date": str(row.get("date", "")),
+ "interactions": float(row.get("interactions", 0)),
+ }
+ )
+
+ return cast(list[AverageInteractionStats], response_data)
+
+ except Exception:
+ logger.exception("Failed to get average app interaction statistics")
+ raise
diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
new file mode 100644
index 0000000000..9928879a7b
--- /dev/null
+++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py
@@ -0,0 +1,188 @@
+import json
+import logging
+import os
+import time
+from typing import Union
+
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
+from core.workflow.entities import WorkflowExecution
+from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
+from extensions.logstore.aliyun_logstore import AliyunLogStore
+from libs.helper import extract_tenant_id
+from models import (
+ Account,
+ CreatorUserRole,
+ EndUser,
+)
+from models.enums import WorkflowRunTriggeredFrom
+
+logger = logging.getLogger(__name__)
+
+
+class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
+ def __init__(
+ self,
+ session_factory: sessionmaker | Engine,
+ user: Union[Account, EndUser],
+ app_id: str | None,
+ triggered_from: WorkflowRunTriggeredFrom | None,
+ ):
+ """
+ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
+
+ Args:
+ session_factory: SQLAlchemy sessionmaker or engine for creating sessions
+ user: Account or EndUser object containing tenant_id, user ID, and role information
+ app_id: App ID for filtering by application (can be None)
+ triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
+ """
+ logger.debug(
+ "LogstoreWorkflowExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from
+ )
+ # Initialize LogStore client
+ # Note: Project/logstore/index initialization is done at app startup via ext_logstore
+ self.logstore_client = AliyunLogStore()
+
+ # Extract tenant_id from user
+ tenant_id = extract_tenant_id(user)
+ if not tenant_id:
+ raise ValueError("User must have a tenant_id or current_tenant_id")
+ self._tenant_id = tenant_id
+
+ # Store app context
+ self._app_id = app_id
+
+ # Extract user context
+ self._triggered_from = triggered_from
+ self._creator_user_id = user.id
+
+ # Determine user role based on user type
+ self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
+
+ # Initialize SQL repository for dual-write support
+ self.sql_repository = SQLAlchemyWorkflowExecutionRepository(session_factory, user, app_id, triggered_from)
+
+ # Control flag for dual-write (write to both LogStore and SQL database)
+ # Set to True to enable dual-write for safe migration, False to use LogStore only
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
+
+ # Control flag for whether to write the `graph` field to LogStore.
+ # If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
+ # otherwise write an empty {} instead. Defaults to writing the `graph` field.
+ self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true"
+
+ def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
+ """
+ Convert a domain model to a logstore model (List[Tuple[str, str]]).
+
+ Args:
+ domain_model: The domain model to convert
+
+ Returns:
+ The logstore model as a list of key-value tuples
+ """
+ logger.debug(
+ "_to_logstore_model: id=%s, workflow_id=%s, status=%s",
+ domain_model.id_,
+ domain_model.workflow_id,
+ domain_model.status.value,
+ )
+ # Use values from constructor if provided
+ if not self._triggered_from:
+ raise ValueError("triggered_from is required in repository constructor")
+ if not self._creator_user_id:
+ raise ValueError("created_by is required in repository constructor")
+ if not self._creator_user_role:
+ raise ValueError("created_by_role is required in repository constructor")
+
+ # Generate log_version as nanosecond timestamp for record versioning
+ log_version = str(time.time_ns())
+
+ # Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
+ json_converter = WorkflowRuntimeTypeConverter()
+
+ logstore_model = [
+ ("id", domain_model.id_),
+ ("log_version", log_version), # Add log_version field for append-only writes
+ ("tenant_id", self._tenant_id),
+ ("app_id", self._app_id or ""),
+ ("workflow_id", domain_model.workflow_id),
+ (
+ "triggered_from",
+ self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from),
+ ),
+ ("type", domain_model.workflow_type.value),
+ ("version", domain_model.workflow_version),
+ (
+ "graph",
+ json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
+ if domain_model.graph and self._enable_put_graph_field
+ else "{}",
+ ),
+ (
+ "inputs",
+ json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
+ if domain_model.inputs
+ else "{}",
+ ),
+ (
+ "outputs",
+ json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
+ if domain_model.outputs
+ else "{}",
+ ),
+ ("status", domain_model.status.value),
+ ("error_message", domain_model.error_message or ""),
+ ("total_tokens", str(domain_model.total_tokens)),
+ ("total_steps", str(domain_model.total_steps)),
+ ("exceptions_count", str(domain_model.exceptions_count)),
+ (
+ "created_by_role",
+ self._creator_user_role.value
+ if hasattr(self._creator_user_role, "value")
+ else str(self._creator_user_role),
+ ),
+ ("created_by", self._creator_user_id),
+ ("started_at", domain_model.started_at.isoformat() if domain_model.started_at else ""),
+ ("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""),
+ ]
+
+ return logstore_model
+
+ def save(self, execution: WorkflowExecution) -> None:
+ """
+ Save or update a WorkflowExecution domain entity to the logstore.
+
+ This method serves as a domain-to-logstore adapter that:
+ 1. Converts the domain entity to its logstore representation
+ 2. Persists the logstore model using Aliyun SLS
+ 3. Maintains proper multi-tenancy by including tenant context during conversion
+ 4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED)
+
+ Args:
+ execution: The WorkflowExecution domain entity to persist
+ """
+ logger.debug(
+ "save: id=%s, workflow_id=%s, status=%s", execution.id_, execution.workflow_id, execution.status.value
+ )
+ try:
+ logstore_model = self._to_logstore_model(execution)
+ self.logstore_client.put_log(AliyunLogStore.workflow_execution_logstore, logstore_model)
+
+ logger.debug("Saved workflow execution to logstore: id=%s", execution.id_)
+ except Exception:
+ logger.exception("Failed to save workflow execution to logstore: id=%s", execution.id_)
+ raise
+
+ # Dual-write to SQL database if enabled (for safe migration)
+ if self._enable_dual_write:
+ try:
+ self.sql_repository.save(execution)
+ logger.debug("Dual-write: saved workflow execution to SQL database: id=%s", execution.id_)
+ except Exception:
+ logger.exception("Failed to dual-write workflow execution to SQL database: id=%s", execution.id_)
+ # Don't raise - LogStore write succeeded, SQL is just a backup
diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
new file mode 100644
index 0000000000..4897171b12
--- /dev/null
+++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py
@@ -0,0 +1,396 @@
+"""
+LogStore implementation of the WorkflowNodeExecutionRepository.
+
+This module provides a LogStore-based repository for WorkflowNodeExecution entities,
+using Aliyun SLS LogStore with append-only writes and version control.
+"""
+
+import json
+import logging
+import os
+import time
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Any, Union
+
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from core.model_runtime.utils.encoders import jsonable_encoder
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.workflow.entities import WorkflowNodeExecution
+from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
+from core.workflow.enums import NodeType
+from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
+from extensions.logstore.aliyun_logstore import AliyunLogStore
+from extensions.logstore.repositories import safe_float, safe_int
+from extensions.logstore.sql_escape import escape_identifier
+from libs.helper import extract_tenant_id
+from models import (
+ Account,
+ CreatorUserRole,
+ EndUser,
+ WorkflowNodeExecutionTriggeredFrom,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecution:
+ """
+ Convert LogStore result dictionary to WorkflowNodeExecution domain model.
+
+ Args:
+ data: Dictionary from LogStore query result
+
+ Returns:
+ WorkflowNodeExecution domain model instance
+ """
+ logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5])
+ # Parse JSON fields
+ inputs = json.loads(data.get("inputs", "{}"))
+ process_data = json.loads(data.get("process_data", "{}"))
+ outputs = json.loads(data.get("outputs", "{}"))
+ metadata = json.loads(data.get("execution_metadata", "{}"))
+
+ # Convert metadata to domain enum keys
+ domain_metadata = {}
+ for k, v in metadata.items():
+ try:
+ domain_metadata[WorkflowNodeExecutionMetadataKey(k)] = v
+ except ValueError:
+ # Skip invalid metadata keys
+ continue
+
+ # Convert status to domain enum
+ status = WorkflowNodeExecutionStatus(data.get("status", "running"))
+
+ # Parse datetime fields
+ created_at = datetime.fromisoformat(data.get("created_at", "")) if data.get("created_at") else datetime.now()
+ finished_at = datetime.fromisoformat(data.get("finished_at", "")) if data.get("finished_at") else None
+
+ return WorkflowNodeExecution(
+ id=data.get("id", ""),
+ node_execution_id=data.get("node_execution_id"),
+ workflow_id=data.get("workflow_id", ""),
+ workflow_execution_id=data.get("workflow_run_id"),
+ index=safe_int(data.get("index", 0)),
+ predecessor_node_id=data.get("predecessor_node_id"),
+ node_id=data.get("node_id", ""),
+ node_type=NodeType(data.get("node_type", "start")),
+ title=data.get("title", ""),
+ inputs=inputs,
+ process_data=process_data,
+ outputs=outputs,
+ status=status,
+ error=data.get("error"),
+ elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
+ metadata=domain_metadata,
+ created_at=created_at,
+ finished_at=finished_at,
+ )
+
+
+class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
+ """
+ LogStore implementation of the WorkflowNodeExecutionRepository interface.
+
+ This implementation uses Aliyun SLS LogStore with an append-only write strategy:
+ - Each save() operation appends a new record with a version timestamp
+ - Updates are simulated by writing new records with higher version numbers
+ - Queries retrieve the latest version using finished_at IS NOT NULL filter
+ - Multi-tenancy is maintained through tenant_id filtering
+
+ Version Strategy:
+ version = time.time_ns() # Nanosecond timestamp for unique ordering
+ """
+
+ def __init__(
+ self,
+ session_factory: sessionmaker | Engine,
+ user: Union[Account, EndUser],
+ app_id: str | None,
+ triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
+ ):
+ """
+ Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
+
+ Args:
+ session_factory: SQLAlchemy sessionmaker or engine for creating sessions
+ user: Account or EndUser object containing tenant_id, user ID, and role information
+ app_id: App ID for filtering by application (can be None)
+ triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
+ """
+ logger.debug(
+ "LogstoreWorkflowNodeExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from
+ )
+ # Initialize LogStore client
+ self.logstore_client = AliyunLogStore()
+
+ # Extract tenant_id from user
+ tenant_id = extract_tenant_id(user)
+ if not tenant_id:
+ raise ValueError("User must have a tenant_id or current_tenant_id")
+ self._tenant_id = tenant_id
+
+ # Store app context
+ self._app_id = app_id
+
+ # Extract user context
+ self._triggered_from = triggered_from
+ self._creator_user_id = user.id
+
+ # Determine user role based on user type
+ self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
+
+ # Initialize SQL repository for dual-write support
+ self.sql_repository = SQLAlchemyWorkflowNodeExecutionRepository(session_factory, user, app_id, triggered_from)
+
+ # Control flag for dual-write (write to both LogStore and SQL database)
+ # Set to True to enable dual-write for safe migration, False to use LogStore only
+ self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
+
+ def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
+ logger.debug(
+ "_to_logstore_model: id=%s, node_id=%s, status=%s",
+ domain_model.id,
+ domain_model.node_id,
+ domain_model.status.value,
+ )
+ if not self._triggered_from:
+ raise ValueError("triggered_from is required in repository constructor")
+ if not self._creator_user_id:
+ raise ValueError("created_by is required in repository constructor")
+ if not self._creator_user_role:
+ raise ValueError("created_by_role is required in repository constructor")
+
+ # Generate log_version as nanosecond timestamp for record versioning
+ log_version = str(time.time_ns())
+
+ json_converter = WorkflowRuntimeTypeConverter()
+
+ logstore_model = [
+ ("id", domain_model.id),
+ ("log_version", log_version), # Add log_version field for append-only writes
+ ("tenant_id", self._tenant_id),
+ ("app_id", self._app_id or ""),
+ ("workflow_id", domain_model.workflow_id),
+ (
+ "triggered_from",
+ self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from),
+ ),
+ ("workflow_run_id", domain_model.workflow_execution_id or ""),
+ ("index", str(domain_model.index)),
+ ("predecessor_node_id", domain_model.predecessor_node_id or ""),
+ ("node_execution_id", domain_model.node_execution_id or ""),
+ ("node_id", domain_model.node_id),
+ ("node_type", domain_model.node_type.value),
+ ("title", domain_model.title),
+ (
+ "inputs",
+ json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
+ if domain_model.inputs
+ else "{}",
+ ),
+ (
+ "process_data",
+ json.dumps(json_converter.to_json_encodable(domain_model.process_data), ensure_ascii=False)
+ if domain_model.process_data
+ else "{}",
+ ),
+ (
+ "outputs",
+ json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
+ if domain_model.outputs
+ else "{}",
+ ),
+ ("status", domain_model.status.value),
+ ("error", domain_model.error or ""),
+ ("elapsed_time", str(domain_model.elapsed_time)),
+ (
+ "execution_metadata",
+ json.dumps(jsonable_encoder(domain_model.metadata), ensure_ascii=False)
+ if domain_model.metadata
+ else "{}",
+ ),
+ ("created_at", domain_model.created_at.isoformat() if domain_model.created_at else ""),
+ ("created_by_role", self._creator_user_role.value),
+ ("created_by", self._creator_user_id),
+ ("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""),
+ ]
+
+ return logstore_model
+
+ def save(self, execution: WorkflowNodeExecution) -> None:
+ """
+ Save or update a NodeExecution domain entity to LogStore.
+
+ This method serves as a domain-to-logstore adapter that:
+ 1. Converts the domain entity to its logstore representation
+ 2. Appends a new record with a log_version timestamp
+ 3. Maintains proper multi-tenancy by including tenant context during conversion
+ 4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED)
+
+ Each save operation creates a new record. Updates are simulated by writing
+ new records with higher log_version numbers.
+
+ Args:
+ execution: The NodeExecution domain entity to persist
+ """
+ logger.debug(
+ "save: id=%s, node_execution_id=%s, status=%s",
+ execution.id,
+ execution.node_execution_id,
+ execution.status.value,
+ )
+ try:
+ logstore_model = self._to_logstore_model(execution)
+ self.logstore_client.put_log(AliyunLogStore.workflow_node_execution_logstore, logstore_model)
+
+ logger.debug(
+ "Saved node execution to LogStore: id=%s, node_execution_id=%s, status=%s",
+ execution.id,
+ execution.node_execution_id,
+ execution.status.value,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to save node execution to LogStore: id=%s, node_execution_id=%s",
+ execution.id,
+ execution.node_execution_id,
+ )
+ raise
+
+ # Dual-write to SQL database if enabled (for safe migration)
+ if self._enable_dual_write:
+ try:
+ self.sql_repository.save(execution)
+ logger.debug("Dual-write: saved node execution to SQL database: id=%s", execution.id)
+ except Exception:
+ logger.exception("Failed to dual-write node execution to SQL database: id=%s", execution.id)
+ # Don't raise - LogStore write succeeded, SQL is just a backup
+
+ def save_execution_data(self, execution: WorkflowNodeExecution) -> None:
+ """
+ Save or update the inputs, process_data, or outputs associated with a specific
+ node_execution record.
+
+ For LogStore implementation, this is a no-op for the LogStore write because save()
+ already writes all fields including inputs, process_data, and outputs. The caller
+ typically calls save() first to persist status/metadata, then calls save_execution_data()
+ to persist data fields. Since LogStore writes complete records atomically, we don't
+ need a separate write here to avoid duplicate records.
+
+ However, if dual-write is enabled, we still need to call the SQL repository's
+ save_execution_data() method to properly update the SQL database.
+
+ Args:
+ execution: The NodeExecution instance with data to save
+ """
+ logger.debug(
+ "save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
+ execution.id,
+ execution.node_execution_id,
+ )
+ # No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
+ # Calling save() again would create a duplicate record in the append-only LogStore
+
+ # Dual-write to SQL database if enabled (for safe migration)
+ if self._enable_dual_write:
+ try:
+ self.sql_repository.save_execution_data(execution)
+ logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
+ except Exception:
+ logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
+ # Don't raise - LogStore write succeeded, SQL is just a backup
+
+ def get_by_workflow_run(
+ self,
+ workflow_run_id: str,
+ order_config: OrderConfig | None = None,
+ ) -> Sequence[WorkflowNodeExecution]:
+ """
+ Retrieve all NodeExecution instances for a specific workflow run.
+ Uses LogStore SQL query with window function to get the latest version of each node execution.
+ This ensures we only get the most recent version of each node execution record.
+ Args:
+ workflow_run_id: The workflow run ID
+ order_config: Optional configuration for ordering results
+ order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
+ order_config.order_direction: Direction to order ("asc" or "desc")
+
+ Returns:
+ A list of NodeExecution instances
+
+ Note:
+ This method uses ROW_NUMBER() window function partitioned by node_execution_id
+ to get the latest version (highest log_version) of each node execution.
+ """
+ logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
+ # Build SQL query with deduplication using window function
+ # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
+ # ensures we get the latest version of each node execution
+
+ # Escape parameters to prevent SQL injection
+ escaped_workflow_run_id = escape_identifier(workflow_run_id)
+ escaped_tenant_id = escape_identifier(self._tenant_id)
+
+ # Build ORDER BY clause for outer query
+ order_clause = ""
+ if order_config and order_config.order_by:
+ order_fields = []
+ for field in order_config.order_by:
+ # Map domain field names to logstore field names if needed
+ field_name = field
+ if order_config.order_direction == "desc":
+ order_fields.append(f"{field_name} DESC")
+ else:
+ order_fields.append(f"{field_name} ASC")
+ if order_fields:
+ order_clause = "ORDER BY " + ", ".join(order_fields)
+
+ # Build app_id filter for subquery
+ app_id_filter = ""
+ if self._app_id:
+ escaped_app_id = escape_identifier(self._app_id)
+ app_id_filter = f" AND app_id='{escaped_app_id}'"
+
+ # Use window function to get latest version of each node execution
+ sql = f"""
+ SELECT * FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
+ FROM {AliyunLogStore.workflow_node_execution_logstore}
+ WHERE workflow_run_id='{escaped_workflow_run_id}'
+ AND tenant_id='{escaped_tenant_id}'
+ {app_id_filter}
+ ) t
+ WHERE rn = 1
+ """
+
+ if order_clause:
+ sql += f" {order_clause}"
+
+ try:
+ # Execute SQL query
+ results = self.logstore_client.execute_sql(
+ sql=sql,
+ query="*",
+ logstore=AliyunLogStore.workflow_node_execution_logstore,
+ )
+
+ # Convert LogStore results to WorkflowNodeExecution domain models
+ executions = []
+ for row in results:
+ try:
+ execution = _dict_to_workflow_node_execution(row)
+ executions.append(execution)
+ except Exception as e:
+ logger.warning("Failed to convert row to WorkflowNodeExecution: %s, row=%s", e, row)
+ continue
+
+ return executions
+
+ except Exception:
+ logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id)
+ raise
diff --git a/api/extensions/logstore/sql_escape.py b/api/extensions/logstore/sql_escape.py
new file mode 100644
index 0000000000..d88d6bd959
--- /dev/null
+++ b/api/extensions/logstore/sql_escape.py
@@ -0,0 +1,134 @@
+"""
+SQL Escape Utility for LogStore Queries
+
+This module provides escaping utilities to prevent injection attacks in LogStore queries.
+
+LogStore supports two query modes:
+1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
+2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
+
+Key Security Concerns:
+- Prevent tenant A from accessing tenant B's data via injection
+- SLS queries are read-only, so we focus on data access control
+- Different escaping strategies for SQL vs LogStore query syntax
+"""
+
+
+def escape_sql_string(value: str) -> str:
+ """
+ Escape a string value for safe use in SQL queries.
+
+ This function escapes single quotes by doubling them, which is the standard
+ SQL escaping method. This prevents SQL injection by ensuring that user input
+ cannot break out of string literals.
+
+ Args:
+ value: The string value to escape
+
+ Returns:
+ Escaped string safe for use in SQL queries
+
+ Examples:
+ >>> escape_sql_string("normal_value")
+ "normal_value"
+ >>> escape_sql_string("value' OR '1'='1")
+ "value'' OR ''1''=''1"
+ >>> escape_sql_string("tenant's_id")
+ "tenant''s_id"
+
+ Security:
+ - Prevents breaking out of string literals
+ - Stops injection attacks like: ' OR '1'='1
+ - Protects against cross-tenant data access
+ """
+ if not value:
+ return value
+
+ # Escape single quotes by doubling them (standard SQL escaping)
+ # This prevents breaking out of string literals in SQL queries
+ return value.replace("'", "''")
+
+
+def escape_identifier(value: str) -> str:
+ """
+ Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
+
+ This function is for PG protocol mode (SQL syntax).
+ For SDK mode, use escape_logstore_query_value() instead.
+
+ Args:
+ value: The identifier value to escape
+
+ Returns:
+ Escaped identifier safe for use in SQL queries
+
+ Examples:
+ >>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
+ "550e8400-e29b-41d4-a716-446655440000"
+ >>> escape_identifier("tenant_id' OR '1'='1")
+ "tenant_id'' OR ''1''=''1"
+
+ Security:
+ - Prevents SQL injection via identifiers
+ - Stops cross-tenant access attempts
+ - Works for UUIDs, alphanumeric IDs, and similar identifiers
+ """
+ # For identifiers, use the same escaping as strings
+ # This is simple and effective for preventing injection
+ return escape_sql_string(value)
+
+
+def escape_logstore_query_value(value: str) -> str:
+ """
+ Escape value for LogStore query syntax (SDK mode).
+
+ LogStore query syntax rules:
+ 1. Keywords (and/or/not) are case-insensitive
+ 2. Single quotes are ordinary characters (no special meaning)
+ 3. Double quotes wrap values: key:"value"
+ 4. Backslash is the escape character:
+ - \" for double quote inside value
+ - \\ for backslash itself
+ 5. Parentheses can change query structure
+
+ To prevent injection:
+ - Wrap value in double quotes to treat special chars as literals
+ - Escape backslashes and double quotes using backslash
+
+ Args:
+ value: The value to escape for LogStore query syntax
+
+ Returns:
+ Quoted and escaped value safe for LogStore query syntax (includes the quotes)
+
+ Examples:
+ >>> escape_logstore_query_value("normal_value")
+ '"normal_value"'
+ >>> escape_logstore_query_value("value or field:evil")
+ '"value or field:evil"' # 'or' and ':' are now literals
+ >>> escape_logstore_query_value('value"test')
+ '"value\\"test"' # Internal double quote escaped
+ >>> escape_logstore_query_value('value\\test')
+ '"value\\\\test"' # Backslash escaped
+
+ Security:
+ - Prevents injection via and/or/not keywords
+ - Prevents injection via colons (:)
+ - Prevents injection via parentheses
+ - Protects against cross-tenant data access
+
+ Note:
+ Escape order is critical: backslash first, then double quotes.
+ Otherwise, we'd double-escape the escape character itself.
+ """
+ if not value:
+ return '""'
+
+ # IMPORTANT: Escape backslashes FIRST, then double quotes
+ # This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
+ escaped = value.replace("\\", "\\\\") # \ -> \\
+ escaped = escaped.replace('"', '\\"') # " -> \"
+
+ # Wrap in double quotes to treat as literal string
+ # This prevents and/or/not/:/() from being interpreted as operators
+ return f'"{escaped}"'
diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py
index 9604a3b6d5..14221d24dd 100644
--- a/api/extensions/otel/decorators/base.py
+++ b/api/extensions/otel/decorators/base.py
@@ -1,5 +1,4 @@
import functools
-import os
from collections.abc import Callable
from typing import Any, TypeVar, cast
@@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer
from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler
+from extensions.otel.runtime import is_instrument_flag_enabled
T = TypeVar("T", bound=Callable[..., Any])
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
-def _is_instrument_flag_enabled() -> bool:
- """
- Check if external instrumentation is enabled via environment variable.
-
- Third-party non-invasive instrumentation agents set this flag to coordinate
- with Dify's manual OpenTelemetry instrumentation.
- """
- return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"
-
-
def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
"""Get or create a singleton instance of the handler class."""
if handler_class not in _HANDLER_INSTANCES:
@@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T],
def decorator(func: T) -> T:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
- if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()):
+ if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
return func(*args, **kwargs)
handler = _get_handler_instance(handler_class or SpanHandler)
diff --git a/api/extensions/otel/instrumentation.py b/api/extensions/otel/instrumentation.py
index 3597110cba..6617f69513 100644
--- a/api/extensions/otel/instrumentation.py
+++ b/api/extensions/otel/instrumentation.py
@@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
+ """
+ Handler that records exceptions to the current OpenTelemetry span.
+
+ Unlike creating a new span, this records exceptions on the existing span
+ to maintain trace context consistency throughout the request lifecycle.
+ """
+
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
- if record.exc_info:
- tracer = get_tracer_provider().get_tracer("dify.exception.logging")
- with tracer.start_as_current_span(
- "log.exception",
- attributes={
- "log.level": record.levelname,
- "log.message": record.getMessage(),
- "log.logger": record.name,
- "log.file.path": record.pathname,
- "log.file.line": record.lineno,
- },
- ) as span:
- span.set_status(StatusCode.ERROR)
- if record.exc_info[1]:
- span.record_exception(record.exc_info[1])
- span.set_attribute("exception.message", str(record.exc_info[1]))
- if record.exc_info[0]:
- span.set_attribute("exception.type", record.exc_info[0].__name__)
+ if not record.exc_info:
+ return
+
+ from opentelemetry.trace import get_current_span
+
+ span = get_current_span()
+ if not span or not span.is_recording():
+ return
+
+ # Record exception on the current span instead of creating a new one
+ span.set_status(StatusCode.ERROR, record.getMessage())
+
+ # Add log context as span events/attributes
+ span.add_event(
+ "log.exception",
+ attributes={
+ "log.level": record.levelname,
+ "log.message": record.getMessage(),
+ "log.logger": record.name,
+ "log.file.path": record.pathname,
+ "log.file.line": record.lineno,
+ },
+ )
+
+ if record.exc_info[1]:
+ span.record_exception(record.exc_info[1])
+ if record.exc_info[0]:
+ span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:
diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py
new file mode 100644
index 0000000000..164db7c275
--- /dev/null
+++ b/api/extensions/otel/parser/__init__.py
@@ -0,0 +1,20 @@
+"""
+OpenTelemetry node parsers for workflow nodes.
+
+This module provides parsers that extract node-specific metadata and set
+OpenTelemetry span attributes according to semantic conventions.
+"""
+
+from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
+from extensions.otel.parser.llm import LLMNodeOTelParser
+from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
+from extensions.otel.parser.tool import ToolNodeOTelParser
+
+__all__ = [
+ "DefaultNodeOTelParser",
+ "LLMNodeOTelParser",
+ "NodeOTelParser",
+ "RetrievalNodeOTelParser",
+ "ToolNodeOTelParser",
+ "safe_json_dumps",
+]
diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py
new file mode 100644
index 0000000000..f4db26e840
--- /dev/null
+++ b/api/extensions/otel/parser/base.py
@@ -0,0 +1,117 @@
+"""
+Base parser interface and utilities for OpenTelemetry node parsers.
+"""
+
+import json
+from typing import Any, Protocol
+
+from opentelemetry.trace import Span
+from opentelemetry.trace.status import Status, StatusCode
+from pydantic import BaseModel
+
+from core.file.models import File
+from core.variables import Segment
+from core.workflow.enums import NodeType
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
+
+
+def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
+ """
+ Safely serialize objects to JSON, handling non-serializable types.
+
+ Handles:
+ - Segment types (ArrayFileSegment, FileSegment, etc.) - converts to their value
+ - File objects - converts to dict using to_dict()
+ - BaseModel objects - converts using model_dump()
+ - Other types - falls back to str() representation
+
+ Args:
+ obj: Object to serialize
+ ensure_ascii: Whether to ensure ASCII encoding
+
+ Returns:
+ JSON string representation of the object
+ """
+
+ def _convert_value(value: Any) -> Any:
+ """Recursively convert non-serializable values."""
+ if value is None:
+ return None
+ if isinstance(value, (bool, int, float, str)):
+ return value
+ if isinstance(value, Segment):
+ # Convert Segment to its underlying value
+ return _convert_value(value.value)
+ if isinstance(value, File):
+ # Convert File to dict
+ return value.to_dict()
+ if isinstance(value, BaseModel):
+ # Convert Pydantic model to dict
+ return _convert_value(value.model_dump(mode="json"))
+ if isinstance(value, dict):
+ return {k: _convert_value(v) for k, v in value.items()}
+ if isinstance(value, (list, tuple)):
+ return [_convert_value(item) for item in value]
+ # Fallback to string representation for unknown types
+ return str(value)
+
+ try:
+ converted = _convert_value(obj)
+ return json.dumps(converted, ensure_ascii=ensure_ascii)
+ except (TypeError, ValueError) as e:
+ # If conversion still fails, return error message as string
+ return json.dumps(
+ {"error": f"Failed to serialize: {type(obj).__name__}", "message": str(e)}, ensure_ascii=ensure_ascii
+ )
+
+
+class NodeOTelParser(Protocol):
+ """Parser interface for node-specific OpenTelemetry enrichment."""
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None: ...
+
+
+class DefaultNodeOTelParser:
+ """Fallback parser used when no node-specific parser is registered."""
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ span.set_attribute("node.id", node.id)
+ if node.execution_id:
+ span.set_attribute("node.execution_id", node.execution_id)
+ if hasattr(node, "node_type") and node.node_type:
+ span.set_attribute("node.type", node.node_type.value)
+
+ span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
+
+ node_type = getattr(node, "node_type", None)
+ if isinstance(node_type, NodeType):
+ if node_type == NodeType.LLM:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
+ elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
+ elif node_type == NodeType.TOOL:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
+ else:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
+ else:
+ span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
+
+ # Extract inputs and outputs from result_event
+ if result_event and result_event.node_run_result:
+ node_run_result = result_event.node_run_result
+ if node_run_result.inputs:
+ span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
+ if node_run_result.outputs:
+ span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
+
+ if error:
+ span.record_exception(error)
+ span.set_status(Status(StatusCode.ERROR, str(error)))
+ else:
+ span.set_status(Status(StatusCode.OK))
diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py
new file mode 100644
index 0000000000..8556974080
--- /dev/null
+++ b/api/extensions/otel/parser/llm.py
@@ -0,0 +1,155 @@
+"""
+Parser for LLM nodes that captures LLM-specific metadata.
+"""
+
+import logging
+from collections.abc import Mapping
+from typing import Any
+
+from opentelemetry.trace import Span
+
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
+from extensions.otel.semconv.gen_ai import LLMAttributes
+
+logger = logging.getLogger(__name__)
+
+
+def _format_input_messages(process_data: Mapping[str, Any]) -> str:
+ """
+ Format input messages from process_data for LLM spans.
+
+ Args:
+ process_data: Process data containing prompts
+
+ Returns:
+ JSON string of formatted input messages
+ """
+ try:
+ if not isinstance(process_data, dict):
+ return safe_json_dumps([])
+
+ prompts = process_data.get("prompts", [])
+ if not prompts:
+ return safe_json_dumps([])
+
+ valid_roles = {"system", "user", "assistant", "tool"}
+ input_messages = []
+ for prompt in prompts:
+ if not isinstance(prompt, dict):
+ continue
+
+ role = prompt.get("role", "")
+ text = prompt.get("text", "")
+
+ if not role or role not in valid_roles:
+ continue
+
+ if text:
+ message = {"role": role, "parts": [{"type": "text", "content": text}]}
+ input_messages.append(message)
+
+ return safe_json_dumps(input_messages)
+ except Exception as e:
+ logger.warning("Failed to format input messages: %s", e, exc_info=True)
+ return safe_json_dumps([])
+
+
+def _format_output_messages(outputs: Mapping[str, Any]) -> str:
+ """
+ Format output messages from outputs for LLM spans.
+
+ Args:
+ outputs: Output data containing text and finish_reason
+
+ Returns:
+ JSON string of formatted output messages
+ """
+ try:
+ if not isinstance(outputs, dict):
+ return safe_json_dumps([])
+
+ text = outputs.get("text", "")
+ finish_reason = outputs.get("finish_reason", "")
+
+ if not text:
+ return safe_json_dumps([])
+
+ valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
+ if finish_reason not in valid_finish_reasons:
+ finish_reason = "stop"
+
+ output_message = {
+ "role": "assistant",
+ "parts": [{"type": "text", "content": text}],
+ "finish_reason": finish_reason,
+ }
+
+ return safe_json_dumps([output_message])
+ except Exception as e:
+ logger.warning("Failed to format output messages: %s", e, exc_info=True)
+ return safe_json_dumps([])
+
+
+class LLMNodeOTelParser:
+ """Parser for LLM nodes that captures LLM-specific metadata."""
+
+ def __init__(self) -> None:
+ self._delegate = DefaultNodeOTelParser()
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
+
+ if not result_event or not result_event.node_run_result:
+ return
+
+ node_run_result = result_event.node_run_result
+ process_data = node_run_result.process_data or {}
+ outputs = node_run_result.outputs or {}
+
+ # Extract usage data (from process_data or outputs)
+ usage_data = process_data.get("usage") or outputs.get("usage") or {}
+
+ # Model and provider information
+ model_name = process_data.get("model_name") or ""
+ model_provider = process_data.get("model_provider") or ""
+
+ if model_name:
+ span.set_attribute(LLMAttributes.REQUEST_MODEL, model_name)
+ if model_provider:
+ span.set_attribute(LLMAttributes.PROVIDER_NAME, model_provider)
+
+ # Token usage
+ if usage_data:
+ prompt_tokens = usage_data.get("prompt_tokens", 0)
+ completion_tokens = usage_data.get("completion_tokens", 0)
+ total_tokens = usage_data.get("total_tokens", 0)
+
+ span.set_attribute(LLMAttributes.USAGE_INPUT_TOKENS, prompt_tokens)
+ span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
+ span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
+
+ # Prompts and completion
+ prompts = process_data.get("prompts", [])
+ if prompts:
+ prompts_json = safe_json_dumps(prompts)
+ span.set_attribute(LLMAttributes.PROMPT, prompts_json)
+
+ text_output = str(outputs.get("text", ""))
+ if text_output:
+ span.set_attribute(LLMAttributes.COMPLETION, text_output)
+
+ # Finish reason
+ finish_reason = outputs.get("finish_reason") or ""
+ if finish_reason:
+ span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
+
+ # Structured input/output messages
+ gen_ai_input_message = _format_input_messages(process_data)
+ gen_ai_output_message = _format_output_messages(outputs)
+
+ span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
+ span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)
diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py
new file mode 100644
index 0000000000..fc151af691
--- /dev/null
+++ b/api/extensions/otel/parser/retrieval.py
@@ -0,0 +1,105 @@
+"""
+Parser for knowledge retrieval nodes that captures retrieval-specific metadata.
+"""
+
+import logging
+from collections.abc import Sequence
+from typing import Any
+
+from opentelemetry.trace import Span
+
+from core.variables import Segment
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
+from extensions.otel.semconv.gen_ai import RetrieverAttributes
+
+logger = logging.getLogger(__name__)
+
+
+def _format_retrieval_documents(retrieval_documents: list[Any]) -> list:
+ """
+ Format retrieval documents for semantic conventions.
+
+ Args:
+ retrieval_documents: List of retrieval document dictionaries
+
+ Returns:
+ List of formatted semantic documents
+ """
+ try:
+ if not isinstance(retrieval_documents, list):
+ return []
+
+ semantic_documents = []
+ for doc in retrieval_documents:
+ if not isinstance(doc, dict):
+ continue
+
+ metadata = doc.get("metadata", {})
+ content = doc.get("content", "")
+ title = doc.get("title", "")
+ score = metadata.get("score", 0.0)
+ document_id = metadata.get("document_id", "")
+
+ semantic_metadata = {}
+ if title:
+ semantic_metadata["title"] = title
+ if metadata.get("source"):
+ semantic_metadata["source"] = metadata["source"]
+ elif metadata.get("_source"):
+ semantic_metadata["source"] = metadata["_source"]
+ if metadata.get("doc_metadata"):
+ doc_metadata = metadata["doc_metadata"]
+ if isinstance(doc_metadata, dict):
+ semantic_metadata.update(doc_metadata)
+
+ semantic_doc = {
+ "document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
+ }
+ semantic_documents.append(semantic_doc)
+
+ return semantic_documents
+ except Exception as e:
+ logger.warning("Failed to format retrieval documents: %s", e, exc_info=True)
+ return []
+
+
+class RetrievalNodeOTelParser:
+ """Parser for knowledge retrieval nodes that captures retrieval-specific metadata."""
+
+ def __init__(self) -> None:
+ self._delegate = DefaultNodeOTelParser()
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
+
+ if not result_event or not result_event.node_run_result:
+ return
+
+ node_run_result = result_event.node_run_result
+ inputs = node_run_result.inputs or {}
+ outputs = node_run_result.outputs or {}
+
+ # Extract query from inputs
+ query = str(inputs.get("query", "")) if inputs else ""
+ if query:
+ span.set_attribute(RetrieverAttributes.QUERY, query)
+
+ # Extract and format retrieval documents from outputs
+ result_value = outputs.get("result") if outputs else None
+ retrieval_documents: list[Any] = []
+ if result_value:
+ value_to_check = result_value
+ if isinstance(result_value, Segment):
+ value_to_check = result_value.value
+
+ if isinstance(value_to_check, (list, Sequence)):
+ retrieval_documents = list(value_to_check)
+
+ if retrieval_documents:
+ semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
+ semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
+ span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)
diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py
new file mode 100644
index 0000000000..b99180722b
--- /dev/null
+++ b/api/extensions/otel/parser/tool.py
@@ -0,0 +1,47 @@
+"""
+Parser for tool nodes that captures tool-specific metadata.
+"""
+
+from opentelemetry.trace import Span
+
+from core.workflow.enums import WorkflowNodeExecutionMetadataKey
+from core.workflow.graph_events import GraphNodeEventBase
+from core.workflow.nodes.base.node import Node
+from core.workflow.nodes.tool.entities import ToolNodeData
+from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
+from extensions.otel.semconv.gen_ai import ToolAttributes
+
+
+class ToolNodeOTelParser:
+ """Parser for tool nodes that captures tool-specific metadata."""
+
+ def __init__(self) -> None:
+ self._delegate = DefaultNodeOTelParser()
+
+ def parse(
+ self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
+ ) -> None:
+ self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
+
+ tool_data = getattr(node, "_node_data", None)
+ if not isinstance(tool_data, ToolNodeData):
+ return
+
+ span.set_attribute(ToolAttributes.TOOL_NAME, node.title)
+ span.set_attribute(ToolAttributes.TOOL_TYPE, tool_data.provider_type.value)
+
+ # Extract tool info from metadata (consistent with aliyun_trace)
+ tool_info = {}
+ if result_event and result_event.node_run_result:
+ node_run_result = result_event.node_run_result
+ if node_run_result.metadata:
+ tool_info = node_run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
+
+ if tool_info:
+ span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
+
+ if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
+ span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
+
+ if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
+ span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))
diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py
index 16f5ccf488..a7181d2683 100644
--- a/api/extensions/otel/runtime.py
+++ b/api/extensions/otel/runtime.py
@@ -1,4 +1,5 @@
import logging
+import os
import sys
from typing import Union
@@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs):
if dify_config.DEBUG:
logger.info("Initializing OpenTelemetry for Celery worker")
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
+
+
+def is_instrument_flag_enabled() -> bool:
+ """
+ Check if external instrumentation is enabled via environment variable.
+
+ Third-party non-invasive instrumentation agents set this flag to coordinate
+ with Dify's manual OpenTelemetry instrumentation.
+ """
+ return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"
diff --git a/api/extensions/otel/semconv/__init__.py b/api/extensions/otel/semconv/__init__.py
index dc79dee222..0db3075815 100644
--- a/api/extensions/otel/semconv/__init__.py
+++ b/api/extensions/otel/semconv/__init__.py
@@ -1,6 +1,13 @@
"""Semantic convention shortcuts for Dify-specific spans."""
from .dify import DifySpanAttributes
-from .gen_ai import GenAIAttributes
+from .gen_ai import ChainAttributes, GenAIAttributes, LLMAttributes, RetrieverAttributes, ToolAttributes
-__all__ = ["DifySpanAttributes", "GenAIAttributes"]
+__all__ = [
+ "ChainAttributes",
+ "DifySpanAttributes",
+ "GenAIAttributes",
+ "LLMAttributes",
+ "RetrieverAttributes",
+ "ToolAttributes",
+]
diff --git a/api/extensions/otel/semconv/gen_ai.py b/api/extensions/otel/semconv/gen_ai.py
index 83c52ed34f..88c2058c06 100644
--- a/api/extensions/otel/semconv/gen_ai.py
+++ b/api/extensions/otel/semconv/gen_ai.py
@@ -62,3 +62,37 @@ class ToolAttributes:
TOOL_CALL_RESULT = "gen_ai.tool.call.result"
"""Tool invocation result."""
+
+
+class LLMAttributes:
+ """LLM operation attribute keys."""
+
+ REQUEST_MODEL = "gen_ai.request.model"
+ """Model identifier."""
+
+ PROVIDER_NAME = "gen_ai.provider.name"
+ """Provider name."""
+
+ USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
+ """Number of input tokens."""
+
+ USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
+ """Number of output tokens."""
+
+ USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
+ """Total number of tokens."""
+
+ PROMPT = "gen_ai.prompt"
+ """Prompt text."""
+
+ COMPLETION = "gen_ai.completion"
+ """Completion text."""
+
+ RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
+ """Finish reason for the response."""
+
+ INPUT_MESSAGE = "gen_ai.input.messages"
+ """Input messages in structured format."""
+
+ OUTPUT_MESSAGE = "gen_ai.output.messages"
+ """Output messages in structured format."""
diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py
index 2283581f62..3d7ef99c9e 100644
--- a/api/extensions/storage/aliyun_oss_storage.py
+++ b/api/extensions/storage/aliyun_oss_storage.py
@@ -26,6 +26,7 @@ class AliyunOssStorage(BaseStorage):
self.bucket_name,
connect_timeout=30,
region=region,
+ cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
)
def save(self, filename, data):
diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py
index 51a97b20f8..1d9911465b 100644
--- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py
+++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py
@@ -5,6 +5,8 @@ automatic cleanup, backup and restore.
Supports complete lifecycle management for knowledge base files.
"""
+from __future__ import annotations
+
import json
import logging
import operator
@@ -48,7 +50,7 @@ class FileMetadata:
return data
@classmethod
- def from_dict(cls, data: dict) -> "FileMetadata":
+ def from_dict(cls, data: dict) -> FileMetadata:
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py
index 74fed26f65..72cb59abbe 100644
--- a/api/extensions/storage/huawei_obs_storage.py
+++ b/api/extensions/storage/huawei_obs_storage.py
@@ -17,6 +17,7 @@ class HuaweiObsStorage(BaseStorage):
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
server=dify_config.HUAWEI_OBS_SERVER,
+ path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
)
def save(self, filename, data):
diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py
index a084844d72..83c5c2d12f 100644
--- a/api/extensions/storage/opendal_storage.py
+++ b/api/extensions/storage/opendal_storage.py
@@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage):
if not self.exists(path):
raise FileNotFoundError("Path not found")
- all_files = self.op.scan(path=path)
+ # Use the new OpenDAL 0.46.0+ API with recursive listing
+ lister = self.op.list(path, recursive=True)
if files and directories:
logger.debug("files and directories on %s scanned", path)
- return [f.path for f in all_files]
+ return [entry.path for entry in lister]
if files:
logger.debug("files on %s scanned", path)
- return [f.path for f in all_files if not f.path.endswith("/")]
+ return [entry.path for entry in lister if not entry.metadata.is_dir]
elif directories:
logger.debug("directories on %s scanned", path)
- return [f.path for f in all_files if f.path.endswith("/")]
+ return [entry.path for entry in lister if entry.metadata.is_dir]
else:
raise ValueError("At least one of files or directories must be True")
diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py
index ea5d982efc..cf092c6973 100644
--- a/api/extensions/storage/tencent_cos_storage.py
+++ b/api/extensions/storage/tencent_cos_storage.py
@@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage):
super().__init__()
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
- config = CosConfig(
- Region=dify_config.TENCENT_COS_REGION,
- SecretId=dify_config.TENCENT_COS_SECRET_ID,
- SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
- Scheme=dify_config.TENCENT_COS_SCHEME,
- )
+ if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
+ config = CosConfig(
+ Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
+ SecretId=dify_config.TENCENT_COS_SECRET_ID,
+ SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
+ Scheme=dify_config.TENCENT_COS_SCHEME,
+ )
+ else:
+ config = CosConfig(
+ Region=dify_config.TENCENT_COS_REGION,
+ SecretId=dify_config.TENCENT_COS_SECRET_ID,
+ SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
+ Scheme=dify_config.TENCENT_COS_SCHEME,
+ )
self.client = CosS3Client(config)
def save(self, filename, data):
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index 737a79f2b0..0be836c8f1 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -1,3 +1,4 @@
+import logging
import mimetypes
import os
import re
@@ -17,6 +18,8 @@ from core.helper import ssrf_proxy
from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile
+logger = logging.getLogger(__name__)
+
def build_from_message_files(
*,
@@ -112,7 +115,18 @@ def build_from_mappings(
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
- valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
+ def is_valid_mapping(m: Mapping[str, Any]) -> bool:
+ if not m or not m.get("transfer_method"):
+ return False
+ # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
+ transfer_method = m.get("transfer_method")
+ if transfer_method == FileTransferMethod.REMOTE_URL:
+ url = m.get("url") or m.get("remote_url")
+ if not url:
+ return False
+ return True
+
+ valid_mappings = [m for m in mappings if is_valid_mapping(m)]
files = [
build_from_mapping(
mapping=mapping,
@@ -356,15 +370,20 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
+ # Backward/interop compatibility: allow tool_file_id to come from related_id or URL
+ tool_file_id = mapping.get("tool_file_id")
+
+ if not tool_file_id:
+ raise ValueError(f"ToolFile {tool_file_id} not found")
tool_file = db.session.scalar(
select(ToolFile).where(
- ToolFile.id == mapping.get("tool_file_id"),
+ ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
)
if tool_file is None:
- raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
+ raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
@@ -402,10 +421,13 @@ def _build_from_datasource_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
+ datasource_file_id = mapping.get("datasource_file_id")
+ if not datasource_file_id:
+ raise ValueError(f"DatasourceFile {datasource_file_id} not found")
datasource_file = (
db.session.query(UploadFile)
.where(
- UploadFile.id == mapping.get("datasource_file_id"),
+ UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
.first()
diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py
index 494194369a..3f030ae127 100644
--- a/api/factories/variable_factory.py
+++ b/api/factories/variable_factory.py
@@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
- Variable,
+ VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
-def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
-def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
-def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
+def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
-def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
+def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
- result: Variable
+ result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
- return cast(Variable, result)
+ return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
-) -> Variable:
- if isinstance(segment, Variable):
+) -> VariableBase:
+ if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
- Variable,
+ VariableBase,
variable_class(
id=id,
name=name,
diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py
index 38835d5ac7..e69306dcb2 100644
--- a/api/fields/annotation_fields.py
+++ b/api/fields/annotation_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from libs.helper import TimestampField
@@ -12,7 +12,7 @@ annotation_fields = {
}
-def build_annotation_model(api_or_ns: Api | Namespace):
+def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)
diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py
index ecc267cf38..d8ae0ad8b8 100644
--- a/api/fields/conversation_fields.py
+++ b/api/fields/conversation_fields.py
@@ -1,236 +1,338 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from fields.member_fields import simple_account_fields
-from libs.helper import TimestampField
+from datetime import datetime
+from typing import Any, TypeAlias
-from .raws import FilesContainedField
+from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
+
+from core.file import File
+
+JSONValue: TypeAlias = Any
-class MessageTextField(fields.Raw):
- def format(self, value):
- return value[0]["text"] if value else ""
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(
+ from_attributes=True,
+ extra="ignore",
+ populate_by_name=True,
+ serialize_by_alias=True,
+ protected_namespaces=(),
+ )
-feedback_fields = {
- "rating": fields.String,
- "content": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account": fields.Nested(simple_account_fields, allow_null=True),
-}
+class MessageFile(ResponseModel):
+ id: str
+ filename: str
+ type: str
+ url: str | None = None
+ mime_type: str | None = None
+ size: int | None = None
+ transfer_method: str
+ belongs_to: str | None = None
+ upload_file_id: str | None = None
-annotation_fields = {
- "id": fields.String,
- "question": fields.String,
- "content": fields.String,
- "account": fields.Nested(simple_account_fields, allow_null=True),
- "created_at": TimestampField,
-}
-
-annotation_hit_history_fields = {
- "annotation_id": fields.String(attribute="id"),
- "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
- "created_at": TimestampField,
-}
-
-message_file_fields = {
- "id": fields.String,
- "filename": fields.String,
- "type": fields.String,
- "url": fields.String,
- "mime_type": fields.String,
- "size": fields.Integer,
- "transfer_method": fields.String,
- "belongs_to": fields.String(default="user"),
- "upload_file_id": fields.String(default=None),
-}
+ @field_validator("transfer_method", mode="before")
+ @classmethod
+ def _normalize_transfer_method(cls, value: object) -> str:
+ if isinstance(value, str):
+ return value
+ return str(value)
-def build_message_file_model(api_or_ns: Api | Namespace):
- """Build the message file fields for the API or Namespace."""
- return api_or_ns.model("MessageFile", message_file_fields)
+class SimpleConversation(ResponseModel):
+ id: str
+ name: str
+ inputs: dict[str, JSONValue]
+ status: str
+ introduction: str | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
+ return format_files_contained(value)
+
+ @field_validator("created_at", "updated_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
-agent_thought_fields = {
- "id": fields.String,
- "chain_id": fields.String,
- "message_id": fields.String,
- "position": fields.Integer,
- "thought": fields.String,
- "tool": fields.String,
- "tool_labels": fields.Raw,
- "tool_input": fields.String,
- "created_at": TimestampField,
- "observation": fields.String,
- "files": fields.List(fields.String),
-}
-
-message_detail_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "message": fields.Raw,
- "message_tokens": fields.Integer,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "answer_tokens": fields.Integer,
- "provider_response_latency": fields.Float,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account_id": fields.String,
- "feedbacks": fields.List(fields.Nested(feedback_fields)),
- "workflow_run_id": fields.String,
- "annotation": fields.Nested(annotation_fields, allow_null=True),
- "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "metadata": fields.Raw(attribute="message_metadata_dict"),
- "status": fields.String,
- "error": fields.String,
- "parent_message_id": fields.String,
-}
-
-feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
-status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
-model_config_fields = {
- "opening_statement": fields.String,
- "suggested_questions": fields.Raw,
- "model": fields.Raw,
- "user_input_form": fields.Raw,
- "pre_prompt": fields.String,
- "agent_mode": fields.Raw,
-}
-
-simple_model_config_fields = {
- "model": fields.Raw(attribute="model_dict"),
- "pre_prompt": fields.String,
-}
-
-simple_message_detail_fields = {
- "inputs": FilesContainedField,
- "query": fields.String,
- "message": MessageTextField,
- "answer": fields.String,
-}
-
-conversation_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_end_user_session_id": fields.String(),
- "from_account_id": fields.String,
- "from_account_name": fields.String,
- "read_at": TimestampField,
- "created_at": TimestampField,
- "updated_at": TimestampField,
- "annotation": fields.Nested(annotation_fields, allow_null=True),
- "model_config": fields.Nested(simple_model_config_fields),
- "user_feedback_stats": fields.Nested(feedback_stat_fields),
- "admin_feedback_stats": fields.Nested(feedback_stat_fields),
- "message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
-}
-
-conversation_pagination_fields = {
- "page": fields.Integer,
- "limit": fields.Integer(attribute="per_page"),
- "total": fields.Integer,
- "has_more": fields.Boolean(attribute="has_next"),
- "data": fields.List(fields.Nested(conversation_fields), attribute="items"),
-}
-
-conversation_message_detail_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account_id": fields.String,
- "created_at": TimestampField,
- "model_config": fields.Nested(model_config_fields),
- "message": fields.Nested(message_detail_fields, attribute="first_message"),
-}
-
-conversation_with_summary_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_end_user_session_id": fields.String,
- "from_account_id": fields.String,
- "from_account_name": fields.String,
- "name": fields.String,
- "summary": fields.String(attribute="summary_or_query"),
- "read_at": TimestampField,
- "created_at": TimestampField,
- "updated_at": TimestampField,
- "annotated": fields.Boolean,
- "model_config": fields.Nested(simple_model_config_fields),
- "message_count": fields.Integer,
- "user_feedback_stats": fields.Nested(feedback_stat_fields),
- "admin_feedback_stats": fields.Nested(feedback_stat_fields),
- "status_count": fields.Nested(status_count_fields),
-}
-
-conversation_with_summary_pagination_fields = {
- "page": fields.Integer,
- "limit": fields.Integer(attribute="per_page"),
- "total": fields.Integer,
- "has_more": fields.Boolean(attribute="has_next"),
- "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
-}
-
-conversation_detail_fields = {
- "id": fields.String,
- "status": fields.String,
- "from_source": fields.String,
- "from_end_user_id": fields.String,
- "from_account_id": fields.String,
- "created_at": TimestampField,
- "updated_at": TimestampField,
- "annotated": fields.Boolean,
- "introduction": fields.String,
- "model_config": fields.Nested(model_config_fields),
- "message_count": fields.Integer,
- "user_feedback_stats": fields.Nested(feedback_stat_fields),
- "admin_feedback_stats": fields.Nested(feedback_stat_fields),
-}
-
-simple_conversation_fields = {
- "id": fields.String,
- "name": fields.String,
- "inputs": FilesContainedField,
- "status": fields.String,
- "introduction": fields.String,
- "created_at": TimestampField,
- "updated_at": TimestampField,
-}
-
-conversation_delete_fields = {
- "result": fields.String,
-}
-
-conversation_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(simple_conversation_fields)),
-}
+class ConversationInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[SimpleConversation]
-def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
- """Build the conversation infinite scroll pagination model for the API or Namespace."""
- simple_conversation_model = build_simple_conversation_model(api_or_ns)
-
- copied_fields = conversation_infinite_scroll_pagination_fields.copy()
- copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
- return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
+class ConversationDelete(ResponseModel):
+ result: str
-def build_conversation_delete_model(api_or_ns: Api | Namespace):
- """Build the conversation delete model for the API or Namespace."""
- return api_or_ns.model("ConversationDelete", conversation_delete_fields)
+class ResultResponse(ResponseModel):
+ result: str
-def build_simple_conversation_model(api_or_ns: Api | Namespace):
- """Build the simple conversation model for the API or Namespace."""
- return api_or_ns.model("SimpleConversation", simple_conversation_fields)
+class SimpleAccount(ResponseModel):
+ id: str
+ name: str
+ email: str
+
+
+class Feedback(ResponseModel):
+ rating: str
+ content: str | None = None
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account: SimpleAccount | None = None
+
+
+class Annotation(ResponseModel):
+ id: str
+ question: str | None = None
+ content: str
+ account: SimpleAccount | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class AnnotationHitHistory(ResponseModel):
+ annotation_id: str
+ annotation_create_account: SimpleAccount | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class AgentThought(ResponseModel):
+ id: str
+ chain_id: str | None = None
+ message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
+ message_id: str
+ position: int
+ thought: str | None = None
+ tool: str | None = None
+ tool_labels: JSONValue
+ tool_input: str | None = None
+ created_at: int | None = None
+ observation: str | None = None
+ files: list[str]
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+ @model_validator(mode="after")
+ def _fallback_chain_id(self):
+ if self.chain_id is None and self.message_chain_id:
+ self.chain_id = self.message_chain_id
+ return self
+
+
+class MessageDetail(ResponseModel):
+ id: str
+ conversation_id: str
+ inputs: dict[str, JSONValue]
+ query: str
+ message: JSONValue
+ message_tokens: int
+ answer: str
+ answer_tokens: int
+ provider_response_latency: float
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account_id: str | None = None
+ feedbacks: list[Feedback]
+ workflow_run_id: str | None = None
+ annotation: Annotation | None = None
+ annotation_hit_history: AnnotationHitHistory | None = None
+ created_at: int | None = None
+ agent_thoughts: list[AgentThought]
+ message_files: list[MessageFile]
+ metadata: JSONValue
+ status: str
+ error: str | None = None
+ parent_message_id: str | None = None
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
+ return format_files_contained(value)
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class FeedbackStat(ResponseModel):
+ like: int
+ dislike: int
+
+
+class StatusCount(ResponseModel):
+ success: int
+ failed: int
+ partial_success: int
+
+
+class ModelConfig(ResponseModel):
+ opening_statement: str | None = None
+ suggested_questions: JSONValue | None = None
+ model: JSONValue | None = None
+ user_input_form: JSONValue | None = None
+ pre_prompt: str | None = None
+ agent_mode: JSONValue | None = None
+
+
+class SimpleModelConfig(ResponseModel):
+ model: JSONValue | None = None
+ pre_prompt: str | None = None
+
+
+class SimpleMessageDetail(ResponseModel):
+ inputs: dict[str, JSONValue]
+ query: str
+ message: str
+ answer: str
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
+ return format_files_contained(value)
+
+
+class Conversation(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_end_user_session_id: str | None = None
+ from_account_id: str | None = None
+ from_account_name: str | None = None
+ read_at: int | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+ annotation: Annotation | None = None
+ model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
+ user_feedback_stats: FeedbackStat | None = None
+ admin_feedback_stats: FeedbackStat | None = None
+ message: SimpleMessageDetail | None = None
+
+
+class ConversationPagination(ResponseModel):
+ page: int
+ limit: int
+ total: int
+ has_more: bool
+ data: list[Conversation]
+
+
+class ConversationMessageDetail(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account_id: str | None = None
+ created_at: int | None = None
+ model_config_: ModelConfig | None = Field(default=None, alias="model_config")
+ message: MessageDetail | None = None
+
+
+class ConversationWithSummary(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_end_user_session_id: str | None = None
+ from_account_id: str | None = None
+ from_account_name: str | None = None
+ name: str
+ summary: str
+ read_at: int | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+ annotated: bool
+ model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
+ message_count: int
+ user_feedback_stats: FeedbackStat | None = None
+ admin_feedback_stats: FeedbackStat | None = None
+ status_count: StatusCount | None = None
+
+
+class ConversationWithSummaryPagination(ResponseModel):
+ page: int
+ limit: int
+ total: int
+ has_more: bool
+ data: list[ConversationWithSummary]
+
+
+class ConversationDetail(ResponseModel):
+ id: str
+ status: str
+ from_source: str
+ from_end_user_id: str | None = None
+ from_account_id: str | None = None
+ created_at: int | None = None
+ updated_at: int | None = None
+ annotated: bool
+ introduction: str | None = None
+ model_config_: ModelConfig | None = Field(default=None, alias="model_config")
+ message_count: int
+ user_feedback_stats: FeedbackStat | None = None
+ admin_feedback_stats: FeedbackStat | None = None
+
+
+def to_timestamp(value: datetime | None) -> int | None:
+ if value is None:
+ return None
+ return int(value.timestamp())
+
+
+def format_files_contained(value: JSONValue) -> JSONValue:
+ if isinstance(value, File):
+ return value.model_dump()
+ if isinstance(value, dict):
+ return {k: format_files_contained(v) for k, v in value.items()}
+ if isinstance(value, list):
+ return [format_files_contained(v) for v in value]
+ return value
+
+
+def message_text(value: JSONValue) -> str:
+ if isinstance(value, list) and value:
+ first = value[0]
+ if isinstance(first, dict):
+ text = first.get("text")
+ if isinstance(text, str):
+ return text
+ return ""
+
+
+def extract_model_config(value: object | None) -> dict[str, JSONValue]:
+ if value is None:
+ return {}
+ if isinstance(value, dict):
+ return value
+ if hasattr(value, "to_dict"):
+ return value.to_dict()
+ return {}
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index 7d5e311591..c55014a368 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from libs.helper import TimestampField
@@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
}
-def build_conversation_variable_model(api_or_ns: Api | Namespace):
+def build_conversation_variable_model(api_or_ns: Namespace):
"""Build the conversation variable model for the API or Namespace."""
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
-def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
+def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
# Build the nested variable model first
conversation_variable_model = build_conversation_variable_model(api_or_ns)
diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py
index 1e5ec7d200..ff6578098b 100644
--- a/api/fields/dataset_fields.py
+++ b/api/fields/dataset_fields.py
@@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
"score_threshold_enabled": fields.Boolean,
"score_threshold": fields.Float,
}
+
+dataset_summary_index_fields = {
+ "enable": fields.Boolean,
+ "model_name": fields.String,
+ "model_provider_name": fields.String,
+ "summary_prompt": fields.String,
+}
+
external_retrieval_model_fields = {
"top_k": fields.Integer,
"score_threshold": fields.Float,
@@ -83,6 +91,7 @@ dataset_detail_fields = {
"embedding_model_provider": fields.String,
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
+ "summary_index_setting": fields.Nested(dataset_summary_index_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py
index 9be59f7454..35a2a04f3e 100644
--- a/api/fields/document_fields.py
+++ b/api/fields/document_fields.py
@@ -33,6 +33,11 @@ document_fields = {
"hit_count": fields.Integer,
"doc_form": fields.String,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
+ # Summary index generation status:
+ # "SUMMARIZING" (when task is queued and generating)
+ "summary_index_status": fields.String,
+ # Whether this document needs summary index generation
+ "need_summary": fields.Boolean,
}
document_with_segments_fields = {
@@ -60,6 +65,10 @@ document_with_segments_fields = {
"completed_segments": fields.Integer,
"total_segments": fields.Integer,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
+ # Summary index generation status:
+ # "SUMMARIZING" (when task is queued and generating)
+ "summary_index_status": fields.String,
+ "need_summary": fields.Boolean, # Whether this document needs summary index generation
}
dataset_and_document_fields = {
diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py
index ea43e3b5fd..5389b0213a 100644
--- a/api/fields/end_user_fields.py
+++ b/api/fields/end_user_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
simple_end_user_fields = {
"id": fields.String,
@@ -8,5 +8,5 @@ simple_end_user_fields = {
}
-def build_simple_end_user_model(api_or_ns: Api | Namespace):
+def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py
index a707500445..913fb675f9 100644
--- a/api/fields/file_fields.py
+++ b/api/fields/file_fields.py
@@ -1,93 +1,85 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from libs.helper import TimestampField
+from datetime import datetime
-upload_config_fields = {
- "file_size_limit": fields.Integer,
- "batch_count_limit": fields.Integer,
- "image_file_size_limit": fields.Integer,
- "video_file_size_limit": fields.Integer,
- "audio_file_size_limit": fields.Integer,
- "workflow_file_upload_limit": fields.Integer,
- "image_file_batch_limit": fields.Integer,
- "single_chunk_attachment_limit": fields.Integer,
-}
+from pydantic import BaseModel, ConfigDict, field_validator
-def build_upload_config_model(api_or_ns: Api | Namespace):
- """Build the upload config model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("UploadConfig", upload_config_fields)
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(
+ from_attributes=True,
+ extra="ignore",
+ populate_by_name=True,
+ serialize_by_alias=True,
+ protected_namespaces=(),
+ )
-file_fields = {
- "id": fields.String,
- "name": fields.String,
- "size": fields.Integer,
- "extension": fields.String,
- "mime_type": fields.String,
- "created_by": fields.String,
- "created_at": TimestampField,
- "preview_url": fields.String,
- "source_url": fields.String,
-}
+def _to_timestamp(value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return int(value.timestamp())
+ return value
-def build_file_model(api_or_ns: Api | Namespace):
- """Build the file model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("File", file_fields)
+class UploadConfig(ResponseModel):
+ file_size_limit: int
+ batch_count_limit: int
+ file_upload_limit: int | None = None
+ image_file_size_limit: int
+ video_file_size_limit: int
+ audio_file_size_limit: int
+ workflow_file_upload_limit: int
+ image_file_batch_limit: int
+ single_chunk_attachment_limit: int
+ attachment_image_file_size_limit: int | None = None
-remote_file_info_fields = {
- "file_type": fields.String(attribute="file_type"),
- "file_length": fields.Integer(attribute="file_length"),
-}
+class FileResponse(ResponseModel):
+ id: str
+ name: str
+ size: int
+ extension: str | None = None
+ mime_type: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ preview_url: str | None = None
+ source_url: str | None = None
+ original_url: str | None = None
+ user_id: str | None = None
+ tenant_id: str | None = None
+ conversation_id: str | None = None
+ file_key: str | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
-def build_remote_file_info_model(api_or_ns: Api | Namespace):
- """Build the remote file info model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("RemoteFileInfo", remote_file_info_fields)
+class RemoteFileInfo(ResponseModel):
+ file_type: str
+ file_length: int
-file_fields_with_signed_url = {
- "id": fields.String,
- "name": fields.String,
- "size": fields.Integer,
- "extension": fields.String,
- "url": fields.String,
- "mime_type": fields.String,
- "created_by": fields.String,
- "created_at": TimestampField,
-}
+class FileWithSignedUrl(ResponseModel):
+ id: str
+ name: str
+ size: int
+ extension: str | None = None
+ url: str | None = None
+ mime_type: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
-def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
- """Build the file with signed URL model for the API or Namespace.
-
- Args:
- api_or_ns: Flask-RestX Api or Namespace instance
-
- Returns:
- The registered model
- """
- return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url)
+__all__ = [
+ "FileResponse",
+ "FileWithSignedUrl",
+ "RemoteFileInfo",
+ "UploadConfig",
+]
diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py
index e70f9fa722..0b54992835 100644
--- a/api/fields/hit_testing_fields.py
+++ b/api/fields/hit_testing_fields.py
@@ -58,4 +58,5 @@ hit_testing_record_fields = {
"score": fields.Float,
"tsne_position": fields.Raw,
"files": fields.List(fields.Nested(files_fields)),
+ "summary": fields.String, # Summary content if retrieved via summary index
}
diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py
index 08e38a6931..25160927e6 100644
--- a/api/fields/member_fields.py
+++ b/api/fields/member_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from libs.helper import AvatarUrlField, TimestampField
@@ -9,7 +9,7 @@ simple_account_fields = {
}
-def build_simple_account_model(api_or_ns: Api | Namespace):
+def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)
diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py
index a419da2e18..e6c3b42f93 100644
--- a/api/fields/message_fields.py
+++ b/api/fields/message_fields.py
@@ -1,77 +1,139 @@
-from flask_restx import Api, Namespace, fields
+from __future__ import annotations
-from fields.conversation_fields import message_file_fields
-from libs.helper import TimestampField
+from datetime import datetime
+from typing import TypeAlias
+from uuid import uuid4
-from .raws import FilesContainedField
+from pydantic import BaseModel, ConfigDict, Field, field_validator
-feedback_fields = {
- "rating": fields.String,
-}
+from core.file import File
+from fields.conversation_fields import AgentThought, JSONValue, MessageFile
+
+JSONValueType: TypeAlias = JSONValue
-def build_feedback_model(api_or_ns: Api | Namespace):
- """Build the feedback model for the API or Namespace."""
- return api_or_ns.model("Feedback", feedback_fields)
+class ResponseModel(BaseModel):
+ model_config = ConfigDict(from_attributes=True, extra="ignore")
-agent_thought_fields = {
- "id": fields.String,
- "chain_id": fields.String,
- "message_id": fields.String,
- "position": fields.Integer,
- "thought": fields.String,
- "tool": fields.String,
- "tool_labels": fields.Raw,
- "tool_input": fields.String,
- "created_at": TimestampField,
- "observation": fields.String,
- "files": fields.List(fields.String),
-}
+class SimpleFeedback(ResponseModel):
+ rating: str | None = None
-def build_agent_thought_model(api_or_ns: Api | Namespace):
- """Build the agent thought model for the API or Namespace."""
- return api_or_ns.model("AgentThought", agent_thought_fields)
+class RetrieverResource(ResponseModel):
+ id: str = Field(default_factory=lambda: str(uuid4()))
+ message_id: str = Field(default_factory=lambda: str(uuid4()))
+ position: int
+ dataset_id: str | None = None
+ dataset_name: str | None = None
+ document_id: str | None = None
+ document_name: str | None = None
+ data_source_type: str | None = None
+ segment_id: str | None = None
+ score: float | None = None
+ hit_count: int | None = None
+ word_count: int | None = None
+ segment_position: int | None = None
+ index_node_hash: str | None = None
+ content: str | None = None
+ summary: str | None = None
+ created_at: int | None = None
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
-retriever_resource_fields = {
- "id": fields.String,
- "message_id": fields.String,
- "position": fields.Integer,
- "dataset_id": fields.String,
- "dataset_name": fields.String,
- "document_id": fields.String,
- "document_name": fields.String,
- "data_source_type": fields.String,
- "segment_id": fields.String,
- "score": fields.Float,
- "hit_count": fields.Integer,
- "word_count": fields.Integer,
- "segment_position": fields.Integer,
- "index_node_hash": fields.String,
- "content": fields.String,
- "created_at": TimestampField,
-}
+class MessageListItem(ResponseModel):
+ id: str
+ conversation_id: str
+ parent_message_id: str | None = None
+ inputs: dict[str, JSONValueType]
+ query: str
+ answer: str = Field(validation_alias="re_sign_file_url_answer")
+ feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
+ retriever_resources: list[RetrieverResource]
+ created_at: int | None = None
+ agent_thoughts: list[AgentThought]
+ message_files: list[MessageFile]
+ status: str
+ error: str | None = None
-message_fields = {
- "id": fields.String,
- "conversation_id": fields.String,
- "parent_message_id": fields.String,
- "inputs": FilesContainedField,
- "query": fields.String,
- "answer": fields.String(attribute="re_sign_file_url_answer"),
- "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
- "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
- "created_at": TimestampField,
- "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
- "message_files": fields.List(fields.Nested(message_file_fields)),
- "status": fields.String,
- "error": fields.String,
-}
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
+ return format_files_contained(value)
-message_infinite_scroll_pagination_fields = {
- "limit": fields.Integer,
- "has_more": fields.Boolean,
- "data": fields.List(fields.Nested(message_fields)),
-}
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class WebMessageListItem(MessageListItem):
+ metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
+
+
+class MessageInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[MessageListItem]
+
+
+class WebMessageInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[WebMessageListItem]
+
+
+class SavedMessageItem(ResponseModel):
+ id: str
+ inputs: dict[str, JSONValueType]
+ query: str
+ answer: str
+ message_files: list[MessageFile]
+ feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
+ created_at: int | None = None
+
+ @field_validator("inputs", mode="before")
+ @classmethod
+ def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
+ return format_files_contained(value)
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ return value
+
+
+class SavedMessageInfiniteScrollPagination(ResponseModel):
+ limit: int
+ has_more: bool
+ data: list[SavedMessageItem]
+
+
+class SuggestedQuestionsResponse(ResponseModel):
+ data: list[str]
+
+
+def to_timestamp(value: datetime | None) -> int | None:
+ if value is None:
+ return None
+ return int(value.timestamp())
+
+
+def format_files_contained(value: JSONValueType) -> JSONValueType:
+ if isinstance(value, File):
+ return value.model_dump()
+ if isinstance(value, dict):
+ return {k: format_files_contained(v) for k, v in value.items()}
+ if isinstance(value, list):
+ return [format_files_contained(v) for v in value]
+ return value
diff --git a/api/fields/rag_pipeline_fields.py b/api/fields/rag_pipeline_fields.py
index f9e858c68b..97c02e7085 100644
--- a/api/fields/rag_pipeline_fields.py
+++ b/api/fields/rag_pipeline_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import fields # type: ignore
+from flask_restx import fields
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py
index 56d6b68378..2ce9fb154c 100644
--- a/api/fields/segment_fields.py
+++ b/api/fields/segment_fields.py
@@ -49,4 +49,5 @@ segment_fields = {
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"attachments": fields.List(fields.Nested(attachment_fields)),
+ "summary": fields.String, # Summary content for the segment
}
diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py
index d5b7c86a04..e359a4408c 100644
--- a/api/fields/tag_fields.py
+++ b/api/fields/tag_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
dataset_tag_fields = {
"id": fields.String,
@@ -8,5 +8,5 @@ dataset_tag_fields = {
}
-def build_dataset_tag_fields(api_or_ns: Api | Namespace):
+def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py
index 4cbdf6f0ca..ae70356322 100644
--- a/api/fields/workflow_app_log_fields.py
+++ b/api/fields/workflow_app_log_fields.py
@@ -1,8 +1,13 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
-from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields
+from fields.workflow_run_fields import (
+ build_workflow_run_for_archived_log_model,
+ build_workflow_run_for_log_model,
+ workflow_run_for_archived_log_fields,
+ workflow_run_for_log_fields,
+)
from libs.helper import TimestampField
workflow_app_log_partial_fields = {
@@ -17,7 +22,7 @@ workflow_app_log_partial_fields = {
}
-def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
+def build_workflow_app_log_partial_model(api_or_ns: Namespace):
"""Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
@@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
+workflow_archived_log_partial_fields = {
+ "id": fields.String,
+ "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True),
+ "trigger_metadata": fields.Raw,
+ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
+ "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
+ "created_at": TimestampField,
+}
+
+
+def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
+ """Build the workflow archived log partial model for the API or Namespace."""
+ workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
+ simple_account_model = build_simple_account_model(api_or_ns)
+ simple_end_user_model = build_simple_end_user_model(api_or_ns)
+
+ copied_fields = workflow_archived_log_partial_fields.copy()
+ copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
+ copied_fields["created_by_account"] = fields.Nested(
+ simple_account_model, attribute="created_by_account", allow_null=True
+ )
+ copied_fields["created_by_end_user"] = fields.Nested(
+ simple_end_user_model, attribute="created_by_end_user", allow_null=True
+ )
+ return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)
+
+
workflow_app_log_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer,
@@ -43,7 +75,7 @@ workflow_app_log_pagination_fields = {
}
-def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
+def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
"""Build the workflow app log pagination model for the API or Namespace."""
# Build the nested partial model first
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
@@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
copied_fields = workflow_app_log_pagination_fields.copy()
copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model))
return api_or_ns.model("WorkflowAppLogPagination", copied_fields)
+
+
+workflow_archived_log_pagination_fields = {
+ "page": fields.Integer,
+ "limit": fields.Integer,
+ "total": fields.Integer,
+ "has_more": fields.Boolean,
+ "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)),
+}
+
+
+def build_workflow_archived_log_pagination_model(api_or_ns: Namespace):
+ """Build the workflow archived log pagination model for the API or Namespace."""
+ workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns)
+
+ copied_fields = workflow_archived_log_pagination_fields.copy()
+ copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model))
+ return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields)
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index d037b0c442..2755f77f61 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
-from core.variables import SecretVariable, SegmentType, Variable
+from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
- if isinstance(value, Variable):
+ if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,
diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py
index 821ce62ecc..35bb442c59 100644
--- a/api/fields/workflow_run_fields.py
+++ b/api/fields/workflow_run_fields.py
@@ -1,4 +1,4 @@
-from flask_restx import Api, Namespace, fields
+from flask_restx import Namespace, fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@@ -19,10 +19,23 @@ workflow_run_for_log_fields = {
}
-def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
+def build_workflow_run_for_log_model(api_or_ns: Namespace):
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
+workflow_run_for_archived_log_fields = {
+ "id": fields.String,
+ "status": fields.String,
+ "triggered_from": fields.String,
+ "elapsed_time": fields.Float,
+ "total_tokens": fields.Integer,
+}
+
+
+def build_workflow_run_for_archived_log_model(api_or_ns: Namespace):
+ return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields)
+
+
workflow_run_for_list_fields = {
"id": fields.String,
"version": fields.String,
diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py
new file mode 100644
index 0000000000..66b57ac661
--- /dev/null
+++ b/api/libs/archive_storage.py
@@ -0,0 +1,353 @@
+"""
+Archive Storage Client for S3-compatible storage.
+
+This module provides a dedicated storage client for archiving or exporting logs
+to S3-compatible object storage.
+"""
+
+import base64
+import datetime
+import hashlib
+import logging
+from collections.abc import Generator
+from typing import Any, cast
+
+import boto3
+import orjson
+from botocore.client import Config
+from botocore.exceptions import ClientError
+
+from configs import dify_config
+
+logger = logging.getLogger(__name__)
+
+
+class ArchiveStorageError(Exception):
+ """Base exception for archive storage operations."""
+
+ pass
+
+
+class ArchiveStorageNotConfiguredError(ArchiveStorageError):
+ """Raised when archive storage is not properly configured."""
+
+ pass
+
+
+class ArchiveStorage:
+ """
+ S3-compatible storage client for archiving or exporting.
+
+ This client provides methods for storing and retrieving archived data in JSONL format.
+ """
+
+ def __init__(self, bucket: str):
+ if not dify_config.ARCHIVE_STORAGE_ENABLED:
+ raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
+
+ if not bucket:
+ raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
+ if not all(
+ [
+ dify_config.ARCHIVE_STORAGE_ENDPOINT,
+ bucket,
+ dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
+ dify_config.ARCHIVE_STORAGE_SECRET_KEY,
+ ]
+ ):
+ raise ArchiveStorageNotConfiguredError(
+ "Archive storage configuration is incomplete. "
+ "Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
+ "ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
+ )
+
+ self.bucket = bucket
+ self.client = boto3.client(
+ "s3",
+ endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
+ aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
+ aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
+ region_name=dify_config.ARCHIVE_STORAGE_REGION,
+ config=Config(
+ s3={"addressing_style": "path"},
+ max_pool_connections=64,
+ ),
+ )
+
+ # Verify bucket accessibility
+ try:
+ self.client.head_bucket(Bucket=self.bucket)
+ except ClientError as e:
+ error_code = e.response.get("Error", {}).get("Code")
+ if error_code == "404":
+ raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
+ elif error_code == "403":
+ raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
+ else:
+ raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
+
+ def put_object(self, key: str, data: bytes) -> str:
+ """
+ Upload an object to the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+ data: Binary data to upload
+
+ Returns:
+ MD5 checksum of the uploaded data
+
+ Raises:
+ ArchiveStorageError: If upload fails
+ """
+ checksum = hashlib.md5(data).hexdigest()
+ try:
+ response = self.client.put_object(
+ Bucket=self.bucket,
+ Key=key,
+ Body=data,
+ ContentMD5=self._content_md5(data),
+ )
+ etag = response.get("ETag")
+ if not etag:
+ raise ArchiveStorageError(f"Missing ETag for '{key}'")
+ normalized_etag = etag.strip('"')
+ if normalized_etag != checksum:
+ raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}")
+ logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
+ return checksum
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
+
+ def get_object(self, key: str) -> bytes:
+ """
+ Download an object from the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Returns:
+ Binary data of the object
+
+ Raises:
+ ArchiveStorageError: If download fails
+ FileNotFoundError: If object does not exist
+ """
+ try:
+ response = self.client.get_object(Bucket=self.bucket, Key=key)
+ return response["Body"].read()
+ except ClientError as e:
+ error_code = e.response.get("Error", {}).get("Code")
+ if error_code == "NoSuchKey":
+ raise FileNotFoundError(f"Archive object not found: {key}")
+ raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
+
+ def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
+ """
+ Stream an object from the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Yields:
+ Chunks of binary data
+
+ Raises:
+ ArchiveStorageError: If download fails
+ FileNotFoundError: If object does not exist
+ """
+ try:
+ response = self.client.get_object(Bucket=self.bucket, Key=key)
+ yield from response["Body"].iter_chunks()
+ except ClientError as e:
+ error_code = e.response.get("Error", {}).get("Code")
+ if error_code == "NoSuchKey":
+ raise FileNotFoundError(f"Archive object not found: {key}")
+ raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
+
+ def object_exists(self, key: str) -> bool:
+ """
+ Check if an object exists in the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Returns:
+ True if object exists, False otherwise
+ """
+ try:
+ self.client.head_object(Bucket=self.bucket, Key=key)
+ return True
+ except ClientError:
+ return False
+
+ def delete_object(self, key: str) -> None:
+ """
+ Delete an object from the archive storage.
+
+ Args:
+ key: Object key (path) within the bucket
+
+ Raises:
+ ArchiveStorageError: If deletion fails
+ """
+ try:
+ self.client.delete_object(Bucket=self.bucket, Key=key)
+ logger.debug("Deleted object: %s", key)
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
+
+ def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
+ """
+ Generate a pre-signed URL for downloading an object.
+
+ Args:
+ key: Object key (path) within the bucket
+ expires_in: URL validity duration in seconds (default: 1 hour)
+
+ Returns:
+ Pre-signed URL string.
+
+ Raises:
+ ArchiveStorageError: If generation fails
+ """
+ try:
+ return self.client.generate_presigned_url(
+ ClientMethod="get_object",
+ Params={"Bucket": self.bucket, "Key": key},
+ ExpiresIn=expires_in,
+ )
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
+
+ def list_objects(self, prefix: str) -> list[str]:
+ """
+ List objects under a given prefix.
+
+ Args:
+ prefix: Object key prefix to filter by
+
+ Returns:
+ List of object keys matching the prefix
+ """
+ keys = []
+ paginator = self.client.get_paginator("list_objects_v2")
+
+ try:
+ for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
+ for obj in page.get("Contents", []):
+ keys.append(obj["Key"])
+ except ClientError as e:
+ raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
+
+ return keys
+
+ @staticmethod
+ def _content_md5(data: bytes) -> str:
+ """Calculate base64-encoded MD5 for Content-MD5 header."""
+ return base64.b64encode(hashlib.md5(data).digest()).decode()
+
+ @staticmethod
+ def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes:
+ """
+ Serialize records to JSONL format.
+
+ Args:
+ records: List of dictionaries to serialize
+
+ Returns:
+ JSONL bytes
+ """
+ lines = []
+ for record in records:
+ serialized = ArchiveStorage._serialize_record(record)
+ lines.append(orjson.dumps(serialized))
+
+ jsonl_content = b"\n".join(lines)
+ if jsonl_content:
+ jsonl_content += b"\n"
+
+ return jsonl_content
+
+ @staticmethod
+ def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]:
+ """
+ Deserialize JSONL data to records.
+
+ Args:
+ data: JSONL bytes
+
+ Returns:
+ List of dictionaries
+ """
+ records = []
+
+ for line in data.splitlines():
+ if line:
+ records.append(orjson.loads(line))
+
+ return records
+
+ @staticmethod
+ def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
+ """Serialize a single record, converting special types."""
+
+ def _serialize(item: Any) -> Any:
+ if isinstance(item, datetime.datetime):
+ return item.isoformat()
+ if isinstance(item, dict):
+ return {key: _serialize(value) for key, value in item.items()}
+ if isinstance(item, list):
+ return [_serialize(value) for value in item]
+ return item
+
+ return cast(dict[str, Any], _serialize(record))
+
+ @staticmethod
+ def compute_checksum(data: bytes) -> str:
+ """Compute MD5 checksum of data."""
+ return hashlib.md5(data).hexdigest()
+
+
+# Singleton instance (lazy initialization)
+_archive_storage: ArchiveStorage | None = None
+_export_storage: ArchiveStorage | None = None
+
+
+def get_archive_storage() -> ArchiveStorage:
+ """
+ Get the archive storage singleton instance.
+
+ Returns:
+ ArchiveStorage instance
+
+ Raises:
+ ArchiveStorageNotConfiguredError: If archive storage is not configured
+ """
+ global _archive_storage
+ if _archive_storage is None:
+ archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
+ if not archive_bucket:
+ raise ArchiveStorageNotConfiguredError(
+ "Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
+ )
+ _archive_storage = ArchiveStorage(bucket=archive_bucket)
+ return _archive_storage
+
+
+def get_export_storage() -> ArchiveStorage:
+ """
+ Get the export storage singleton instance.
+
+ Returns:
+ ArchiveStorage instance
+ """
+ global _export_storage
+ if _export_storage is None:
+ export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
+ if not export_bucket:
+ raise ArchiveStorageNotConfiguredError(
+ "Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
+ )
+ _export_storage = ArchiveStorage(bucket=export_bucket)
+ return _export_storage
diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py
index 5bbf0c79a3..d4cb3e9971 100644
--- a/api/libs/broadcast_channel/channel.py
+++ b/api/libs/broadcast_channel/channel.py
@@ -2,6 +2,8 @@
Broadcast channel for Pub/Sub messaging.
"""
+from __future__ import annotations
+
import types
from abc import abstractmethod
from collections.abc import Iterator
@@ -129,6 +131,6 @@ class BroadcastChannel(Protocol):
"""
@abstractmethod
- def topic(self, topic: str) -> "Topic":
+ def topic(self, topic: str) -> Topic:
"""topic returns a `Topic` instance for the given topic name."""
...
diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py
index 1fc3db8156..5bb4f579c1 100644
--- a/api/libs/broadcast_channel/redis/channel.py
+++ b/api/libs/broadcast_channel/redis/channel.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
@@ -20,7 +22,7 @@ class BroadcastChannel:
):
self._client = redis_client
- def topic(self, topic: str) -> "Topic":
+ def topic(self, topic: str) -> Topic:
return Topic(self._client, topic)
diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py
index 16e3a80ee1..d190c51bbc 100644
--- a/api/libs/broadcast_channel/redis/sharded_channel.py
+++ b/api/libs/broadcast_channel/redis/sharded_channel.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
@@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel:
):
self._client = redis_client
- def topic(self, topic: str) -> "ShardedTopic":
+ def topic(self, topic: str) -> ShardedTopic:
return ShardedTopic(self._client, topic)
diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py
index ff74ccbe8e..0828cf80bf 100644
--- a/api/libs/email_i18n.py
+++ b/api/libs/email_i18n.py
@@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and
eliminates the need for repetitive language switching logic.
"""
+from __future__ import annotations
+
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Any, Protocol
@@ -53,7 +55,7 @@ class EmailLanguage(StrEnum):
ZH_HANS = "zh-Hans"
@classmethod
- def from_language_code(cls, language_code: str) -> "EmailLanguage":
+ def from_language_code(cls, language_code: str) -> EmailLanguage:
"""Convert a language code to EmailLanguage with fallback to English."""
if language_code == "zh-Hans":
return cls.ZH_HANS
diff --git a/api/libs/encryption.py b/api/libs/encryption.py
new file mode 100644
index 0000000000..81be8cce97
--- /dev/null
+++ b/api/libs/encryption.py
@@ -0,0 +1,66 @@
+"""
+Field Encoding/Decoding Utilities
+
+Provides Base64 decoding for sensitive fields (password, verification code)
+received from the frontend.
+
+Note: This uses Base64 encoding for obfuscation, not cryptographic encryption.
+Real security relies on HTTPS for transport layer encryption.
+"""
+
+import base64
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class FieldEncryption:
+ """Handle decoding of sensitive fields during transmission"""
+
+ @classmethod
+ def decrypt_field(cls, encoded_text: str) -> str | None:
+ """
+ Decode Base64 encoded field from frontend.
+
+ Args:
+ encoded_text: Base64 encoded text from frontend
+
+ Returns:
+ Decoded plaintext, or None if decoding fails
+ """
+ try:
+ # Decode base64
+ decoded_bytes = base64.b64decode(encoded_text)
+ decoded_text = decoded_bytes.decode("utf-8")
+ logger.debug("Field decoding successful")
+ return decoded_text
+
+ except Exception:
+ # Decoding failed - return None to trigger error in caller
+ return None
+
+ @classmethod
+ def decrypt_password(cls, encrypted_password: str) -> str | None:
+ """
+ Decrypt password field
+
+ Args:
+ encrypted_password: Encrypted password from frontend
+
+ Returns:
+ Decrypted password or None if decryption fails
+ """
+ return cls.decrypt_field(encrypted_password)
+
+ @classmethod
+ def decrypt_verification_code(cls, encrypted_code: str) -> str | None:
+ """
+ Decrypt verification code field
+
+ Args:
+ encrypted_code: Encrypted code from frontend
+
+ Returns:
+ Decrypted code or None if decryption fails
+ """
+ return cls.decrypt_field(encrypted_code)
diff --git a/api/libs/external_api.py b/api/libs/external_api.py
index 61a90ee4a9..e8592407c3 100644
--- a/api/libs/external_api.py
+++ b/api/libs/external_api.py
@@ -1,5 +1,4 @@
import re
-import sys
from collections.abc import Mapping
from typing import Any
@@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
- # Log stack
- exc_info: Any = sys.exc_info()
- if exc_info[1] is None:
- exc_info = (None, None, None)
- current_app.log_exception(exc_info)
+ # Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
+ # Explicit log_exception call removed to avoid duplicate log entries
return data, status_code
diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py
index 23eb8dca05..ef26699fb3 100644
--- a/api/libs/gmpy2_pkcs10aep_cipher.py
+++ b/api/libs/gmpy2_pkcs10aep_cipher.py
@@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
- m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
+ m_int: int = gmpy2.powmod(em_int, self._key.e, self._key.n) # type: ignore[attr-defined]
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
@@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
- m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
+ m_int: int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # type: ignore[attr-defined]
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a
diff --git a/api/libs/helper.py b/api/libs/helper.py
index a278ace6ad..07c4823727 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
+from uuid import UUID
from zoneinfo import available_timezones
from flask import Response, stream_with_context
@@ -31,6 +32,38 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def escape_like_pattern(pattern: str) -> str:
+ """
+ Escape special characters in a string for safe use in SQL LIKE patterns.
+
+ This function escapes the special characters used in SQL LIKE patterns:
+ - Backslash (\\) -> \\
+ - Percent (%) -> \\%
+ - Underscore (_) -> \\_
+
+ The escaped pattern can then be safely used in SQL LIKE queries with the
+ ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards.
+
+ Args:
+ pattern: The string pattern to escape
+
+ Returns:
+ Escaped string safe for use in SQL LIKE queries
+
+ Examples:
+ >>> escape_like_pattern("50% discount")
+ '50\\% discount'
+ >>> escape_like_pattern("test_data")
+ 'test\\_data'
+ >>> escape_like_pattern("path\\to\\file")
+ 'path\\\\to\\\\file'
+ """
+ if not pattern:
+ return pattern
+ # Escape backslash first, then percent and underscore
+ return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
+
+
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
"""
Extract tenant_id from Account or EndUser object.
@@ -119,6 +152,19 @@ def uuid_value(value: Any) -> str:
raise ValueError(error)
+def normalize_uuid(value: str | UUID) -> str:
+ if not value:
+ return ""
+
+ try:
+ return uuid_value(value)
+ except ValueError as exc:
+ raise ValueError("must be a valid UUID") from exc
+
+
+UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)]
+
+
def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r"^[a-zA-Z0-9_]+$", value):
@@ -184,7 +230,7 @@ def timezone(timezone_string):
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
if dify_config.DB_TYPE == "postgresql":
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
- elif dify_config.DB_TYPE == "mysql":
+ elif dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
else:
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")
@@ -215,7 +261,11 @@ def generate_text_hash(text: str) -> str:
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
- return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
+ return Response(
+ response=json.dumps(jsonable_encoder(response)),
+ status=200,
+ content_type="application/json; charset=utf-8",
+ )
else:
def generate() -> Generator:
diff --git a/api/libs/login.py b/api/libs/login.py
index 4b8ee2d1f8..73caa492fe 100644
--- a/api/libs/login.py
+++ b/api/libs/login.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
from collections.abc import Callable
from functools import wraps
-from typing import Any
+from typing import TYPE_CHECKING, Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
@@ -9,7 +11,9 @@ from werkzeug.local import LocalProxy
from configs import dify_config
from libs.token import check_csrf_token
from models import Account
-from models.model import EndUser
+
+if TYPE_CHECKING:
+ from models.model import EndUser
def current_account_with_tenant():
diff --git a/api/libs/smtp.py b/api/libs/smtp.py
index 4044c6f7ed..6f82f1440a 100644
--- a/api/libs/smtp.py
+++ b/api/libs/smtp.py
@@ -3,6 +3,8 @@ import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
+from configs import dify_config
+
logger = logging.getLogger(__name__)
@@ -19,20 +21,21 @@ class SMTPClient:
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
- smtp = None
+ smtp: smtplib.SMTP | None = None
+ local_host = dify_config.SMTP_LOCAL_HOSTNAME
try:
- if self.use_tls:
- if self.opportunistic_tls:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
- # Send EHLO command with the HELO domain name as the server address
- smtp.ehlo(self.server)
- smtp.starttls()
- # Resend EHLO command to identify the TLS session
- smtp.ehlo(self.server)
- else:
- smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
+ if self.use_tls and not self.opportunistic_tls:
+ # SMTP with SSL (implicit TLS)
+ smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10, local_hostname=local_host)
else:
- smtp = smtplib.SMTP(self.server, self.port, timeout=10)
+ # Plain SMTP or SMTP with STARTTLS (explicit TLS)
+ smtp = smtplib.SMTP(self.server, self.port, timeout=10, local_hostname=local_host)
+
+ assert smtp is not None
+ if self.use_tls and self.opportunistic_tls:
+ smtp.ehlo(self.server)
+ smtp.starttls()
+ smtp.ehlo(self.server)
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():
diff --git a/api/libs/workspace_permission.py b/api/libs/workspace_permission.py
new file mode 100644
index 0000000000..dd42a7facf
--- /dev/null
+++ b/api/libs/workspace_permission.py
@@ -0,0 +1,74 @@
+"""
+Workspace permission helper functions.
+
+These helpers check both billing/plan level and workspace-specific policy level permissions.
+Checks are performed at two levels:
+1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
+2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
+"""
+
+import logging
+
+from werkzeug.exceptions import Forbidden
+
+from configs import dify_config
+from services.enterprise.enterprise_service import EnterpriseService
+from services.feature_service import FeatureService
+
+logger = logging.getLogger(__name__)
+
+
+def check_workspace_member_invite_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows member invitations at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - For future expansion (currently no plan-level restriction)
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits member invitations
+ """
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_member_invite:
+ raise Forbidden("Workspace policy prohibits member invitations")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace invite permission for %s", workspace_id)
+
+
+def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
+ """
+ Check if workspace allows owner transfer at both billing and policy levels.
+
+ Checks performed:
+ 1. Billing/plan level - SANDBOX plan blocks owner transfer
+ 2. Enterprise policy level - Admin-configured workspace permission
+
+ Args:
+ workspace_id: The workspace ID to check permissions for
+
+ Raises:
+ Forbidden: If either billing plan or workspace policy prohibits ownership transfer
+ """
+ features = FeatureService.get_features(workspace_id)
+ if not features.is_allow_transfer_workspace:
+ raise Forbidden("Your current plan does not allow workspace ownership transfer")
+
+ # Check enterprise workspace policy level (only if enterprise enabled)
+ if dify_config.ENTERPRISE_ENABLED:
+ try:
+ permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
+ if not permission.allow_owner_transfer:
+ raise Forbidden("Workspace policy prohibits ownership transfer")
+ except Forbidden:
+ raise
+ except Exception:
+ logger.exception("Failed to check workspace transfer permission for %s", workspace_id)
diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py
index 17ed067d81..657d28f896 100644
--- a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py
+++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '00bacef91f18'
down_revision = '8ec536f3c800'
@@ -23,31 +20,17 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
- batch_op.drop_column('description_str')
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
- batch_op.drop_column('description_str')
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
+ batch_op.drop_column('description_str')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
- batch_op.drop_column('description')
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
- batch_op.drop_column('description')
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
+ batch_op.drop_column('description')
# ### end Alembic commands ###
diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py
index ed70bf5d08..912d9dbfa4 100644
--- a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py
+++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py
@@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '114eed84c228'
down_revision = 'c71211c8f604'
@@ -32,13 +28,7 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
- else:
- with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
+ with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py
index 509bd5d0e8..0ca905129d 100644
--- a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py
+++ b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '161cadc1af8d'
down_revision = '7e6a8693e07a'
@@ -23,16 +20,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
- # Step 1: Add column without NOT NULL constraint
- op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
- else:
- with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
- # Step 1: Add column without NOT NULL constraint
- op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
+ with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
+ # Step 1: Add column without NOT NULL constraint
+ op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py
index 0767b725f6..be1b42f883 100644
--- a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py
+++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py
@@ -9,11 +9,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-import sqlalchemy as sa
-from sqlalchemy.dialects import postgresql
-
# revision identifiers, used by Alembic.
revision = '6af6a521a53e'
down_revision = 'd57ba9ebb251'
@@ -23,58 +18,30 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('document_id',
- existing_type=sa.UUID(),
- nullable=True)
- batch_op.alter_column('data_source_type',
- existing_type=sa.TEXT(),
- nullable=True)
- batch_op.alter_column('segment_id',
- existing_type=sa.UUID(),
- nullable=True)
- else:
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('document_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
- batch_op.alter_column('data_source_type',
- existing_type=models.types.LongText(),
- nullable=True)
- batch_op.alter_column('segment_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+ with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+ batch_op.alter_column('document_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
+ batch_op.alter_column('data_source_type',
+ existing_type=models.types.LongText(),
+ nullable=True)
+ batch_op.alter_column('segment_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('segment_id',
- existing_type=sa.UUID(),
- nullable=False)
- batch_op.alter_column('data_source_type',
- existing_type=sa.TEXT(),
- nullable=False)
- batch_op.alter_column('document_id',
- existing_type=sa.UUID(),
- nullable=False)
- else:
- with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
- batch_op.alter_column('segment_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
- batch_op.alter_column('data_source_type',
- existing_type=models.types.LongText(),
- nullable=False)
- batch_op.alter_column('document_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+ with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+ batch_op.alter_column('segment_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
+ batch_op.alter_column('data_source_type',
+ existing_type=models.types.LongText(),
+ nullable=False)
+ batch_op.alter_column('document_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py
index a749c8bddf..5d12419bf7 100644
--- a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py
+++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py
@@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198
from alembic import op
import models as models
import sqlalchemy as sa
-from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd3f6769a94a3'
diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py
index 45842295ea..a49d6a52f6 100644
--- a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py
+++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py
@@ -28,85 +28,45 @@ def upgrade():
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
- if _is_pg(conn):
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.TEXT(),
- nullable=False)
+ with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ nullable=False)
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.TEXT(),
- nullable=False)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ nullable=False)
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.TEXT(),
- nullable=False)
- else:
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- nullable=False)
-
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- nullable=False)
-
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- nullable=False)
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.TEXT(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ nullable=True)
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.TEXT(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ nullable=True)
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=sa.TEXT(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
-
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
-
- with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
- batch_op.alter_column('custom_disclaimer',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
+ batch_op.alter_column('custom_disclaimer',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ nullable=True)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py
index fdd8984029..8a36c9c4a5 100644
--- a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py
+++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py
@@ -49,57 +49,33 @@ def upgrade():
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
- if _is_pg(conn):
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('graph',
- existing_type=sa.TEXT(),
- nullable=False)
- batch_op.alter_column('features',
- existing_type=sa.TEXT(),
- nullable=False)
- batch_op.alter_column('updated_at',
- existing_type=postgresql.TIMESTAMP(),
- nullable=False)
- else:
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('graph',
- existing_type=models.types.LongText(),
- nullable=False)
- batch_op.alter_column('features',
- existing_type=models.types.LongText(),
- nullable=False)
- batch_op.alter_column('updated_at',
- existing_type=sa.TIMESTAMP(),
- nullable=False)
+
+ with op.batch_alter_table('workflows', schema=None) as batch_op:
+ batch_op.alter_column('graph',
+ existing_type=models.types.LongText(),
+ nullable=False)
+ batch_op.alter_column('features',
+ existing_type=models.types.LongText(),
+ nullable=False)
+ batch_op.alter_column('updated_at',
+ existing_type=sa.TIMESTAMP(),
+ nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('updated_at',
- existing_type=postgresql.TIMESTAMP(),
- nullable=True)
- batch_op.alter_column('features',
- existing_type=sa.TEXT(),
- nullable=True)
- batch_op.alter_column('graph',
- existing_type=sa.TEXT(),
- nullable=True)
- else:
- with op.batch_alter_table('workflows', schema=None) as batch_op:
- batch_op.alter_column('updated_at',
- existing_type=sa.TIMESTAMP(),
- nullable=True)
- batch_op.alter_column('features',
- existing_type=models.types.LongText(),
- nullable=True)
- batch_op.alter_column('graph',
- existing_type=models.types.LongText(),
- nullable=True)
+ with op.batch_alter_table('workflows', schema=None) as batch_op:
+ batch_op.alter_column('updated_at',
+ existing_type=sa.TIMESTAMP(),
+ nullable=True)
+ batch_op.alter_column('features',
+ existing_type=models.types.LongText(),
+ nullable=True)
+ batch_op.alter_column('graph',
+ existing_type=models.types.LongText(),
+ nullable=True)
if _is_pg(conn):
with op.batch_alter_table('messages', schema=None) as batch_op:
diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py
index 16ca902726..1fc4a64df1 100644
--- a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py
+++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py
@@ -86,57 +86,30 @@ def upgrade():
def migrate_existing_provider_models_data():
"""migrate provider_models table data to provider_model_credentials"""
- conn = op.get_bind()
- # Define table structure for data manipulation
- if _is_pg(conn):
- provider_models_table = table('provider_models',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('encrypted_config', sa.Text()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime()),
- column('credential_id', models.types.StringUUID()),
- )
- else:
- provider_models_table = table('provider_models',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('encrypted_config', models.types.LongText()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime()),
- column('credential_id', models.types.StringUUID()),
- )
+ # Define table structure for data manipulatio
+ provider_models_table = table('provider_models',
+ column('id', models.types.StringUUID()),
+ column('tenant_id', models.types.StringUUID()),
+ column('provider_name', sa.String()),
+ column('model_name', sa.String()),
+ column('model_type', sa.String()),
+ column('encrypted_config', models.types.LongText()),
+ column('created_at', sa.DateTime()),
+ column('updated_at', sa.DateTime()),
+ column('credential_id', models.types.StringUUID()),
+ )
- if _is_pg(conn):
- provider_model_credentials_table = table('provider_model_credentials',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('credential_name', sa.String()),
- column('encrypted_config', sa.Text()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime())
- )
- else:
- provider_model_credentials_table = table('provider_model_credentials',
- column('id', models.types.StringUUID()),
- column('tenant_id', models.types.StringUUID()),
- column('provider_name', sa.String()),
- column('model_name', sa.String()),
- column('model_type', sa.String()),
- column('credential_name', sa.String()),
- column('encrypted_config', models.types.LongText()),
- column('created_at', sa.DateTime()),
- column('updated_at', sa.DateTime())
- )
+ provider_model_credentials_table = table('provider_model_credentials',
+ column('id', models.types.StringUUID()),
+ column('tenant_id', models.types.StringUUID()),
+ column('provider_name', sa.String()),
+ column('model_name', sa.String()),
+ column('model_type', sa.String()),
+ column('credential_name', sa.String()),
+ column('encrypted_config', models.types.LongText()),
+ column('created_at', sa.DateTime()),
+ column('updated_at', sa.DateTime())
+ )
# Get database connection
@@ -183,14 +156,8 @@ def migrate_existing_provider_models_data():
def downgrade():
# Re-add encrypted_config column to provider_models table
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('provider_models', schema=None) as batch_op:
- batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('provider_models', schema=None) as batch_op:
- batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('provider_models', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
if not context.is_offline_mode():
# Migrate data back from provider_model_credentials to provider_models
diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py
index 75b4d61173..79fe9d9bba 100644
--- a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py
+++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py
@@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695
from alembic import op
import models as models
import sqlalchemy as sa
-from libs.uuid_utils import uuidv7
def _is_pg(conn):
diff --git a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py
index 4f472fe4b4..cf2b973d2d 100644
--- a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py
+++ b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py
@@ -9,8 +9,6 @@ from alembic import op
import models as models
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@@ -23,12 +21,7 @@ depends_on = None
def upgrade():
# Add encrypted_headers column to tool_mcp_providers table
- conn = op.get_bind()
-
- if _is_pg(conn):
- op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))
- else:
- op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
+ op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
def downgrade():
diff --git a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py
index 8eac0dee10..bad516dcac 100644
--- a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py
+++ b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py
@@ -44,6 +44,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
)
+
if _is_pg(conn):
op.create_table('datasource_oauth_tenant_params',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -70,6 +71,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
)
+
if _is_pg(conn):
op.create_table('datasource_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -104,6 +106,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name')
)
+
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False)
@@ -133,6 +136,7 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey')
)
+
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False)
@@ -174,6 +178,7 @@ def upgrade():
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
)
+
if _is_pg(conn):
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -193,7 +198,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
else:
- # MySQL: Use compatible syntax
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
@@ -211,6 +215,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
+
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
@@ -236,6 +241,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey')
)
+
if _is_pg(conn):
op.create_table('pipelines',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -266,6 +272,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
)
+
if _is_pg(conn):
op.create_table('workflow_draft_variable_files',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -292,6 +299,7 @@ def upgrade():
sa.Column('value_type', sa.String(20), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey'))
)
+
if _is_pg(conn):
op.create_table('workflow_node_execution_offload',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@@ -316,6 +324,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')),
sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key'))
)
+
if _is_pg(conn):
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
@@ -342,6 +351,7 @@ def upgrade():
comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',)
)
batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False)
+
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
diff --git a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py
index 0776ab0818..ec0cfbd11d 100644
--- a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py
+++ b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py
@@ -9,8 +9,6 @@ from alembic import op
import models as models
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@@ -33,15 +31,9 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
- batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
- batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
- else:
- with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
- batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
- batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
+
+ with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
+ batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py
index 627219cc4b..12905b3674 100644
--- a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py
+++ b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py
@@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407
from alembic import op
import models as models
import sqlalchemy as sa
-from libs.uuid_utils import uuidv7
def _is_pg(conn):
return conn.dialect.name == "postgresql"
diff --git a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py
index 9641a15c89..c27c1058d1 100644
--- a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py
+++ b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py
@@ -105,6 +105,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client')
)
+
if _is_pg(conn):
op.create_table('trigger_subscriptions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
@@ -143,6 +144,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider')
)
+
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True)
batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False)
@@ -176,6 +178,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription')
)
+
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False)
@@ -207,6 +210,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node')
)
+
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False)
@@ -264,6 +268,7 @@ def upgrade():
sa.Column('finished_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
)
+
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
@@ -299,6 +304,7 @@ def upgrade():
sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'),
sa.UniqueConstraint('webhook_id', name='uniq_webhook_id')
)
+
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False)
diff --git a/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py
new file mode 100644
index 0000000000..624be1d073
--- /dev/null
+++ b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py
@@ -0,0 +1,60 @@
+"""make message annotation question not nullable
+
+Revision ID: 9e6fa5cbcd80
+Revises: 03f8dcbc611e
+Create Date: 2025-11-06 16:03:54.549378
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '9e6fa5cbcd80'
+down_revision = '288345cd01d1'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ bind = op.get_bind()
+ message_annotations = sa.table(
+ "message_annotations",
+ sa.column("id", sa.String),
+ sa.column("message_id", sa.String),
+ sa.column("question", sa.Text),
+ )
+ messages = sa.table(
+ "messages",
+ sa.column("id", sa.String),
+ sa.column("query", sa.Text),
+ )
+ update_question_from_message = (
+ sa.update(message_annotations)
+ .where(
+ sa.and_(
+ message_annotations.c.question.is_(None),
+ message_annotations.c.message_id.isnot(None),
+ )
+ )
+ .values(
+ question=sa.select(sa.func.coalesce(messages.c.query, ""))
+ .where(messages.c.id == message_annotations.c.message_id)
+ .scalar_subquery()
+ )
+ )
+ bind.execute(update_question_from_message)
+
+ fill_remaining_questions = (
+ sa.update(message_annotations)
+ .where(message_annotations.c.question.is_(None))
+ .values(question="")
+ )
+ bind.execute(fill_remaining_questions)
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False)
+
+
+def downgrade():
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True)
diff --git a/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py
new file mode 100644
index 0000000000..2bdd430e81
--- /dev/null
+++ b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py
@@ -0,0 +1,31 @@
+"""add type column not null default tool
+
+Revision ID: 03ea244985ce
+Revises: d57accd375ae
+Create Date: 2025-12-16 18:17:12.193877
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '03ea244985ce'
+down_revision = 'd57accd375ae'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('type', sa.String(length=50), server_default=sa.text("'tool'"), nullable=False))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
+ batch_op.drop_column('type')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py
new file mode 100644
index 0000000000..e89fcee7e5
--- /dev/null
+++ b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py
@@ -0,0 +1,46 @@
+"""add credit pool
+
+Revision ID: 7df29de0f6be
+Revises: 03ea244985ce
+Create Date: 2025-12-25 10:39:15.139304
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '7df29de0f6be'
+down_revision = '03ea244985ce'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('tenant_credit_pools',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
+ sa.Column('quota_limit', sa.BigInteger(), nullable=False),
+ sa.Column('quota_used', sa.BigInteger(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
+ )
+ with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
+ batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
+ batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+
+ with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
+ batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
+ batch_op.drop_index('tenant_credit_pool_pool_type_idx')
+
+ op.drop_table('tenant_credit_pools')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
new file mode 100644
index 0000000000..7e0cc8ec9d
--- /dev/null
+++ b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_add_workflow_run_created_at_id_idx.py
@@ -0,0 +1,30 @@
+"""add workflow_run_created_at_id_idx
+
+Revision ID: 905527cc8fd3
+Revises: 7df29de0f6be
+Create Date: 2025-01-09 16:30:02.462084
+
+"""
+from alembic import op
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '905527cc8fd3'
+down_revision = '7df29de0f6be'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_run_created_at_id_idx')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
new file mode 100644
index 0000000000..758369ba99
--- /dev/null
+++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py
@@ -0,0 +1,33 @@
+"""feat: add created_at id index to messages
+
+Revision ID: 3334862ee907
+Revises: 905527cc8fd3
+Create Date: 2026-01-12 17:29:44.846544
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '3334862ee907'
+down_revision = '905527cc8fd3'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('messages', schema=None) as batch_op:
+ batch_op.drop_index('message_created_at_id_idx')
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
new file mode 100644
index 0000000000..2e1af0c83f
--- /dev/null
+++ b/api/migrations/versions/2026_01_16_1715-288345cd01d1_change_workflow_node_execution_run_index.py
@@ -0,0 +1,35 @@
+"""change workflow node execution workflow_run index
+
+Revision ID: 288345cd01d1
+Revises: 3334862ee907
+Create Date: 2026-01-16 17:15:00.000000
+
+"""
+from alembic import op
+
+
+# revision identifiers, used by Alembic.
+revision = "288345cd01d1"
+down_revision = "3334862ee907"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_id_idx",
+ ["workflow_run_id"],
+ unique=False,
+ )
+
+
+def downgrade():
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_index("workflow_node_execution_workflow_run_id_idx")
+ batch_op.create_index(
+ "workflow_node_execution_workflow_run_idx",
+ ["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"],
+ unique=False,
+ )
diff --git a/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py
new file mode 100644
index 0000000000..b99ca04e3f
--- /dev/null
+++ b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py
@@ -0,0 +1,73 @@
+"""add table explore banner and trial
+
+Revision ID: f9f6d18a37f9
+Revises: 9e6fa5cbcd80
+Create Date: 2026-01-017 11:10:18.079355
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'f9f6d18a37f9'
+down_revision = '9e6fa5cbcd80'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('account_trial_app_records',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('account_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('count', sa.Integer(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
+ sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
+ )
+ with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
+ batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
+ batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
+
+ op.create_table('exporle_banners',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('content', sa.JSON(), nullable=False),
+ sa.Column('link', sa.String(length=255), nullable=False),
+ sa.Column('sort', sa.Integer(), nullable=False),
+ sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
+ )
+ op.create_table('trial_apps',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('trial_limit', sa.Integer(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
+ sa.UniqueConstraint('app_id', name='unique_trail_app_id')
+ )
+ with op.batch_alter_table('trial_apps', schema=None) as batch_op:
+ batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
+ batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('trial_apps', schema=None) as batch_op:
+ batch_op.drop_index('trial_app_tenant_id_idx')
+ batch_op.drop_index('trial_app_app_id_idx')
+
+ op.drop_table('trial_apps')
+ op.drop_table('exporle_banners')
+ with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
+ batch_op.drop_index('account_trial_app_record_app_id_idx')
+ batch_op.drop_index('account_trial_app_record_account_id_idx')
+
+ op.drop_table('account_trial_app_records')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py
new file mode 100644
index 0000000000..5e7298af54
--- /dev/null
+++ b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py
@@ -0,0 +1,95 @@
+"""create workflow_archive_logs
+
+Revision ID: 9d77545f524e
+Revises: f9f6d18a37f9
+Create Date: 2026-01-06 17:18:56.292479
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+def _is_pg(conn):
+ return conn.dialect.name == "postgresql"
+
+# revision identifiers, used by Alembic.
+revision = '9d77545f524e'
+down_revision = 'f9f6d18a37f9'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ conn = op.get_bind()
+ if _is_pg(conn):
+ op.create_table('workflow_archive_logs',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
+ sa.Column('log_id', models.types.StringUUID(), nullable=True),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_by_role', sa.String(length=255), nullable=False),
+ sa.Column('created_by', models.types.StringUUID(), nullable=False),
+ sa.Column('log_created_at', sa.DateTime(), nullable=True),
+ sa.Column('log_created_from', sa.String(length=255), nullable=True),
+ sa.Column('run_version', sa.String(length=255), nullable=False),
+ sa.Column('run_status', sa.String(length=255), nullable=False),
+ sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
+ sa.Column('run_error', models.types.LongText(), nullable=True),
+ sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('run_created_at', sa.DateTime(), nullable=False),
+ sa.Column('run_finished_at', sa.DateTime(), nullable=True),
+ sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
+ sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
+ )
+ else:
+ op.create_table('workflow_archive_logs',
+ sa.Column('id', models.types.StringUUID(), nullable=False),
+ sa.Column('log_id', models.types.StringUUID(), nullable=True),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
+ sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
+ sa.Column('created_by_role', sa.String(length=255), nullable=False),
+ sa.Column('created_by', models.types.StringUUID(), nullable=False),
+ sa.Column('log_created_at', sa.DateTime(), nullable=True),
+ sa.Column('log_created_from', sa.String(length=255), nullable=True),
+ sa.Column('run_version', sa.String(length=255), nullable=False),
+ sa.Column('run_status', sa.String(length=255), nullable=False),
+ sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
+ sa.Column('run_error', models.types.LongText(), nullable=True),
+ sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
+ sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('run_created_at', sa.DateTime(), nullable=False),
+ sa.Column('run_finished_at', sa.DateTime(), nullable=True),
+ sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
+ sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
+ sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
+ )
+ with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
+ batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False)
+ batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False)
+ batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False)
+
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_archive_log_workflow_run_id_idx')
+ batch_op.drop_index('workflow_archive_log_run_created_at_idx')
+ batch_op.drop_index('workflow_archive_log_app_idx')
+
+ op.drop_table('workflow_archive_logs')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py
new file mode 100644
index 0000000000..c6c72859dc
--- /dev/null
+++ b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py
@@ -0,0 +1,107 @@
+"""add summary index feature
+
+Revision ID: 788d3099ae3a
+Revises: 9d77545f524e
+Create Date: 2026-01-27 18:15:45.277928
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+def _is_pg(conn):
+ return conn.dialect.name == "postgresql"
+
+# revision identifiers, used by Alembic.
+revision = '788d3099ae3a'
+down_revision = '9d77545f524e'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ conn = op.get_bind()
+ if _is_pg(conn):
+ op.create_table('document_segment_summaries',
+ sa.Column('id', models.types.StringUUID(), nullable=False),
+ sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
+ sa.Column('document_id', models.types.StringUUID(), nullable=False),
+ sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
+ sa.Column('summary_content', models.types.LongText(), nullable=True),
+ sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
+ sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
+ sa.Column('tokens', sa.Integer(), nullable=True),
+ sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
+ sa.Column('error', models.types.LongText(), nullable=True),
+ sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+ sa.Column('disabled_at', sa.DateTime(), nullable=True),
+ sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey')
+ )
+ with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
+ batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
+ batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
+ batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
+ batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
+
+ with op.batch_alter_table('datasets', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
+
+ with op.batch_alter_table('documents', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False))
+ else:
+ # MySQL: Use compatible syntax
+ op.create_table(
+ 'document_segment_summaries',
+ sa.Column('id', models.types.StringUUID(), nullable=False),
+ sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
+ sa.Column('document_id', models.types.StringUUID(), nullable=False),
+ sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
+ sa.Column('summary_content', models.types.LongText(), nullable=True),
+ sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
+ sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
+ sa.Column('tokens', sa.Integer(), nullable=True),
+ sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
+ sa.Column('error', models.types.LongText(), nullable=True),
+ sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+ sa.Column('disabled_at', sa.DateTime(), nullable=True),
+ sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'),
+ )
+ with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
+ batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False)
+ batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False)
+ batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False)
+ batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False)
+
+ with op.batch_alter_table('datasets', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
+
+ with op.batch_alter_table('documents', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=False))
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+
+ with op.batch_alter_table('documents', schema=None) as batch_op:
+ batch_op.drop_column('need_summary')
+
+ with op.batch_alter_table('datasets', schema=None) as batch_op:
+ batch_op.drop_column('summary_index_setting')
+
+ with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op:
+ batch_op.drop_index('document_segment_summaries_status_idx')
+ batch_op.drop_index('document_segment_summaries_document_id_idx')
+ batch_op.drop_index('document_segment_summaries_dataset_id_idx')
+ batch_op.drop_index('document_segment_summaries_chunk_id_idx')
+
+ op.drop_table('document_segment_summaries')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py
index fae506906b..127ffd5599 100644
--- a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py
+++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '23db93619b9d'
down_revision = '8ae9bc661daa'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py
index 2676ef0b94..31829d8e58 100644
--- a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py
+++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py
@@ -62,14 +62,8 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.drop_index('app_annotation_settings_app_idx')
diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py
index 3362a3a09f..07a8cd86b1 100644
--- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py
+++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py
@@ -11,9 +11,6 @@ from alembic import op
import models as models
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '2a3aebbbf4bb'
down_revision = 'c031d46af369'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('apps', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('apps', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('apps', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py
index 40bd727f66..211b2d8882 100644
--- a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py
+++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py
@@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '2e9819ca5b28'
down_revision = 'ab23c11305d4'
@@ -24,35 +20,19 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
- batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
- batch_op.drop_column('dataset_id')
- else:
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
- batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
- batch_op.drop_column('dataset_id')
+ with op.batch_alter_table('api_tokens', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
+ batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
+ batch_op.drop_column('dataset_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
- batch_op.drop_index('api_token_tenant_idx')
- batch_op.drop_column('tenant_id')
- else:
- with op.batch_alter_table('api_tokens', schema=None) as batch_op:
- batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
- batch_op.drop_index('api_token_tenant_idx')
- batch_op.drop_column('tenant_id')
+ with op.batch_alter_table('api_tokens', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
+ batch_op.drop_index('api_token_tenant_idx')
+ batch_op.drop_column('tenant_id')
# ### end Alembic commands ###
diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py
index 76056a9460..3491c85e2f 100644
--- a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py
+++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py
@@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '42e85ed5564d'
down_revision = 'f9107f83abab'
@@ -24,59 +20,31 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('app_model_config_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
- else:
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('app_model_config_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=True)
+ with op.batch_alter_table('conversations', schema=None) as batch_op:
+ batch_op.alter_column('app_model_config_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
+ batch_op.alter_column('model_provider',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=True)
+ batch_op.alter_column('model_id',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('app_model_config_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- else:
- with op.batch_alter_table('conversations', schema=None) as batch_op:
- batch_op.alter_column('model_id',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('model_provider',
- existing_type=sa.VARCHAR(length=255),
- nullable=False)
- batch_op.alter_column('app_model_config_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+ with op.batch_alter_table('conversations', schema=None) as batch_op:
+ batch_op.alter_column('model_id',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=False)
+ batch_op.alter_column('model_provider',
+ existing_type=sa.VARCHAR(length=255),
+ nullable=False)
+ batch_op.alter_column('app_model_config_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py
index ef066587b7..8537a87233 100644
--- a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py
+++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py
@@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415
"""
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '4829e54d2fee'
down_revision = '114eed84c228'
@@ -23,39 +19,21 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- # PostgreSQL: Keep original syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- else:
- # MySQL: Use compatible syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+
+ with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+ batch_op.alter_column('message_chain_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- # PostgreSQL: Keep original syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- else:
- # MySQL: Use compatible syntax
- with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
- batch_op.alter_column('message_chain_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+
+ with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
+ batch_op.alter_column('message_chain_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py
index b080e7680b..22405e3cc8 100644
--- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py
+++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py
@@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506
"""
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '563cf8bf777b'
down_revision = 'b5429b71023c'
@@ -23,35 +19,19 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- else:
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+ with op.batch_alter_table('tool_files', schema=None) as batch_op:
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- else:
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
+ with op.batch_alter_table('tool_files', schema=None) as batch_op:
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py
index 1ace8ea5a0..01d7d5ba21 100644
--- a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py
+++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py
@@ -48,12 +48,9 @@ def upgrade():
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
- if _is_pg(conn):
- with op.batch_alter_table('datasets', schema=None) as batch_op:
- batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
- else:
- with op.batch_alter_table('datasets', schema=None) as batch_op:
- batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
+
+ with op.batch_alter_table('datasets', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py
index 457338ef42..0faa48f535 100644
--- a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py
+++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '714aafe25d39'
down_revision = 'f2a6fc85e260'
@@ -23,16 +20,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
- batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
- else:
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
- batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
+ with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
+ batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
index 7bcd1a1be3..aa7b4a21e2 100644
--- a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
+++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '77e83833755c'
down_revision = '6dcb43972bdc'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
index 3c0aa082d5..34a17697d3 100644
--- a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
+++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py
@@ -27,7 +27,6 @@ def upgrade():
conn = op.get_bind()
if _is_pg(conn):
- # PostgreSQL: Keep original syntax
op.create_table('tool_providers',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
@@ -40,7 +39,6 @@ def upgrade():
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
else:
- # MySQL: Use compatible syntax
op.create_table('tool_providers',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
@@ -52,12 +50,9 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
+
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py
index beea90b384..884839c010 100644
--- a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py
+++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '88072f0caa04'
down_revision = '246ba09cbbdb'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tenants', schema=None) as batch_op:
- batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('tenants', schema=None) as batch_op:
- batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('tenants', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/89c7899ca936_.py b/api/migrations/versions/89c7899ca936_.py
index 2420710e74..d26f1e82d6 100644
--- a/api/migrations/versions/89c7899ca936_.py
+++ b/api/migrations/versions/89c7899ca936_.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '89c7899ca936'
down_revision = '187385f442fc'
@@ -23,39 +20,21 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=sa.VARCHAR(length=255),
- type_=sa.Text(),
- existing_nullable=True)
- else:
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=sa.VARCHAR(length=255),
- type_=models.types.LongText(),
- existing_nullable=True)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('description',
+ existing_type=sa.VARCHAR(length=255),
+ type_=models.types.LongText(),
+ existing_nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=sa.Text(),
- type_=sa.VARCHAR(length=255),
- existing_nullable=True)
- else:
- with op.batch_alter_table('sites', schema=None) as batch_op:
- batch_op.alter_column('description',
- existing_type=models.types.LongText(),
- type_=sa.VARCHAR(length=255),
- existing_nullable=True)
+ with op.batch_alter_table('sites', schema=None) as batch_op:
+ batch_op.alter_column('description',
+ existing_type=models.types.LongText(),
+ type_=sa.VARCHAR(length=255),
+ existing_nullable=True)
# ### end Alembic commands ###
diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py
index 111e81240b..6022ea2c20 100644
--- a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py
+++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = '8ec536f3c800'
down_revision = 'ad472b61a054'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False))
- else:
- with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
- batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
+ with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
index 1c1c6cacbb..9d6d40114d 100644
--- a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
+++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
@@ -57,12 +57,9 @@ def upgrade():
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
+
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
if _is_pg(conn):
with op.batch_alter_table('upload_files', schema=None) as batch_op:
diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py
index 5d29d354f3..0b3f92a12e 100644
--- a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py
+++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py
@@ -24,7 +24,6 @@ def upgrade():
conn = op.get_bind()
if _is_pg(conn):
- # PostgreSQL: Keep original syntax
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False))
batch_op.drop_index('pinned_conversation_conversation_idx')
@@ -35,7 +34,6 @@ def upgrade():
batch_op.drop_index('saved_message_message_idx')
batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False)
else:
- # MySQL: Use compatible syntax
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'"), nullable=False))
batch_op.drop_index('pinned_conversation_conversation_idx')
diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py
index 616cb2f163..c8747a51f7 100644
--- a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py
+++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'a5b56fb053ef'
down_revision = 'd3d503a3471c'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py
index 900ff78036..f56aeb7e66 100644
--- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py
+++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'a9836e3baeee'
down_revision = '968fff4c0ab9'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/b24be59fbb04_.py b/api/migrations/versions/b24be59fbb04_.py
index b0a6d10d8c..ae91eaf1bc 100644
--- a/api/migrations/versions/b24be59fbb04_.py
+++ b/api/migrations/versions/b24be59fbb04_.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'b24be59fbb04'
down_revision = 'de95f5c77138'
@@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
index 772395c25b..c02c24c23f 100644
--- a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
+++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py
@@ -11,9 +11,6 @@ from alembic import op
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'b3a09c049e8e'
down_revision = '2e9819ca5b28'
@@ -23,20 +20,11 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
- batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
- batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
- batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
- batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
- batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
- batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
+ batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
+ batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
+ batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
# ### end Alembic commands ###
diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py
index 76be794ff4..fe51d1c78d 100644
--- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py
+++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py
@@ -7,7 +7,6 @@ Create Date: 2024-06-17 10:01:00.255189
"""
import sqlalchemy as sa
from alembic import op
-from sqlalchemy.dialects import postgresql
import models.types
diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py
index 9e02ec5d84..36e934f0fc 100644
--- a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py
+++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py
@@ -54,12 +54,9 @@ def upgrade():
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
- if _is_pg(conn):
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
- else:
- with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
- batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
+
+ with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
if _is_pg(conn):
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
@@ -68,54 +65,31 @@ def upgrade():
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'"), nullable=False))
- if _is_pg(conn):
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
- batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- batch_op.alter_column('message_id',
- existing_type=postgresql.UUID(),
- nullable=True)
- else:
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
- batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
- batch_op.alter_column('message_id',
- existing_type=models.types.StringUUID(),
- nullable=True)
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
+ batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
+ batch_op.alter_column('message_id',
+ existing_type=models.types.StringUUID(),
+ nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
- if _is_pg(conn):
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.alter_column('message_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- batch_op.alter_column('conversation_id',
- existing_type=postgresql.UUID(),
- nullable=False)
- batch_op.drop_column('hit_count')
- batch_op.drop_column('question')
- else:
- with op.batch_alter_table('message_annotations', schema=None) as batch_op:
- batch_op.alter_column('message_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
- batch_op.alter_column('conversation_id',
- existing_type=models.types.StringUUID(),
- nullable=False)
- batch_op.drop_column('hit_count')
- batch_op.drop_column('question')
+ with op.batch_alter_table('message_annotations', schema=None) as batch_op:
+ batch_op.alter_column('message_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
+ batch_op.alter_column('conversation_id',
+ existing_type=models.types.StringUUID(),
+ nullable=False)
+ batch_op.drop_column('hit_count')
+ batch_op.drop_column('question')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_column('type')
diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py
index 02098e91c1..ac1c14e50c 100644
--- a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py
+++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py
@@ -12,9 +12,6 @@ from sqlalchemy.dialects import postgresql
import models.types
-def _is_pg(conn):
- return conn.dialect.name == "postgresql"
-
# revision identifiers, used by Alembic.
revision = 'f2a6fc85e260'
down_revision = '46976cc39132'
@@ -24,16 +21,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
- conn = op.get_bind()
-
- if _is_pg(conn):
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
- batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
- else:
- with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
- batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
- batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
+ with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
+ batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
# ### end Alembic commands ###
diff --git a/api/models/__init__.py b/api/models/__init__.py
index 906bc3198e..74b33130ef 100644
--- a/api/models/__init__.py
+++ b/api/models/__init__.py
@@ -35,6 +35,7 @@ from .enums import (
WorkflowTriggerStatus,
)
from .model import (
+ AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@@ -47,6 +48,7 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
+ ExporleBanner,
IconType,
InstalledApp,
Message,
@@ -60,7 +62,9 @@ from .model import (
Site,
Tag,
TagBinding,
+ TenantCreditPool,
TraceAppConfig,
+ TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@@ -99,6 +103,7 @@ from .workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
+ WorkflowArchiveLog,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowNodeExecutionTriggeredFrom,
@@ -113,6 +118,7 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
+ "AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@@ -149,6 +155,7 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
+ "ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@@ -177,6 +184,7 @@ __all__ = [
"Tenant",
"TenantAccountJoin",
"TenantAccountRole",
+ "TenantCreditPool",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
@@ -186,6 +194,7 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
+ "TrialApp",
"TriggerOAuthSystemClient",
"TriggerOAuthTenantClient",
"TriggerSubscription",
@@ -195,6 +204,7 @@ __all__ = [
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
+ "WorkflowArchiveLog",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",
diff --git a/api/models/account.py b/api/models/account.py
index 420e6adc6c..f7a9c20026 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
-from sqlalchemy.orm import Mapped, Session, mapped_column
+from sqlalchemy.orm import Mapped, Session, mapped_column, validates
from typing_extensions import deprecated
from .base import TypeBase
@@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
+ @validates("status")
+ def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
+ if isinstance(value, AccountStatus):
+ return value.value
+ return value
+
@property
def is_password_set(self):
return self.password is not None
diff --git a/api/models/dataset.py b/api/models/dataset.py
index ba2eaf6749..e7da2961bc 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -72,6 +72,7 @@ class Dataset(Base):
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
+ summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
@@ -419,6 +420,7 @@ class Document(Base):
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
+ need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
@@ -1149,7 +1151,7 @@ class DatasetCollectionBinding(TypeBase):
)
-class TidbAuthBinding(Base):
+class TidbAuthBinding(TypeBase):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@@ -1158,7 +1160,13 @@ class TidbAuthBinding(Base):
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
- id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ primary_key=True,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -1166,7 +1174,9 @@ class TidbAuthBinding(Base):
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
- created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
class Whitelist(TypeBase):
@@ -1532,6 +1542,7 @@ class PipelineRecommendedPlugin(TypeBase):
)
plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
+ type: Mapped[str] = mapped_column(sa.String(50), nullable=False, server_default=sa.text("'tool'"))
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
created_at: Mapped[datetime] = mapped_column(
@@ -1566,3 +1577,36 @@ class SegmentAttachmentBinding(Base):
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+
+
+class DocumentSegmentSummary(Base):
+ __tablename__ = "document_segment_summaries"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
+ sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"),
+ sa.Index("document_segment_summaries_document_id_idx", "document_id"),
+ sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"),
+ sa.Index("document_segment_summaries_status_idx", "status"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
+ dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ # corresponds to DocumentSegment.id or parent chunk id
+ chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
+ summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
+ summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
+ tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
+ status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
+ error: Mapped[str] = mapped_column(LongText, nullable=True)
+ enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
+ disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+ disabled_by = mapped_column(StringUUID, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ )
+
+ def __repr__(self):
+ return f""
diff --git a/api/models/model.py b/api/models/model.py
index c8fbdc40ec..c1c6e04ce9 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import re
import uuid
@@ -5,13 +7,13 @@ from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
-from typing import TYPE_CHECKING, Any, Literal, Optional, cast
+from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
import sqlalchemy as sa
from flask import request
-from flask_login import UserMixin
-from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
+from flask_login import UserMixin # type: ignore[import-untyped]
+from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@@ -54,7 +56,7 @@ class AppMode(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
- def value_of(cls, value: str) -> "AppMode":
+ def value_of(cls, value: str) -> AppMode:
"""
Get value of given mode.
@@ -70,6 +72,7 @@ class AppMode(StrEnum):
class IconType(StrEnum):
IMAGE = auto()
EMOJI = auto()
+ LINK = auto()
class App(Base):
@@ -81,7 +84,7 @@ class App(Base):
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(LongText, default=sa.text("''"))
mode: Mapped[str] = mapped_column(String(255))
- icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji
+ icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link
icon = mapped_column(String(255))
icon_background: Mapped[str | None] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
@@ -111,24 +114,28 @@ class App(Base):
else:
app_model_config = self.app_model_config
if app_model_config:
- return app_model_config.pre_prompt
+ pre_prompt = app_model_config.pre_prompt or ""
+ # Truncate to 200 characters with ellipsis if using prompt as description
+ if len(pre_prompt) > 200:
+ return pre_prompt[:200] + "..."
+ return pre_prompt
else:
return ""
@property
- def site(self) -> Optional["Site"]:
+ def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
- def app_model_config(self) -> Optional["AppModelConfig"]:
+ def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return None
@property
- def workflow(self) -> Optional["Workflow"]:
+ def workflow(self) -> Workflow | None:
if self.workflow_id:
from .workflow import Workflow
@@ -283,7 +290,7 @@ class App(Base):
return deleted_tools
@property
- def tags(self) -> list["Tag"]:
+ def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
@@ -308,40 +315,48 @@ class App(Base):
return None
-class AppModelConfig(Base):
+class AppModelConfig(TypeBase):
__tablename__ = "app_model_configs"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- app_id = mapped_column(StringUUID, nullable=False)
- provider = mapped_column(String(255), nullable=True)
- model_id = mapped_column(String(255), nullable=True)
- configs = mapped_column(sa.JSON, nullable=True)
- created_by = mapped_column(StringUUID, nullable=True)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = mapped_column(StringUUID, nullable=True)
- updated_at = mapped_column(
- sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ configs: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True, default=None)
+ created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
- opening_statement = mapped_column(LongText)
- suggested_questions = mapped_column(LongText)
- suggested_questions_after_answer = mapped_column(LongText)
- speech_to_text = mapped_column(LongText)
- text_to_speech = mapped_column(LongText)
- more_like_this = mapped_column(LongText)
- model = mapped_column(LongText)
- user_input_form = mapped_column(LongText)
- dataset_query_variable = mapped_column(String(255))
- pre_prompt = mapped_column(LongText)
- agent_mode = mapped_column(LongText)
- sensitive_word_avoidance = mapped_column(LongText)
- retriever_resource = mapped_column(LongText)
- prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'"))
- chat_prompt_config = mapped_column(LongText)
- completion_prompt_config = mapped_column(LongText)
- dataset_configs = mapped_column(LongText)
- external_data_tools = mapped_column(LongText)
- file_upload = mapped_column(LongText)
+ updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ updated_at: Mapped[datetime] = mapped_column(
+ sa.DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
+ )
+ opening_statement: Mapped[str | None] = mapped_column(LongText, default=None)
+ suggested_questions: Mapped[str | None] = mapped_column(LongText, default=None)
+ suggested_questions_after_answer: Mapped[str | None] = mapped_column(LongText, default=None)
+ speech_to_text: Mapped[str | None] = mapped_column(LongText, default=None)
+ text_to_speech: Mapped[str | None] = mapped_column(LongText, default=None)
+ more_like_this: Mapped[str | None] = mapped_column(LongText, default=None)
+ model: Mapped[str | None] = mapped_column(LongText, default=None)
+ user_input_form: Mapped[str | None] = mapped_column(LongText, default=None)
+ dataset_query_variable: Mapped[str | None] = mapped_column(String(255), default=None)
+ pre_prompt: Mapped[str | None] = mapped_column(LongText, default=None)
+ agent_mode: Mapped[str | None] = mapped_column(LongText, default=None)
+ sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None)
+ retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None)
+ prompt_type: Mapped[str] = mapped_column(
+ String(255), nullable=False, server_default=sa.text("'simple'"), default="simple"
+ )
+ chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
+ completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
+ dataset_configs: Mapped[str | None] = mapped_column(LongText, default=None)
+ external_data_tools: Mapped[str | None] = mapped_column(LongText, default=None)
+ file_upload: Mapped[str | None] = mapped_column(LongText, default=None)
@property
def app(self) -> App | None:
@@ -596,6 +611,70 @@ class InstalledApp(TypeBase):
return tenant
+class TrialApp(Base):
+ __tablename__ = "trial_apps"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
+ sa.Index("trial_app_app_id_idx", "app_id"),
+ sa.Index("trial_app_tenant_id_idx", "tenant_id"),
+ sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
+ )
+
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+ trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
+
+ @property
+ def app(self) -> App | None:
+ app = db.session.query(App).where(App.id == self.app_id).first()
+ return app
+
+
+class AccountTrialAppRecord(Base):
+ __tablename__ = "account_trial_app_records"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
+ sa.Index("account_trial_app_record_account_id_idx", "account_id"),
+ sa.Index("account_trial_app_record_app_id_idx", "app_id"),
+ sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
+ )
+ id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
+ account_id = mapped_column(StringUUID, nullable=False)
+ app_id = mapped_column(StringUUID, nullable=False)
+ count = mapped_column(sa.Integer, nullable=False, default=0)
+ created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+
+ @property
+ def app(self) -> App | None:
+ app = db.session.query(App).where(App.id == self.app_id).first()
+ return app
+
+ @property
+ def user(self) -> Account | None:
+ user = db.session.query(Account).where(Account.id == self.account_id).first()
+ return user
+
+
+class ExporleBanner(TypeBase):
+ __tablename__ = "exporle_banners"
+ __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
+ id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
+ content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
+ link: Mapped[str] = mapped_column(String(255), nullable=False)
+ sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
+ status: Mapped[str] = mapped_column(
+ sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
+ )
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+ language: Mapped[str] = mapped_column(
+ String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US"
+ )
+
+
class OAuthProviderApp(TypeBase):
"""
Globally shared OAuth provider app information.
@@ -745,8 +824,8 @@ class Conversation(Base):
override_model_configs = json.loads(self.override_model_configs)
if "model" in override_model_configs:
- app_model_config = AppModelConfig()
- app_model_config = app_model_config.from_model_config_dict(override_model_configs)
+ # where is app_id?
+ app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
@@ -961,6 +1040,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
+ Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -1189,7 +1269,7 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
- def agent_thoughts(self) -> list["MessageAgentThought"]:
+ def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
@@ -1302,7 +1382,7 @@ class Message(Base):
}
@classmethod
- def from_dict(cls, data: dict[str, Any]) -> "Message":
+ def from_dict(cls, data: dict[str, Any]) -> Message:
return cls(
id=data["id"],
app_id=data["app_id"],
@@ -1415,15 +1495,20 @@ class MessageAnnotation(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[str | None] = mapped_column(StringUUID)
- question = mapped_column(LongText, nullable=True)
- content = mapped_column(LongText, nullable=False)
+ question: Mapped[str] = mapped_column(LongText, nullable=False)
+ content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
- account_id = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = mapped_column(
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
+ @property
+ def question_text(self) -> str:
+ """Return a non-null question string, falling back to the answer content."""
+ return self.question or self.content
+
@property
def account(self):
account = db.session.query(Account).where(Account.id == self.account_id).first()
@@ -1435,7 +1520,7 @@ class MessageAnnotation(Base):
return account
-class AppAnnotationHitHistory(Base):
+class AppAnnotationHitHistory(TypeBase):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1445,17 +1530,19 @@ class AppAnnotationHitHistory(Base):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- app_id = mapped_column(StringUUID, nullable=False)
+ id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- source = mapped_column(LongText, nullable=False)
- question = mapped_column(LongText, nullable=False)
- account_id = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
- score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
- message_id = mapped_column(StringUUID, nullable=False)
- annotation_question = mapped_column(LongText, nullable=False)
- annotation_content = mapped_column(LongText, nullable=False)
+ source: Mapped[str] = mapped_column(LongText, nullable=False)
+ question: Mapped[str] = mapped_column(LongText, nullable=False)
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+ score: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ annotation_question: Mapped[str] = mapped_column(LongText, nullable=False)
+ annotation_content: Mapped[str] = mapped_column(LongText, nullable=False)
@property
def account(self):
@@ -1524,7 +1611,7 @@ class OperationLog(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
action: Mapped[str] = mapped_column(String(255), nullable=False)
- content: Mapped[Any] = mapped_column(sa.JSON)
+ content: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@@ -1831,7 +1918,7 @@ class MessageChain(TypeBase):
)
-class MessageAgentThought(Base):
+class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
@@ -1839,34 +1926,42 @@ class MessageAgentThought(Base):
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
- id = mapped_column(StringUUID, default=lambda: str(uuid4()))
- message_id = mapped_column(StringUUID, nullable=False)
- message_chain_id = mapped_column(StringUUID, nullable=True)
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
+ )
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
- thought = mapped_column(LongText, nullable=True)
- tool = mapped_column(LongText, nullable=True)
- tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
- tool_input = mapped_column(LongText, nullable=True)
- observation = mapped_column(LongText, nullable=True)
+ created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+ tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
- tool_process_data = mapped_column(LongText, nullable=True)
- message = mapped_column(LongText, nullable=True)
- message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- message_unit_price = mapped_column(sa.Numeric, nullable=True)
- message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- message_files = mapped_column(LongText, nullable=True)
- answer = mapped_column(LongText, nullable=True)
- answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- answer_unit_price = mapped_column(sa.Numeric, nullable=True)
- answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
- tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- total_price = mapped_column(sa.Numeric, nullable=True)
- currency = mapped_column(String(255), nullable=True)
- latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
- created_by_role = mapped_column(String(255), nullable=False)
- created_by = mapped_column(StringUUID, nullable=False)
- created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
+ tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ message_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ answer_price_unit: Mapped[Decimal] = mapped_column(
+ sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
+ )
+ tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
+ currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
+ )
@property
def files(self) -> list[Any]:
@@ -2061,3 +2156,35 @@ class TraceAppConfig(TypeBase):
"created_at": str(self.created_at) if self.created_at else None,
"updated_at": str(self.updated_at) if self.updated_at else None,
}
+
+
+class TenantCreditPool(TypeBase):
+ __tablename__ = "tenant_credit_pools"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
+ sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
+ sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
+ quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
+ created_at: Mapped[datetime] = mapped_column(
+ sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), init=False
+ )
+ updated_at: Mapped[datetime] = mapped_column(
+ sa.DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
+ )
+
+ @property
+ def remaining_credits(self) -> int:
+ return max(0, self.quota_limit - self.quota_used)
+
+ def has_sufficient_credits(self, required_credits: int) -> bool:
+ return self.remaining_credits >= required_credits
diff --git a/api/models/provider.py b/api/models/provider.py
index 2afd8c5329..441b54c797 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from datetime import datetime
from enum import StrEnum, auto
from functools import cached_property
@@ -19,7 +21,7 @@ class ProviderType(StrEnum):
SYSTEM = auto()
@staticmethod
- def value_of(value: str) -> "ProviderType":
+ def value_of(value: str) -> ProviderType:
for member in ProviderType:
if member.value == value:
return member
@@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum):
"""hosted trial quota"""
@staticmethod
- def value_of(value: str) -> "ProviderQuotaType":
+ def value_of(value: str) -> ProviderQuotaType:
for member in ProviderQuotaType:
if member.value == value:
return member
@@ -76,7 +78,7 @@ class Provider(TypeBase):
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
- quota_used: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, default=0)
+ quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
diff --git a/api/models/tools.py b/api/models/tools.py
index e4f9bcb582..e7b98dcf27 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
from datetime import datetime
from decimal import Decimal
@@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase):
)
@property
- def schema_type(self) -> "ApiProviderSchemaType":
+ def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
- def tools(self) -> list["ApiToolBundle"]:
+ def tools(self) -> list[ApiToolBundle]:
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
@property
@@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
- def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
+ def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
return [
WorkflowToolParameterConfiguration.model_validate(config)
for config in json.loads(self.parameter_configuration)
@@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase):
except (json.JSONDecodeError, TypeError):
return []
- def to_entity(self) -> "MCPProviderEntity":
+ def to_entity(self) -> MCPProviderEntity:
"""Convert to domain entity"""
from core.entities.mcp_provider import MCPProviderEntity
@@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase):
)
@property
- def description_i18n(self) -> "I18nObject":
+ def description_i18n(self) -> I18nObject:
return I18nObject.model_validate(json.loads(self.description))
diff --git a/api/models/trigger.py b/api/models/trigger.py
index 87e2a5ccfc..209345eb84 100644
--- a/api/models/trigger.py
+++ b/api/models/trigger.py
@@ -415,7 +415,7 @@ class AppTrigger(TypeBase):
node_id: Mapped[str | None] = mapped_column(String(64), nullable=False)
trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
- provider_name: Mapped[str] = mapped_column(String(255), server_default="", default="") # why it is nullable?
+ provider_name: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default="")
status: Mapped[str] = mapped_column(
EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.ENABLED
)
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 853d5afefc..83956b1114 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -29,6 +29,7 @@ from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
+from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
from core.workflow.enums import NodeType
from extensions.ext_storage import Storage
@@ -44,7 +45,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
-from core.variables import SecretVariable, Segment, SegmentType, Variable
+from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@@ -176,8 +177,8 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
@@ -226,11 +227,10 @@ class Workflow(Base): # bug
#
# Currently, the following functions / methods would mutate the returned dict:
#
- # - `_get_graph_and_variable_pool_of_single_iteration`.
- # - `_get_graph_and_variable_pool_of_single_loop`.
+ # - `_get_graph_and_variable_pool_for_single_node_run`.
return json.loads(self.graph) if self.graph else {}
- def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
+ def get_node_config_by_id(self, node_id: str) -> NodeConfigDict:
"""Extract a node configuration from the workflow graph by node ID.
A node configuration is a dictionary containing the node's properties, including
the node's id, title, and its data as a dict.
@@ -248,8 +248,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
- assert isinstance(node_config, dict)
- return node_config
+ return NodeConfigDictAdapter.validate_python(node_config)
@staticmethod
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
@@ -445,7 +444,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
- var: Variable,
+ var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@@ -461,7 +460,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
- def environment_variables(self, value: Sequence[Variable]):
+ def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@@ -485,7 +484,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
- def encrypt_func(var: Variable) -> Variable:
+ def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -515,7 +514,7 @@ class Workflow(Base): # bug
return result
@property
- def conversation_variables(self) -> Sequence[Variable]:
+ def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@@ -525,7 +524,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
- def conversation_variables(self, value: Sequence[Variable]):
+ def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@@ -595,6 +594,7 @@ class WorkflowRun(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
+ sa.Index("workflow_run_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@@ -780,11 +780,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
- "workflow_node_execution_workflow_run_idx",
- "tenant_id",
- "app_id",
- "workflow_id",
- "triggered_from",
+ "workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
@@ -1166,6 +1162,69 @@ class WorkflowAppLog(TypeBase):
}
+class WorkflowArchiveLog(TypeBase):
+ """
+ Workflow archive log.
+
+ Stores essential workflow run snapshot data for archived app logs.
+
+ Field sources:
+ - Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency.
+ - log_* fields: from WorkflowAppLog when present; null if the run has no app log.
+ - run_* fields: workflow run snapshot fields from WorkflowRun.
+ - trigger_metadata: snapshot from WorkflowTriggerLog when present.
+ """
+
+ __tablename__ = "workflow_archive_logs"
+ __table_args__ = (
+ sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"),
+ sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"),
+ sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"),
+ sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"),
+ )
+
+ id: Mapped[str] = mapped_column(
+ StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
+ )
+
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+
+ log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
+ log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+ log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
+
+ run_version: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_status: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False)
+ run_error: Mapped[str | None] = mapped_column(LongText, nullable=True)
+ run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
+ run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
+ run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
+ run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
+ run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
+ run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
+
+ trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True)
+ archived_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
+ )
+
+ @property
+ def workflow_run_summary(self) -> dict[str, Any]:
+ return {
+ "id": self.workflow_run_id,
+ "status": self.run_status,
+ "triggered_from": self.run_triggered_from,
+ "elapsed_time": self.run_elapsed_time,
+ "total_tokens": self.run_total_tokens,
+ }
+
+
class ConversationVariable(TypeBase):
__tablename__ = "workflow_conversation_variables"
@@ -1181,7 +1240,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
- def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable":
+ def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@@ -1190,7 +1249,7 @@ class ConversationVariable(TypeBase):
)
return obj
- def to_variable(self) -> Variable:
+ def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@@ -1506,6 +1565,7 @@ class WorkflowDraftVariable(Base):
file_id: str | None = None,
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
+ variable.id = str(uuid4())
variable.created_at = naive_utc_now()
variable.updated_at = naive_utc_now()
variable.description = description
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 2a8432f571..97e6c83ed6 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,9 +1,10 @@
[project]
name = "dify-api"
-version = "1.11.0"
+version = "1.12.0"
requires-python = ">=3.11,<3.13"
dependencies = [
+ "aliyun-log-python-sdk~=0.9.37",
"arize-phoenix-otel~=0.9.2",
"azure-identity==1.16.1",
"beautifulsoup4==4.12.2",
@@ -30,7 +31,8 @@ dependencies = [
"gunicorn~=23.0.0",
"httpx[socks]~=0.27.0",
"jieba==0.42.1",
- "json-repair>=0.41.1",
+ "json-repair>=0.55.1",
+ "jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
"markdown~=3.5.1",
@@ -62,12 +64,12 @@ dependencies = [
"pandas[excel,output-formatting,performance]~=2.2.2",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
- "pycryptodome==3.19.1",
+ "pycryptodome==3.23.0",
"pydantic~=2.11.4",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.11.0",
"pyjwt~=2.10.1",
- "pypdfium2==4.30.0",
+ "pypdfium2==5.2.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"pyyaml~=6.0.1",
@@ -85,13 +87,13 @@ dependencies = [
"sseclient-py~=1.8.0",
"httpx-sse~=0.4.0",
"sendgrid~=6.12.3",
- "flask-restx~=1.3.0",
+ "flask-restx~=1.3.2",
"packaging~=23.2",
"croniter>=6.0.0",
"weaviate-client==4.17.0",
"apscheduler>=3.11.0",
"weave>=0.52.16",
- "jsonschema>=4.25.1",
+ "fastopenapi[flask]>=0.7.0",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -114,7 +116,7 @@ dev = [
"dotenv-linter~=0.5.0",
"faker~=38.2.0",
"lxml-stubs~=0.5.1",
- "ty~=0.0.1a19",
+ "ty>=0.0.14",
"basedpyright~=1.31.0",
"ruff~=0.14.0",
"pytest~=8.3.2",
@@ -173,6 +175,7 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
+ "pytest-xdist>=3.8.0",
]
############################################################
@@ -188,7 +191,7 @@ storage = [
"opendal~=0.46.0",
"oss2==2.18.5",
"supabase~=2.18.1",
- "tos~=2.7.1",
+ "tos~=2.9.0",
]
############################################################
@@ -216,6 +219,7 @@ vdb = [
"pymochow==2.2.9",
"pyobvector~=0.2.17",
"qdrant-client==1.9.0",
+ "intersystems-irispython>=5.1.0",
"tablestore==6.3.7",
"tcvectordb~=1.6.4",
"tidb-vector==0.0.9",
diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json
index 6a689b96df..007c49ddb0 100644
--- a/api/pyrightconfig.json
+++ b/api/pyrightconfig.json
@@ -8,6 +8,7 @@
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [
+ "fastopenapi",
"flask_restx",
"flask_login",
"opentelemetry.instrumentation.celery",
diff --git a/api/pytest.ini b/api/pytest.ini
index afb53b47cc..4a9470fa0c 100644
--- a/api/pytest.ini
+++ b/api/pytest.ini
@@ -1,5 +1,5 @@
[pytest]
-addopts = --cov=./api --cov-report=json --cov-report=xml
+addopts = --cov=./api --cov-report=json
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py
index fa2c94b623..5b3f635301 100644
--- a/api/repositories/api_workflow_node_execution_repository.py
+++ b/api/repositories/api_workflow_node_execution_repository.py
@@ -13,8 +13,10 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
-from models.workflow import WorkflowNodeExecutionModel
+from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
@@ -130,6 +132,18 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
"""
...
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions and offloads for the given workflow run ids.
+ """
+ ...
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions and offloads for the given workflow run ids.
+ """
+ ...
+
def delete_executions_by_app(
self,
tenant_id: str,
@@ -195,3 +209,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
The number of executions deleted
"""
...
+
+ def get_offloads_by_execution_ids(
+ self,
+ session: Session,
+ node_execution_ids: Sequence[str],
+ ) -> Sequence[WorkflowNodeExecutionOffload]:
+ """
+ Get offload records by node execution IDs.
+
+ This method retrieves workflow node execution offload records
+ that belong to the given node execution IDs.
+
+ Args:
+ session: The database session to use
+ node_execution_ids: List of node execution IDs to filter by
+
+ Returns:
+ A sequence of WorkflowNodeExecutionOffload instances
+ """
+ ...
diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py
index fd547c78ba..1d3954571f 100644
--- a/api/repositories/api_workflow_run_repository.py
+++ b/api/repositories/api_workflow_run_repository.py
@@ -34,15 +34,18 @@ Example:
```
"""
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from typing import Protocol
+from sqlalchemy.orm import Session
+
from core.workflow.entities.pause_reason import PauseReason
+from core.workflow.enums import WorkflowType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
-from models.workflow import WorkflowRun
+from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
AverageInteractionStats,
@@ -253,6 +256,151 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+ """
+ ...
+
+ def get_archived_run_ids(
+ self,
+ session: Session,
+ run_ids: Sequence[str],
+ ) -> set[str]:
+ """
+ Fetch workflow run IDs that already have archive log records.
+ """
+ ...
+
+ def get_archived_logs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowArchiveLog]:
+ """
+ Fetch archived workflow logs by time range for restore.
+ """
+ ...
+
+ def get_archived_log_by_run_id(
+ self,
+ run_id: str,
+ ) -> WorkflowArchiveLog | None:
+ """
+ Fetch a workflow archive log by workflow run ID.
+ """
+ ...
+
+ def delete_archive_log_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> int:
+ """
+ Delete archive log by workflow run ID.
+
+ Used after restoring a workflow run to remove the archive log record,
+ allowing the run to be archived again if needed.
+
+ Args:
+ session: Database session
+ run_id: Workflow run ID
+
+ Returns:
+ Number of records deleted (0 or 1)
+ """
+ ...
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Delete workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons).
+ """
+ ...
+
+ def get_pause_records_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowPause]:
+ """
+ Fetch workflow pause records by workflow run ID.
+ """
+ ...
+
+ def get_pause_reason_records_by_run_id(
+ self,
+ session: Session,
+ pause_ids: Sequence[str],
+ ) -> Sequence[WorkflowPauseReason]:
+ """
+ Fetch workflow pause reason records by pause IDs.
+ """
+ ...
+
+ def get_app_logs_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowAppLog]:
+ """
+ Fetch workflow app logs by workflow run ID.
+ """
+ ...
+
+ def create_archive_logs(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ app_logs: Sequence[WorkflowAppLog],
+ trigger_metadata: str | None,
+ ) -> int:
+ """
+ Create archive log records for a workflow run.
+ """
+ ...
+
+ def get_archived_runs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Return workflow runs that already have archive logs, for cleanup of `workflow_runs`.
+ """
+ ...
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ """
+ Count workflow runs and their related records (node executions, offloads, app logs,
+ trigger logs, pauses, pause reasons) without deleting data.
+ """
+ ...
+
def create_workflow_pause(
self,
workflow_run_id: str,
diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
index 7e2173acdd..b19cc73bd1 100644
--- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
@@ -9,11 +9,14 @@ from collections.abc import Sequence
from datetime import datetime
from typing import cast
-from sqlalchemy import asc, delete, desc, select
+from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
-from models.workflow import WorkflowNodeExecutionModel
+from models.workflow import (
+ WorkflowNodeExecutionModel,
+ WorkflowNodeExecutionOffload,
+)
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@@ -290,3 +293,85 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
result = cast(CursorResult, session.execute(stmt))
session.commit()
return result.rowcount
+
+ def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+
+ offloads_deleted = (
+ cast(
+ CursorResult,
+ session.execute(
+ delete(WorkflowNodeExecutionOffload).where(
+ WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
+ )
+ ),
+ ).rowcount
+ or 0
+ )
+
+ node_executions_deleted = (
+ cast(
+ CursorResult,
+ session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)),
+ ).rowcount
+ or 0
+ )
+
+ return node_executions_deleted, offloads_deleted
+
+ def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
+ """
+ Count node executions (and offloads) for the given workflow runs using workflow_run_id.
+ """
+ if not run_ids:
+ return 0, 0
+
+ run_ids = list(run_ids)
+ run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
+
+ node_executions_count = (
+ session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0
+ )
+ node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
+ offloads_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowNodeExecutionOffload)
+ .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
+ )
+ or 0
+ )
+
+ return int(node_executions_count), int(offloads_count)
+
+ @staticmethod
+ def get_by_run(
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Fetch node executions for a run using workflow_run_id.
+ """
+ stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def get_offloads_by_execution_ids(
+ self,
+ session: Session,
+ node_execution_ids: Sequence[str],
+ ) -> Sequence[WorkflowNodeExecutionOffload]:
+ if not node_execution_ids:
+ return []
+
+ stmt = select(WorkflowNodeExecutionOffload).where(
+ WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
+ )
+ return list(session.scalars(stmt))
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
index b172c6a3ac..d5214be042 100644
--- a/api/repositories/sqlalchemy_api_workflow_run_repository.py
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -21,7 +21,7 @@ Implementation Notes:
import logging
import uuid
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from decimal import Decimal
from typing import Any, cast
@@ -32,7 +32,7 @@ from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
-from core.workflow.enums import WorkflowExecutionStatus
+from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import convert_datetime_to_date
@@ -40,8 +40,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
-from models.workflow import WorkflowPause as WorkflowPauseModel
-from models.workflow import WorkflowPauseReason, WorkflowRun
+from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
@@ -314,6 +313,335 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
return total_deleted
+ def get_runs_batch_by_time_range(
+ self,
+ start_from: datetime | None,
+ end_before: datetime,
+ last_seen: tuple[datetime, str] | None,
+ batch_size: int,
+ run_types: Sequence[WorkflowType] | None = None,
+ tenant_ids: Sequence[str] | None = None,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Fetch ended workflow runs in a time window for archival and clean batching.
+
+ Query scope:
+ - created_at in [start_from, end_before)
+ - type in run_types (when provided)
+ - status is an ended state
+ - optional tenant_id filter and cursor (last_seen) for pagination
+ """
+ with self._session_maker() as session:
+ stmt = (
+ select(WorkflowRun)
+ .where(
+ WorkflowRun.created_at < end_before,
+ WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
+ )
+ .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
+ .limit(batch_size)
+ )
+ if run_types is not None:
+ if not run_types:
+ return []
+ stmt = stmt.where(WorkflowRun.type.in_(run_types))
+
+ if start_from:
+ stmt = stmt.where(WorkflowRun.created_at >= start_from)
+
+ if tenant_ids:
+ stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids))
+
+ if last_seen:
+ stmt = stmt.where(
+ or_(
+ WorkflowRun.created_at > last_seen[0],
+ and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
+ )
+ )
+
+ return session.scalars(stmt).all()
+
+ def get_archived_run_ids(
+ self,
+ session: Session,
+ run_ids: Sequence[str],
+ ) -> set[str]:
+ if not run_ids:
+ return set()
+
+ stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids))
+ return set(session.scalars(stmt).all())
+
+ def get_archived_log_by_run_id(
+ self,
+ run_id: str,
+ ) -> WorkflowArchiveLog | None:
+ with self._session_maker() as session:
+ stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1)
+ return session.scalar(stmt)
+
+ def delete_archive_log_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> int:
+ stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id)
+ result = session.execute(stmt)
+ return cast(CursorResult, result).rowcount or 0
+
+ def get_pause_records_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowPause]:
+ stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def get_pause_reason_records_by_run_id(
+ self,
+ session: Session,
+ pause_ids: Sequence[str],
+ ) -> Sequence[WorkflowPauseReason]:
+ if not pause_ids:
+ return []
+
+ stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ return list(session.scalars(stmt))
+
+ def delete_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if delete_node_executions:
+ node_executions_deleted, offloads_deleted = delete_node_executions(session, runs)
+ else:
+ node_executions_deleted, offloads_deleted = 0, 0
+
+ app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))
+ app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0
+
+ pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
+ pause_ids = session.scalars(pause_stmt).all()
+ pause_reasons_deleted = 0
+ pauses_deleted = 0
+
+ if pause_ids:
+ pause_reasons_result = session.execute(
+ delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0
+ pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids)))
+ pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
+
+ trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
+
+ runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
+ runs_deleted = cast(CursorResult, runs_result).rowcount or 0
+
+ session.commit()
+
+ return {
+ "runs": runs_deleted,
+ "node_executions": node_executions_deleted,
+ "offloads": offloads_deleted,
+ "app_logs": app_logs_deleted,
+ "trigger_logs": trigger_logs_deleted,
+ "pauses": pauses_deleted,
+ "pause_reasons": pause_reasons_deleted,
+ }
+
+ def get_app_logs_by_run_id(
+ self,
+ session: Session,
+ run_id: str,
+ ) -> Sequence[WorkflowAppLog]:
+ stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id)
+ return list(session.scalars(stmt))
+
+ def create_archive_logs(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ app_logs: Sequence[WorkflowAppLog],
+ trigger_metadata: str | None,
+ ) -> int:
+ if not app_logs:
+ archive_log = WorkflowArchiveLog(
+ log_id=None,
+ log_created_at=None,
+ log_created_from=None,
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_id=run.workflow_id,
+ workflow_run_id=run.id,
+ created_by_role=run.created_by_role,
+ created_by=run.created_by,
+ run_version=run.version,
+ run_status=run.status,
+ run_triggered_from=run.triggered_from,
+ run_error=run.error,
+ run_elapsed_time=run.elapsed_time,
+ run_total_tokens=run.total_tokens,
+ run_total_steps=run.total_steps,
+ run_created_at=run.created_at,
+ run_finished_at=run.finished_at,
+ run_exceptions_count=run.exceptions_count,
+ trigger_metadata=trigger_metadata,
+ )
+ session.add(archive_log)
+ return 1
+
+ archive_logs = [
+ WorkflowArchiveLog(
+ log_id=app_log.id,
+ log_created_at=app_log.created_at,
+ log_created_from=app_log.created_from,
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_id=run.workflow_id,
+ workflow_run_id=run.id,
+ created_by_role=run.created_by_role,
+ created_by=run.created_by,
+ run_version=run.version,
+ run_status=run.status,
+ run_triggered_from=run.triggered_from,
+ run_error=run.error,
+ run_elapsed_time=run.elapsed_time,
+ run_total_tokens=run.total_tokens,
+ run_total_steps=run.total_steps,
+ run_created_at=run.created_at,
+ run_finished_at=run.finished_at,
+ run_exceptions_count=run.exceptions_count,
+ trigger_metadata=trigger_metadata,
+ )
+ for app_log in app_logs
+ ]
+ session.add_all(archive_logs)
+ return len(archive_logs)
+
+ def get_archived_runs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Retrieves WorkflowRun records by joining workflow_archive_logs.
+
+ Used to identify runs that are already archived and ready for deletion.
+ """
+ stmt = (
+ select(WorkflowRun)
+ .join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id)
+ .where(
+ WorkflowArchiveLog.run_created_at >= start_date,
+ WorkflowArchiveLog.run_created_at < end_date,
+ )
+ .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
+ .limit(limit)
+ )
+ if tenant_ids:
+ stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
+ return list(session.scalars(stmt))
+
+ def get_archived_logs_by_time_range(
+ self,
+ session: Session,
+ tenant_ids: Sequence[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int,
+ ) -> Sequence[WorkflowArchiveLog]:
+ # Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted.
+ stmt = (
+ select(WorkflowArchiveLog)
+ .where(
+ WorkflowArchiveLog.run_created_at >= start_date,
+ WorkflowArchiveLog.run_created_at < end_date,
+ )
+ .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
+ .limit(limit)
+ )
+ if tenant_ids:
+ stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
+ return list(session.scalars(stmt))
+
+ def count_runs_with_related(
+ self,
+ runs: Sequence[WorkflowRun],
+ count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None,
+ count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
+ ) -> dict[str, int]:
+ if not runs:
+ return {
+ "runs": 0,
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ with self._session_maker() as session:
+ run_ids = [run.id for run in runs]
+ if count_node_executions:
+ node_executions_count, offloads_count = count_node_executions(session, runs)
+ else:
+ node_executions_count, offloads_count = 0, 0
+
+ app_logs_count = (
+ session.scalar(
+ select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))
+ )
+ or 0
+ )
+
+ pause_ids = session.scalars(
+ select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
+ ).all()
+ pauses_count = len(pause_ids)
+ pause_reasons_count = 0
+ if pause_ids:
+ pause_reasons_count = (
+ session.scalar(
+ select(func.count())
+ .select_from(WorkflowPauseReason)
+ .where(WorkflowPauseReason.pause_id.in_(pause_ids))
+ )
+ or 0
+ )
+
+ trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0
+
+ return {
+ "runs": len(runs),
+ "node_executions": node_executions_count,
+ "offloads": offloads_count,
+ "app_logs": int(app_logs_count),
+ "trigger_logs": trigger_logs_count,
+ "pauses": pauses_count,
+ "pause_reasons": int(pause_reasons_count),
+ }
+
def create_workflow_pause(
self,
workflow_run_id: str,
@@ -340,9 +668,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
RuntimeError: If workflow is already paused or in invalid state
"""
- previous_pause_model_query = select(WorkflowPauseModel).where(
- WorkflowPauseModel.workflow_run_id == workflow_run_id
- )
+ previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id)
with self._session_maker() as session, session.begin():
# Get the workflow run
workflow_run = session.get(WorkflowRun, workflow_run_id)
@@ -367,7 +693,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Upload the state file
# Create the pause record
- pause_model = WorkflowPauseModel()
+ pause_model = WorkflowPause()
pause_model.id = str(uuidv7())
pause_model.workflow_id = workflow_run.workflow_id
pause_model.workflow_run_id = workflow_run.id
@@ -539,13 +865,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
"""
with self._session_maker() as session, session.begin():
# Get the pause model by ID
- pause_model = session.get(WorkflowPauseModel, pause_entity.id)
+ pause_model = session.get(WorkflowPause, pause_entity.id)
if pause_model is None:
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
self._delete_pause_model(session, pause_model)
@staticmethod
- def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
+ def _delete_pause_model(session: Session, pause_model: WorkflowPause):
storage.delete(pause_model.state_object_key)
# Delete the pause record
@@ -580,15 +906,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
_limit: int = limit or 1000
pruned_record_ids: list[str] = []
cond = or_(
- WorkflowPauseModel.created_at < expiration,
+ WorkflowPause.created_at < expiration,
and_(
- WorkflowPauseModel.resumed_at.is_not(null()),
- WorkflowPauseModel.resumed_at < resumption_expiration,
+ WorkflowPause.resumed_at.is_not(null()),
+ WorkflowPause.resumed_at < resumption_expiration,
),
)
# First, collect pause records to delete with their state files
# Expired pauses (created before expiration time)
- stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
+ stmt = select(WorkflowPause).where(cond).limit(_limit)
with self._session_maker(expire_on_commit=False) as session:
# Old resumed pauses (resumed more than resumption_duration ago)
@@ -599,7 +925,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
# Delete state files from storage
for pause in pauses_to_delete:
with self._session_maker(expire_on_commit=False) as session, session.begin():
- # todo: this issues a separate query for each WorkflowPauseModel record.
+ # todo: this issues a separate query for each WorkflowPause record.
# consider batching this lookup.
try:
storage.delete(pause.state_object_key)
@@ -851,7 +1177,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
def __init__(
self,
*,
- pause_model: WorkflowPauseModel,
+ pause_model: WorkflowPause,
reason_models: Sequence[WorkflowPauseReason],
human_input_form: Sequence = (),
) -> None:
diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
index 0d67e286b0..f3dc4cd60b 100644
--- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
+++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py
@@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository.
from collections.abc import Sequence
from datetime import UTC, datetime, timedelta
+from typing import cast
-from sqlalchemy import and_, select
+from sqlalchemy import and_, delete, func, select
+from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from models.enums import WorkflowTriggerStatus
@@ -44,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
return self.session.scalar(query)
+ def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]:
+ """List trigger logs for a workflow run."""
+ query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id)
+ return list(self.session.scalars(query).all())
+
def get_failed_for_retry(
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
) -> Sequence[WorkflowTriggerLog]:
@@ -84,3 +91,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
)
return list(self.session.scalars(query).all())
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows deleted.
+ """
+ if not run_ids:
+ return 0
+
+ result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)))
+ return cast(CursorResult, result).rowcount or 0
+
+ def count_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Count trigger logs associated with the given workflow run ids.
+
+ Args:
+ run_ids: Collection of workflow run identifiers.
+
+ Returns:
+ Number of rows matched.
+ """
+ if not run_ids:
+ return 0
+
+ count = self.session.scalar(
+ select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
+ )
+ return int(count or 0)
diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py
index 138b8779ac..b0009e398d 100644
--- a/api/repositories/workflow_trigger_log_repository.py
+++ b/api/repositories/workflow_trigger_log_repository.py
@@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol):
A sequence of recent WorkflowTriggerLog instances
"""
...
+
+ def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
+ """
+ Delete trigger logs for workflow run IDs.
+
+ Args:
+ run_ids: Workflow run IDs to delete
+
+ Returns:
+ Number of rows deleted
+ """
+ ...
diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py
index 352a84b592..be5f483b95 100644
--- a/api/schedule/clean_messages.py
+++ b/api/schedule/clean_messages.py
@@ -1,90 +1,78 @@
-import datetime
import logging
import time
import click
-from sqlalchemy.exc import SQLAlchemyError
+from redis.exceptions import LockError
import app
from configs import dify_config
-from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
-from models.model import (
- App,
- Message,
- MessageAgentThought,
- MessageAnnotation,
- MessageChain,
- MessageFeedback,
- MessageFile,
-)
-from models.web import SavedMessage
-from services.feature_service import FeatureService
+from services.retention.conversation.messages_clean_policy import create_message_clean_policy
+from services.retention.conversation.messages_clean_service import MessagesCleanService
logger = logging.getLogger(__name__)
-@app.celery.task(queue="dataset")
+@app.celery.task(queue="retention")
def clean_messages():
- click.echo(click.style("Start clean messages.", fg="green"))
- start_at = time.perf_counter()
- plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
- days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
- )
- while True:
- try:
- # Main query with join and filter
- messages = (
- db.session.query(Message)
- .where(Message.created_at < plan_sandbox_clean_message_day)
- .order_by(Message.created_at.desc())
- .limit(100)
- .all()
- )
+ """
+ Clean expired messages based on clean policy.
- except SQLAlchemyError:
- raise
- if not messages:
- break
- for message in messages:
- app = db.session.query(App).filter_by(id=message.app_id).first()
- if not app:
- logger.warning(
- "Expected App record to exist, but none was found, app_id=%s, message_id=%s",
- message.app_id,
- message.id,
- )
- continue
- features_cache_key = f"features:{app.tenant_id}"
- plan_cache = redis_client.get(features_cache_key)
- if plan_cache is None:
- features = FeatureService.get_features(app.tenant_id)
- redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
- plan = features.billing.subscription.plan
- else:
- plan = plan_cache.decode()
- if plan == CloudPlan.SANDBOX:
- # clean related message
- db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
- synchronize_session=False
- )
- db.session.query(Message).where(Message.id == message.id).delete()
- db.session.commit()
- end_at = time.perf_counter()
- click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
+ This task uses MessagesCleanService to efficiently clean messages in batches.
+ The behavior depends on BILLING_ENABLED configuration:
+ - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period)
+ - BILLING_ENABLED=False: delete all messages within the time range
+ """
+ click.echo(click.style("clean_messages: start clean messages.", fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ # Create policy based on billing configuration
+ policy = create_message_clean_policy(
+ graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
+ )
+
+ # Create and run the cleanup service
+ # lock the task to avoid concurrent execution in case of the future data volume growth
+ with redis_client.lock(
+ "retention:clean_messages", timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, blocking=False
+ ):
+ service = MessagesCleanService.from_days(
+ policy=policy,
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ )
+ stats = service.run()
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ f"clean_messages: completed successfully\n"
+ f" - Latency: {end_at - start_at:.2f}s\n"
+ f" - Batches processed: {stats['batches']}\n"
+ f" - Total messages scanned: {stats['total_messages']}\n"
+ f" - Messages filtered: {stats['filtered_messages']}\n"
+ f" - Messages deleted: {stats['total_deleted']}",
+ fg="green",
+ )
+ )
+ except LockError:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages: acquire task lock failed, skip current execution")
+ click.echo(
+ click.style(
+ f"clean_messages: skipped (lock already held) - latency: {end_at - start_at:.2f}s",
+ fg="yellow",
+ )
+ )
+ raise
+ except Exception as e:
+ end_at = time.perf_counter()
+ logger.exception("clean_messages failed")
+ click.echo(
+ click.style(
+ f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py
new file mode 100644
index 0000000000..ff45a3ddf2
--- /dev/null
+++ b/api/schedule/clean_workflow_runs_task.py
@@ -0,0 +1,79 @@
+import logging
+from datetime import UTC, datetime
+
+import click
+from redis.exceptions import LockError
+
+import app
+from configs import dify_config
+from extensions.ext_redis import redis_client
+from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
+
+logger = logging.getLogger(__name__)
+
+
+@app.celery.task(queue="retention")
+def clean_workflow_runs_task() -> None:
+ """
+ Scheduled cleanup for workflow runs and related records (sandbox tenants only).
+ """
+ click.echo(
+ click.style(
+ (
+ "Scheduled workflow run cleanup starting: "
+ f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, "
+ f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}"
+ ),
+ fg="green",
+ )
+ )
+
+ start_time = datetime.now(UTC)
+
+ try:
+ # lock the task to avoid concurrent execution in case of the future data volume growth
+ with redis_client.lock(
+ "retention:clean_workflow_runs_task",
+ timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL,
+ blocking=False,
+ ):
+ WorkflowRunCleanup(
+ days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
+ batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
+ start_from=None,
+ end_before=None,
+ ).run()
+
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed}",
+ fg="green",
+ )
+ )
+ except LockError:
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ logger.exception("clean_workflow_runs_task: acquire task lock failed, skip current execution")
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup skipped (lock already held). "
+ f"start={start_time.isoformat()} end={end_time.isoformat()} duration={elapsed}",
+ fg="yellow",
+ )
+ )
+ raise
+ except Exception as e:
+ end_time = datetime.now(UTC)
+ elapsed = end_time - start_time
+ logger.exception("clean_workflow_runs_task failed")
+ click.echo(
+ click.style(
+ f"Scheduled workflow run cleanup failed. start={start_time.isoformat()} "
+ f"end={end_time.isoformat()} duration={elapsed} - {str(e)}",
+ fg="red",
+ )
+ )
+ raise
diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py
index c343063fae..ed46c1c70a 100644
--- a/api/schedule/create_tidb_serverless_task.py
+++ b/api/schedule/create_tidb_serverless_task.py
@@ -50,10 +50,13 @@ def create_clusters(batch_size):
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
+ tenant_id=None,
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
+ active=False,
+ status="CREATING",
)
db.session.add(tidb_auth_binding)
db.session.commit()
diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py
index db610df290..77d6b5a138 100644
--- a/api/schedule/queue_monitor_task.py
+++ b/api/schedule/queue_monitor_task.py
@@ -16,6 +16,11 @@ celery_redis = Redis(
port=redis_config.get("port") or 6379,
password=redis_config.get("password") or None,
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
+ ssl=bool(dify_config.BROKER_USE_SSL),
+ ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
+ ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
+ ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
+ ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
)
logger = logging.getLogger(__name__)
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 5a549dc318..35e4a505af 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -748,6 +748,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
+ @staticmethod
+ def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
+ """
+ Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
+
+ This keeps backward compatibility for older records that stored uppercase emails while the
+ rest of the system gradually normalizes new inputs.
+ """
+ query_session = session or db.session
+ account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
+ if account or email == email.lower():
+ return account
+
+ return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
+
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@@ -999,6 +1014,11 @@ class TenantService:
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
+
+ from services.credit_pool_service import CreditPoolService
+
+ CreditPoolService.create_default_pool(tenant.id)
+
return tenant
@staticmethod
@@ -1358,16 +1378,27 @@ class RegisterService:
if not inviter:
raise ValueError("Inviter is required")
+ normalized_email = email.lower()
+
"""Invite new member"""
+ # Check workspace permission for member invitations
+ from libs.workspace_permission import check_workspace_member_invite_permission
+
+ check_workspace_member_invite_permission(tenant.id)
+
with Session(db.engine) as session:
- account = session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
- name = email.split("@")[0]
+ name = normalized_email.split("@")[0]
account = cls.register(
- email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
+ email=normalized_email,
+ name=name,
+ language=language,
+ status=AccountStatus.PENDING,
+ is_setup=True,
)
# Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role)
@@ -1389,7 +1420,7 @@ class RegisterService:
# send email
send_invite_member_mail_task.delay(
language=language,
- to=email,
+ to=account.email,
token=token,
inviter_name=inviter.name if inviter else "Dify",
workspace_name=tenant.name,
@@ -1488,6 +1519,16 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
+ @classmethod
+ def get_invitation_with_case_fallback(
+ cls, workspace_id: str | None, email: str | None, token: str
+ ) -> dict[str, Any] | None:
+ invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
+ if invitation or not email or email == email.lower():
+ return invitation
+ normalized_email = email.lower()
+ return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
+
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index 9258def907..56e9cc6a00 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -1,10 +1,14 @@
+import logging
import uuid
import pandas as pd
+
+logger = logging.getLogger(__name__)
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
+from core.helper.csv_sanitizer import CSVSanitizer
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
@@ -73,7 +77,7 @@ class AppAnnotationService:
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
- annotation.question,
+ question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
@@ -133,13 +137,16 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
if keyword:
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(keyword)
stmt = (
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.where(
or_(
- MessageAnnotation.question.ilike(f"%{keyword}%"),
- MessageAnnotation.content.ilike(f"%{keyword}%"),
+ MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"),
+ MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
@@ -155,6 +162,12 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
+ """
+ Export all annotations for an app with CSV injection protection.
+
+ Sanitizes question and content fields to prevent formula injection attacks
+ when exported to CSV format.
+ """
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
@@ -171,6 +184,16 @@ class AppAnnotationService:
.order_by(MessageAnnotation.created_at.desc())
.all()
)
+
+ # Sanitize CSV-injectable fields to prevent formula injection
+ for annotation in annotations:
+ # Sanitize question field if present
+ if annotation.question:
+ annotation.question = CSVSanitizer.sanitize_value(annotation.question)
+ # Sanitize content field (answer)
+ if annotation.content:
+ annotation.content = CSVSanitizer.sanitize_value(annotation.content)
+
return annotations
@classmethod
@@ -186,8 +209,12 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
+ question = args.get("question")
+ if question is None:
+ raise ValueError("'question' is required")
+
annotation = MessageAnnotation(
- app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
+ app_id=app.id, content=args["answer"], question=question, account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
@@ -196,7 +223,7 @@ class AppAnnotationService:
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
- args["question"],
+ question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
@@ -221,8 +248,12 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
+ question = args.get("question")
+ if question is None:
+ raise ValueError("'question' is required")
+
annotation.content = args["answer"]
- annotation.question = args["question"]
+ annotation.question = question
db.session.commit()
# if annotation reply is enabled , add annotation to index
@@ -233,7 +264,7 @@ class AppAnnotationService:
if app_annotation_setting:
update_annotation_to_index_task.delay(
annotation.id,
- annotation.question,
+ annotation.question_text,
current_tenant_id,
app_id,
app_annotation_setting.collection_binding_id,
@@ -330,6 +361,18 @@ class AppAnnotationService:
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage):
+ """
+ Batch import annotations from CSV file with enhanced security checks.
+
+ Security features:
+ - File size validation
+ - Row count limits (min/max)
+ - Memory-efficient CSV parsing
+ - Subscription quota validation
+ - Concurrency tracking
+ """
+ from configs import dify_config
+
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
@@ -341,16 +384,80 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
+ job_id: str | None = None # Initialize to avoid unbound variable error
try:
- # Skip the first row
- df = pd.read_csv(file.stream, dtype=str)
- result = []
- for _, row in df.iterrows():
- content = {"question": row.iloc[0], "answer": row.iloc[1]}
+ # Quick row count check before full parsing (memory efficient)
+ # Read only first chunk to estimate row count
+ file.stream.seek(0)
+ first_chunk = file.stream.read(8192) # Read first 8KB
+ file.stream.seek(0)
+
+ # Estimate row count from first chunk
+ newline_count = first_chunk.count(b"\n")
+ if newline_count == 0:
+ raise ValueError("The CSV file appears to be empty or invalid.")
+
+ # Parse CSV with row limit to prevent memory exhaustion
+ # Use chunksize for memory-efficient processing
+ max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
+ min_records = dify_config.ANNOTATION_IMPORT_MIN_RECORDS
+
+ # Read CSV in chunks to avoid loading entire file into memory
+ df = pd.read_csv(
+ file.stream,
+ dtype=str,
+ nrows=max_records + 1, # Read one extra to detect overflow
+ engine="python",
+ on_bad_lines="skip", # Skip malformed lines instead of crashing
+ )
+
+ # Validate column count
+ if len(df.columns) < 2:
+ raise ValueError("Invalid CSV format. The file must contain at least 2 columns (question and answer).")
+
+ # Build result list with validation
+ result: list[dict] = []
+ for idx, row in df.iterrows():
+ # Stop if we exceed the limit
+ if len(result) >= max_records:
+ raise ValueError(
+ f"The CSV file contains too many records. Maximum {max_records} records allowed per import. "
+ f"Please split your file into smaller batches."
+ )
+
+ # Extract and validate question and answer
+ try:
+ question_raw = row.iloc[0]
+ answer_raw = row.iloc[1]
+ except (IndexError, KeyError):
+ continue # Skip malformed rows
+
+ # Convert to string and strip whitespace
+ question = str(question_raw).strip() if question_raw is not None else ""
+ answer = str(answer_raw).strip() if answer_raw is not None else ""
+
+ # Skip empty entries or NaN values
+ if not question or not answer or question.lower() == "nan" or answer.lower() == "nan":
+ continue
+
+ # Validate length constraints (idx is pandas index, convert to int for display)
+ row_num = int(idx) + 2 if isinstance(idx, (int, float)) else len(result) + 2
+ if len(question) > 2000:
+ raise ValueError(f"Question at row {row_num} is too long. Maximum 2000 characters allowed.")
+ if len(answer) > 10000:
+ raise ValueError(f"Answer at row {row_num} is too long. Maximum 10000 characters allowed.")
+
+ content = {"question": question, "answer": answer}
result.append(content)
- if len(result) == 0:
- raise ValueError("The CSV file is empty.")
- # check annotation limit
+
+ # Validate minimum records
+ if len(result) < min_records:
+ raise ValueError(
+ f"The CSV file must contain at least {min_records} valid annotation record(s). "
+ f"Found {len(result)} valid record(s)."
+ )
+
+ # Check annotation quota limit
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
annotation_quota_limit = features.annotation_quota_limit
@@ -359,12 +466,34 @@ class AppAnnotationService:
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
- # send batch add segments task
+
+ # Register job in active tasks list for concurrency tracking
+ current_time = int(naive_utc_now().timestamp() * 1000)
+ active_jobs_key = f"annotation_import_active:{current_tenant_id}"
+ redis_client.zadd(active_jobs_key, {job_id: current_time})
+ redis_client.expire(active_jobs_key, 7200) # 2 hours TTL
+
+ # Set job status
redis_client.setnx(indexing_cache_key, "waiting")
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
- except Exception as e:
+
+ except ValueError as e:
return {"error_msg": str(e)}
- return {"job_id": job_id, "job_status": "waiting"}
+ except Exception as e:
+ # Clean up active job registration on error (only if job was created)
+ if job_id is not None:
+ try:
+ active_jobs_key = f"annotation_import_active:{current_tenant_id}"
+ redis_client.zrem(active_jobs_key, job_id)
+ except Exception:
+ # Silently ignore cleanup errors - the job will be auto-expired
+ logger.debug("Failed to clean up active job tracking during error handling")
+
+ # Check if it's a CSV parsing error
+ error_str = str(e)
+ return {"error_msg": f"An error occurred while processing the file: {error_str}"}
+
+ return {"job_id": job_id, "job_status": "waiting", "record_count": len(result)}
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index 1dd6faea5d..0f42c99246 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from models import Account, App, AppMode
-from models.model import AppModelConfig
+from models.model import AppModelConfig, IconType
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.workflow_draft_variable_service import WorkflowDraftVariableService
@@ -155,6 +155,7 @@ class AppDslService:
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
+ and "/blob/" in parsed_url.path
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
@@ -427,10 +428,10 @@ class AppDslService:
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
- if icon_type_value in ["emoji", "link", "image"]:
+ if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]:
icon_type = icon_type_value
else:
- icon_type = "emoji"
+ icon_type = IconType.EMOJI
icon = icon or str(app_data.get("icon", ""))
if app:
@@ -520,12 +521,10 @@ class AppDslService:
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
- app_model_config = AppModelConfig().from_model_config_dict(model_config)
+ app_model_config = AppModelConfig(
+ app_id=app.id, created_by=account.id, updated_by=account.id
+ ).from_model_config_dict(model_config)
app_model_config.id = str(uuid4())
- app_model_config.app_id = app.id
- app_model_config.created_by = account.id
- app_model_config.updated_by = account.id
-
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
@@ -782,15 +781,16 @@ class AppDslService:
return dependencies
@classmethod
- def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
+ def get_leaked_dependencies(
+ cls, tenant_id: str, dsl_dependencies: list[PluginDependency]
+ ) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
- dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
- if not dependencies:
+ if not dsl_dependencies:
return []
- return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
+ return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies)
@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index 4514c86f7c..ce85f2e914 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import uuid
from collections.abc import Generator, Mapping
-from typing import Any, Union
+from typing import TYPE_CHECKING, Any, Union
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@@ -14,9 +16,13 @@ from enums.quota_type import QuotaType, unlimited
from extensions.otel import AppGenerateHandler, trace_span
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
-from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
+from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
+from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService
+if TYPE_CHECKING:
+ from controllers.console.app.workflow import LoopNodeRunPayload
+
class AppGenerateService:
@classmethod
@@ -164,7 +170,9 @@ class AppGenerateService:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
- def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
+ def generate_single_loop(
+ cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
+ ):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
diff --git a/api/services/app_service.py b/api/services/app_service.py
index ef89a4fd10..af458ff618 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -55,8 +55,11 @@ class AppService:
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
+ from libs.helper import escape_like_pattern
+
name = args["name"][:30]
- filters.append(App.name.ilike(f"%{name}%"))
+ escaped_name = escape_like_pattern(name)
+ filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
@@ -147,10 +150,9 @@ class AppService:
db.session.flush()
if default_model_config:
- app_model_config = AppModelConfig(**default_model_config)
- app_model_config.app_id = app.id
- app_model_config.created_by = account.id
- app_model_config.updated_by = account.id
+ app_model_config = AppModelConfig(
+ **default_model_config, app_id=app.id, created_by=account.id, updated_by=account.id
+ )
db.session.add(app_model_config)
db.session.flush()
diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py
index e100582511..bc73b7c8c2 100644
--- a/api/services/async_workflow_service.py
+++ b/api/services/async_workflow_service.py
@@ -21,7 +21,7 @@ from models.model import App, EndUser
from models.trigger import WorkflowTriggerLog
from models.workflow import Workflow
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
-from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowNotFoundError
+from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
from services.workflow_service import WorkflowService
@@ -141,7 +141,7 @@ class AsyncWorkflowService:
trigger_log_repo.update(trigger_log)
session.commit()
- raise InvokeRateLimitError(
+ raise WorkflowQuotaLimitError(
f"Workflow execution quota limit reached for tenant {trigger_data.tenant_id}"
) from e
diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py
index d455475bfc..b002706931 100644
--- a/api/services/auth/firecrawl/firecrawl.py
+++ b/api/services/auth/firecrawl/firecrawl.py
@@ -26,7 +26,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
- response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
+ response = self._post_request(self._build_url("v1/crawl"), options, headers)
if response.status_code == 200:
return True
else:
@@ -35,15 +35,17 @@ class FirecrawlAuth(ApiKeyAuthBase):
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
+ def _build_url(self, path: str) -> str:
+ # ensure exactly one slash between base and path, regardless of user-provided base_url
+ return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
+
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
- if response.status_code in {402, 409, 500}:
- error_message = response.json().get("error", "Unknown error occurred")
- raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
- else:
- if response.text:
- error_message = json.loads(response.text).get("error", "Unknown error occurred")
- raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
- raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
+ try:
+ payload = response.json()
+ except json.JSONDecodeError:
+ payload = {}
+ error_message = payload.get("error") or payload.get("message") or (response.text or "Unknown error occurred")
+ raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
diff --git a/api/services/billing_service.py b/api/services/billing_service.py
index 54e1c9d285..946b8cdfdb 100644
--- a/api/services/billing_service.py
+++ b/api/services/billing_service.py
@@ -1,8 +1,13 @@
+import json
+import logging
import os
+from collections.abc import Sequence
from typing import Literal
import httpx
+from pydantic import TypeAdapter
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
+from typing_extensions import TypedDict
from werkzeug.exceptions import InternalServerError
from enums.cloud_plan import CloudPlan
@@ -11,6 +16,15 @@ from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
from models import Account, TenantAccountJoin, TenantAccountRole
+logger = logging.getLogger(__name__)
+
+
+class SubscriptionPlan(TypedDict):
+ """Tenant subscriptionplan information."""
+
+ plan: str
+ expiration_date: int
+
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
@@ -18,6 +32,11 @@ class BillingService:
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
+ # Redis key prefix for tenant plan cache
+ _PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
+ # Cache TTL: 10 minutes
+ _PLAN_CACHE_TTL = 600
+
@classmethod
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
@@ -112,7 +131,7 @@ class BillingService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
- response = httpx.request(method, url, json=json, params=params, headers=headers)
+ response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
if method == "PUT":
@@ -124,6 +143,9 @@ class BillingService:
raise ValueError("Invalid arguments.")
if method == "POST" and response.status_code != httpx.codes.OK:
raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.")
+ if method == "DELETE" and response.status_code != httpx.codes.OK:
+ logger.error("billing_service: DELETE response: %s %s", response.status_code, response.text)
+ raise ValueError(f"Unable to process delete request {url}. Please try again later or contact support.")
return response.json()
@staticmethod
@@ -146,7 +168,7 @@ class BillingService:
def delete_account(cls, account_id: str):
"""Delete account."""
params = {"account_id": account_id}
- return cls._send_request("DELETE", "/account/", params=params)
+ return cls._send_request("DELETE", "/account", params=params)
@classmethod
def is_email_in_freeze(cls, email: str) -> bool:
@@ -239,3 +261,135 @@ class BillingService:
def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str):
payload = {"account_id": account_id, "click_id": click_id}
return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload)
+
+ @classmethod
+ def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
+ """
+ Bulk fetch billing subscription plan via billing API.
+ Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
+ Returns:
+ Mapping of tenant_id -> {plan: str, expiration_date: int}
+ """
+ results: dict[str, SubscriptionPlan] = {}
+ subscription_adapter = TypeAdapter(SubscriptionPlan)
+
+ chunk_size = 200
+ for i in range(0, len(tenant_ids), chunk_size):
+ chunk = tenant_ids[i : i + chunk_size]
+ try:
+ resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
+ data = resp.get("data", {})
+
+ for tenant_id, plan in data.items():
+ try:
+ subscription_plan = subscription_adapter.validate_python(plan)
+ results[tenant_id] = subscription_plan
+ except Exception:
+ logger.exception(
+ "get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
+ )
+ continue
+ except Exception:
+ logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
+ continue
+
+ return results
+
+ @classmethod
+ def _make_plan_cache_key(cls, tenant_id: str) -> str:
+ return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"
+
+ @classmethod
+ def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
+ """
+ Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.
+
+ NOTE: if you want to high data consistency, use get_plan_bulk instead.
+
+ Returns:
+ Mapping of tenant_id -> {plan: str, expiration_date: int}
+ """
+ tenant_plans: dict[str, SubscriptionPlan] = {}
+
+ if not tenant_ids:
+ return tenant_plans
+
+ subscription_adapter = TypeAdapter(SubscriptionPlan)
+
+ # Step 1: Batch fetch from Redis cache using mget
+ redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
+ try:
+ cached_values = redis_client.mget(redis_keys)
+
+ if len(cached_values) != len(tenant_ids):
+ raise Exception(
+ "get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
+ )
+
+ # Map cached values back to tenant_ids
+ cache_misses: list[str] = []
+
+ for tenant_id, cached_value in zip(tenant_ids, cached_values):
+ if cached_value:
+ try:
+ # Redis returns bytes, decode to string and parse JSON
+ json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
+ plan_dict = json.loads(json_str)
+ subscription_plan = subscription_adapter.validate_python(plan_dict)
+ tenant_plans[tenant_id] = subscription_plan
+ except Exception:
+ logger.exception(
+ "get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
+ )
+ cache_misses.append(tenant_id)
+ else:
+ cache_misses.append(tenant_id)
+
+ logger.info(
+ "get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
+ len(tenant_plans),
+ len(cache_misses),
+ )
+ except Exception:
+ logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
+ cache_misses = list(tenant_ids)
+
+ # Step 2: Fetch missing plans from billing API
+ if cache_misses:
+ bulk_plans = BillingService.get_plan_bulk(cache_misses)
+
+ if bulk_plans:
+ plans_to_cache: dict[str, SubscriptionPlan] = {}
+
+ for tenant_id, subscription_plan in bulk_plans.items():
+ tenant_plans[tenant_id] = subscription_plan
+ plans_to_cache[tenant_id] = subscription_plan
+
+ # Step 3: Batch update Redis cache using pipeline
+ if plans_to_cache:
+ try:
+ pipe = redis_client.pipeline()
+ for tenant_id, subscription_plan in plans_to_cache.items():
+ redis_key = cls._make_plan_cache_key(tenant_id)
+ # Serialize dict to JSON string
+ json_str = json.dumps(subscription_plan)
+ pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
+ pipe.execute()
+
+ logger.info(
+ "get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
+ len(plans_to_cache),
+ )
+ except Exception:
+ logger.exception("get_plan_bulk_with_cache: redis pipeline failed")
+
+ return tenant_plans
+
+ @classmethod
+ def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
+ resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
+ data = resp.get("data", [])
+ tenant_whitelist = []
+ for item in data:
+ tenant_whitelist.append(item["tenant_id"])
+ return tenant_whitelist
diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py
index 5253199552..295d48d8a1 100644
--- a/api/services/conversation_service.py
+++ b/api/services/conversation_service.py
@@ -6,16 +6,18 @@ from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
+from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType
-from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from extensions.ext_database import db
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account, ConversationVariable
from models.model import App, Conversation, EndUser, Message
+from services.conversation_variable_updater import ConversationVariableUpdater
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
@@ -202,6 +204,7 @@ class ConversationService:
user: Union[Account, EndUser] | None,
limit: int,
last_id: str | None,
+ variable_name: str | None = None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
@@ -212,7 +215,27 @@ class ConversationService:
.order_by(ConversationVariable.created_at)
)
- with Session(db.engine) as session:
+ # Apply variable_name filter if provided
+ if variable_name:
+ # Filter using JSON extraction to match variable names case-insensitively
+ from libs.helper import escape_like_pattern
+
+ escaped_variable_name = escape_like_pattern(variable_name)
+ # Filter using JSON extraction to match variable names case-insensitively
+ if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
+ stmt = stmt.where(
+ func.json_extract(ConversationVariable.data, "$.name").ilike(
+ f"%{escaped_variable_name}%", escape="\\"
+ )
+ )
+ elif dify_config.DB_TYPE == "postgresql":
+ stmt = stmt.where(
+ func.json_extract_path_text(ConversationVariable.data, "name").ilike(
+ f"%{escaped_variable_name}%", escape="\\"
+ )
+ )
+
+ with session_factory.create_session() as session:
if last_id:
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
if not last_variable:
@@ -279,7 +302,7 @@ class ConversationService:
.where(ConversationVariable.id == variable_id)
)
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
existing_variable = session.scalar(stmt)
if not existing_variable:
raise ConversationVariableNotExistsError()
@@ -314,7 +337,7 @@ class ConversationService:
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes
- updater = conversation_variable_updater_factory()
+ updater = ConversationVariableUpdater(session_factory.get_session_maker())
updater.update(conversation_id, updated_variable)
updater.flush()
diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py
new file mode 100644
index 0000000000..92008d5ff1
--- /dev/null
+++ b/api/services/conversation_variable_updater.py
@@ -0,0 +1,28 @@
+from sqlalchemy import select
+from sqlalchemy.orm import Session, sessionmaker
+
+from core.variables.variables import VariableBase
+from models import ConversationVariable
+
+
+class ConversationVariableNotFoundError(Exception):
+ pass
+
+
+class ConversationVariableUpdater:
+ def __init__(self, session_maker: sessionmaker[Session]) -> None:
+ self._session_maker: sessionmaker[Session] = session_maker
+
+ def update(self, conversation_id: str, variable: VariableBase) -> None:
+ stmt = select(ConversationVariable).where(
+ ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
+ )
+ with self._session_maker() as session:
+ row = session.scalar(stmt)
+ if not row:
+ raise ConversationVariableNotFoundError("conversation variable not found in the database")
+ row.data = variable.model_dump_json()
+ session.commit()
+
+ def flush(self) -> None:
+ pass
diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py
new file mode 100644
index 0000000000..1954602571
--- /dev/null
+++ b/api/services/credit_pool_service.py
@@ -0,0 +1,85 @@
+import logging
+
+from sqlalchemy import update
+from sqlalchemy.orm import Session
+
+from configs import dify_config
+from core.errors.error import QuotaExceededError
+from extensions.ext_database import db
+from models import TenantCreditPool
+
+logger = logging.getLogger(__name__)
+
+
+class CreditPoolService:
+ @classmethod
+ def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
+ """create default credit pool for new tenant"""
+ credit_pool = TenantCreditPool(
+ tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
+ )
+ db.session.add(credit_pool)
+ db.session.commit()
+ return credit_pool
+
+ @classmethod
+ def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
+ """get tenant credit pool"""
+ return (
+ db.session.query(TenantCreditPool)
+ .filter_by(
+ tenant_id=tenant_id,
+ pool_type=pool_type,
+ )
+ .first()
+ )
+
+ @classmethod
+ def check_credits_available(
+ cls,
+ tenant_id: str,
+ credits_required: int,
+ pool_type: str = "trial",
+ ) -> bool:
+ """check if credits are available without deducting"""
+ pool = cls.get_pool(tenant_id, pool_type)
+ if not pool:
+ return False
+ return pool.remaining_credits >= credits_required
+
+ @classmethod
+ def check_and_deduct_credits(
+ cls,
+ tenant_id: str,
+ credits_required: int,
+ pool_type: str = "trial",
+ ) -> int:
+ """check and deduct credits, returns actual credits deducted"""
+
+ pool = cls.get_pool(tenant_id, pool_type)
+ if not pool:
+ raise QuotaExceededError("Credit pool not found")
+
+ if pool.remaining_credits <= 0:
+ raise QuotaExceededError("No credits remaining")
+
+ # deduct all remaining credits if less than required
+ actual_credits = min(credits_required, pool.remaining_credits)
+
+ try:
+ with Session(db.engine) as session:
+ stmt = (
+ update(TenantCreditPool)
+ .where(
+ TenantCreditPool.tenant_id == tenant_id,
+ TenantCreditPool.pool_type == pool_type,
+ )
+ .values(quota_used=TenantCreditPool.quota_used + actual_credits)
+ )
+ session.execute(stmt)
+ session.commit()
+ except Exception:
+ logger.exception("Failed to deduct credits for tenant %s", tenant_id)
+ raise QuotaExceededError("Failed to deduct credits")
+
+ return actual_credits
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 7841b8b33d..0b3fcbe4ae 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -13,10 +13,11 @@ import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
-from werkzeug.exceptions import NotFound
+from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
+from core.file import helpers as file_helpers
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
@@ -73,6 +74,7 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
+from services.file_service import FileService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tag_service import TagService
from services.vector_service import VectorService
@@ -87,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
+from tasks.regenerate_summary_index_task import regenerate_summary_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
@@ -144,7 +147,8 @@ class DatasetService:
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
- query = query.where(Dataset.name.ilike(f"%{search}%"))
+ escaped_search = helper.escape_like_pattern(search)
+ query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
@@ -208,6 +212,7 @@ class DatasetService:
embedding_model_provider: str | None = None,
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
+ summary_index_setting: dict | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
@@ -250,6 +255,8 @@ class DatasetService:
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
+ if summary_index_setting is not None:
+ dataset.summary_index_setting = summary_index_setting
db.session.add(dataset)
db.session.flush()
@@ -473,6 +480,11 @@ class DatasetService:
if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model
+ # Update summary index setting if provided
+ summary_index_setting = data.get("summary_index_setting", None)
+ if summary_index_setting is not None:
+ dataset.summary_index_setting = summary_index_setting
+
# Update basic dataset properties
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", dataset.description)
@@ -561,6 +573,9 @@ class DatasetService:
# update Retrieval model
if data.get("retrieval_model"):
filtered_data["retrieval_model"] = data["retrieval_model"]
+ # update summary index setting
+ if data.get("summary_index_setting"):
+ filtered_data["summary_index_setting"] = data.get("summary_index_setting")
# update icon info
if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info")
@@ -569,12 +584,27 @@ class DatasetService:
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.commit()
+ # Reload dataset to get updated values
+ db.session.refresh(dataset)
+
# update pipeline knowledge base node data
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
# Trigger vector index task if indexing technique changed
if action:
deal_dataset_vector_index_task.delay(dataset.id, action)
+ # If embedding_model changed, also regenerate summary vectors
+ if action == "update":
+ regenerate_summary_index_task.delay(
+ dataset.id,
+ regenerate_reason="embedding_model_changed",
+ regenerate_vectors_only=True,
+ )
+
+ # Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries.
+ # The new setting will only apply to:
+ # 1. New documents added after the setting change
+ # 2. Manual summary generation requests
return dataset
@@ -613,6 +643,7 @@ class DatasetService:
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
+ knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting
node["data"] = knowledge_index_node_data
updated = True
except Exception:
@@ -851,6 +882,54 @@ class DatasetService:
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
+ @staticmethod
+ def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool:
+ """
+ Check if summary_index_setting model (model_name or model_provider_name) has changed.
+
+ Args:
+ dataset: Current dataset object
+ data: Update data dictionary
+
+ Returns:
+ bool: True if summary model changed, False otherwise
+ """
+ # Check if summary_index_setting is being updated
+ if "summary_index_setting" not in data or data.get("summary_index_setting") is None:
+ return False
+
+ new_summary_setting = data.get("summary_index_setting")
+ old_summary_setting = dataset.summary_index_setting
+
+ # If new setting is disabled, no need to regenerate
+ if not new_summary_setting or not new_summary_setting.get("enable"):
+ return False
+
+ # If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate)
+ # Note: This task only regenerates existing summaries, not generates new ones
+ if not old_summary_setting:
+ return False
+
+ # Compare model_name and model_provider_name
+ old_model_name = old_summary_setting.get("model_name")
+ old_model_provider = old_summary_setting.get("model_provider_name")
+ new_model_name = new_summary_setting.get("model_name")
+ new_model_provider = new_summary_setting.get("model_provider_name")
+
+ # Check if model changed
+ if old_model_name != new_model_name or old_model_provider != new_model_provider:
+ logger.info(
+ "Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s",
+ dataset.id,
+ old_model_provider,
+ old_model_name,
+ new_model_provider,
+ new_model_name,
+ )
+ return True
+
+ return False
+
@staticmethod
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
@@ -886,6 +965,9 @@ class DatasetService:
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
+ # Update summary_index_setting if provided
+ if knowledge_configuration.summary_index_setting is not None:
+ dataset.summary_index_setting = knowledge_configuration.summary_index_setting
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
@@ -991,6 +1073,9 @@ class DatasetService:
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
+ # Update summary_index_setting if provided
+ if knowledge_configuration.summary_index_setting is not None:
+ dataset.summary_index_setting = knowledge_configuration.summary_index_setting
session.add(dataset)
session.commit()
if action:
@@ -1161,6 +1246,7 @@ class DocumentService:
Document.archived.is_(True),
),
}
+ DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION = ".zip"
@classmethod
def normalize_display_status(cls, status: str | None) -> str | None:
@@ -1287,6 +1373,187 @@ class DocumentService:
else:
return None
+ @staticmethod
+ def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]:
+ """Fetch documents for a dataset in a single batch query."""
+ if not document_ids:
+ return []
+ document_id_list: list[str] = [str(document_id) for document_id in document_ids]
+ # Fetch all requested documents in one query to avoid N+1 lookups.
+ documents: Sequence[Document] = db.session.scalars(
+ select(Document).where(
+ Document.dataset_id == dataset_id,
+ Document.id.in_(document_id_list),
+ )
+ ).all()
+ return documents
+
+ @staticmethod
+ def get_document_download_url(document: Document) -> str:
+ """
+ Return a signed download URL for an upload-file document.
+ """
+ upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
+ return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
+
+ @staticmethod
+ def enrich_documents_with_summary_index_status(
+ documents: Sequence[Document],
+ dataset: Dataset,
+ tenant_id: str,
+ ) -> None:
+ """
+ Enrich documents with summary_index_status based on dataset summary index settings.
+
+ This method calculates and sets the summary_index_status for each document that needs summary.
+ Documents that don't need summary or when summary index is disabled will have status set to None.
+
+ Args:
+ documents: List of Document instances to enrich
+ dataset: Dataset instance containing summary_index_setting
+ tenant_id: Tenant ID for summary status lookup
+ """
+ # Check if dataset has summary index enabled
+ has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
+
+ # Filter documents that need summary calculation
+ documents_need_summary = [doc for doc in documents if doc.need_summary is True]
+ document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
+
+ # Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
+ summary_status_map: dict[str, str | None] = {}
+ if has_summary_index and document_ids_need_summary:
+ from services.summary_index_service import SummaryIndexService
+
+ summary_status_map = SummaryIndexService.get_documents_summary_index_status(
+ document_ids=document_ids_need_summary,
+ dataset_id=dataset.id,
+ tenant_id=tenant_id,
+ )
+
+ # Add summary_index_status to each document
+ for document in documents:
+ if has_summary_index and document.need_summary is True:
+ # Get status from map, default to None (not queued yet)
+ document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined]
+ else:
+ # Return null if summary index is not enabled or document doesn't need summary
+ document.summary_index_status = None # type: ignore[attr-defined]
+
+ @staticmethod
+ def prepare_document_batch_download_zip(
+ *,
+ dataset_id: str,
+ document_ids: Sequence[str],
+ tenant_id: str,
+ current_user: Account,
+ ) -> tuple[list[UploadFile], str]:
+ """
+ Resolve upload files for batch ZIP downloads and generate a client-visible filename.
+ """
+ dataset = DatasetService.get_dataset(dataset_id)
+ if not dataset:
+ raise NotFound("Dataset not found.")
+ try:
+ DatasetService.check_dataset_permission(dataset, current_user)
+ except NoPermissionError as e:
+ raise Forbidden(str(e))
+
+ upload_files_by_document_id = DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset_id,
+ document_ids=document_ids,
+ tenant_id=tenant_id,
+ )
+ upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids]
+ download_name = DocumentService._generate_document_batch_download_zip_filename()
+ return upload_files, download_name
+
+ @staticmethod
+ def _generate_document_batch_download_zip_filename() -> str:
+ """
+ Generate a random attachment filename for the batch download ZIP.
+ """
+ return f"{uuid.uuid4().hex}{DocumentService.DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION}"
+
+ @staticmethod
+ def _get_upload_file_id_for_upload_file_document(
+ document: Document,
+ *,
+ invalid_source_message: str,
+ missing_file_message: str,
+ ) -> str:
+ """
+ Normalize and validate `Document -> UploadFile` linkage for download flows.
+ """
+ if document.data_source_type != "upload_file":
+ raise NotFound(invalid_source_message)
+
+ data_source_info: dict[str, Any] = document.data_source_info_dict or {}
+ upload_file_id: str | None = data_source_info.get("upload_file_id")
+ if not upload_file_id:
+ raise NotFound(missing_file_message)
+
+ return str(upload_file_id)
+
+ @staticmethod
+ def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile:
+ """
+ Load the `UploadFile` row for an upload-file document.
+ """
+ upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="Document does not have an uploaded file to download.",
+ missing_file_message="Uploaded file not found.",
+ )
+ upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
+ upload_file = upload_files_by_id.get(upload_file_id)
+ if not upload_file:
+ raise NotFound("Uploaded file not found.")
+ return upload_file
+
+ @staticmethod
+ def _get_upload_files_by_document_id_for_zip_download(
+ *,
+ dataset_id: str,
+ document_ids: Sequence[str],
+ tenant_id: str,
+ ) -> dict[str, UploadFile]:
+ """
+ Batch load upload files keyed by document id for ZIP downloads.
+ """
+ document_id_list: list[str] = [str(document_id) for document_id in document_ids]
+
+ documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list)
+ documents_by_id: dict[str, Document] = {str(document.id): document for document in documents}
+
+ missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys())
+ if missing_document_ids:
+ raise NotFound("Document not found.")
+
+ upload_file_ids: list[str] = []
+ upload_file_ids_by_document_id: dict[str, str] = {}
+ for document_id, document in documents_by_id.items():
+ if document.tenant_id != tenant_id:
+ raise Forbidden("No permission.")
+
+ upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="Only uploaded-file documents can be downloaded as ZIP.",
+ missing_file_message="Only uploaded-file documents can be downloaded as ZIP.",
+ )
+ upload_file_ids.append(upload_file_id)
+ upload_file_ids_by_document_id[document_id] = upload_file_id
+
+ upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
+ missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
+ if missing_upload_file_ids:
+ raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")
+
+ return {
+ document_id: upload_files_by_id[upload_file_id]
+ for document_id, upload_file_id in upload_file_ids_by_document_id.items()
+ }
+
@staticmethod
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
@@ -1419,7 +1686,7 @@ class DocumentService:
document.name = name
db.session.add(document)
- if document.data_source_info_dict:
+ if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
db.session.query(UploadFile).where(
UploadFile.id == document.data_source_info_dict["upload_file_id"]
).update({UploadFile.name: name})
@@ -1636,6 +1903,20 @@ class DocumentService:
return [], ""
db.session.add(dataset_process_rule)
db.session.flush()
+ else:
+ # Fallback when no process_rule provided in knowledge_config:
+ # 1) reuse dataset.latest_process_rule if present
+ # 2) otherwise create an automatic rule
+ dataset_process_rule = getattr(dataset, "latest_process_rule", None)
+ if not dataset_process_rule:
+ dataset_process_rule = DatasetProcessRule(
+ dataset_id=dataset.id,
+ mode="automatic",
+ rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
+ created_by=account.id,
+ )
+ db.session.add(dataset_process_rule)
+ db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
try:
with redis_client.lock(lock_name, timeout=600):
@@ -1647,65 +1928,67 @@ class DocumentService:
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
- for file_id in upload_file_list:
- file = (
- db.session.query(UploadFile)
- .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
- .first()
+ files = (
+ db.session.query(UploadFile)
+ .where(
+ UploadFile.tenant_id == dataset.tenant_id,
+ UploadFile.id.in_(upload_file_list),
)
+ .all()
+ )
+ if len(files) != len(set(upload_file_list)):
+ raise FileNotExistsError("One or more files not found.")
- # raise error if file not found
- if not file:
- raise FileNotExistsError()
-
- file_name = file.name
+ file_names = [file.name for file in files]
+ db_documents = (
+ db.session.query(Document)
+ .where(
+ Document.dataset_id == dataset.id,
+ Document.tenant_id == current_user.current_tenant_id,
+ Document.data_source_type == "upload_file",
+ Document.enabled == True,
+ Document.name.in_(file_names),
+ )
+ .all()
+ )
+ documents_map = {document.name: document for document in db_documents}
+ for file in files:
data_source_info: dict[str, str | bool] = {
- "upload_file_id": file_id,
+ "upload_file_id": file.id,
}
- # check duplicate
- if knowledge_config.duplicate:
- document = (
- db.session.query(Document)
- .filter_by(
- dataset_id=dataset.id,
- tenant_id=current_user.current_tenant_id,
- data_source_type="upload_file",
- enabled=True,
- name=file_name,
- )
- .first()
+ document = documents_map.get(file.name)
+ if knowledge_config.duplicate and document:
+ document.dataset_process_rule_id = dataset_process_rule.id
+ document.updated_at = naive_utc_now()
+ document.created_from = created_from
+ document.doc_form = knowledge_config.doc_form
+ document.doc_language = knowledge_config.doc_language
+ document.data_source_info = json.dumps(data_source_info)
+ document.batch = batch
+ document.indexing_status = "waiting"
+ db.session.add(document)
+ documents.append(document)
+ duplicate_document_ids.append(document.id)
+ continue
+ else:
+ document = DocumentService.build_document(
+ dataset,
+ dataset_process_rule.id,
+ knowledge_config.data_source.info_list.data_source_type,
+ knowledge_config.doc_form,
+ knowledge_config.doc_language,
+ data_source_info,
+ created_from,
+ position,
+ account,
+ file.name,
+ batch,
)
- if document:
- document.dataset_process_rule_id = dataset_process_rule.id
- document.updated_at = naive_utc_now()
- document.created_from = created_from
- document.doc_form = knowledge_config.doc_form
- document.doc_language = knowledge_config.doc_language
- document.data_source_info = json.dumps(data_source_info)
- document.batch = batch
- document.indexing_status = "waiting"
- db.session.add(document)
- documents.append(document)
- duplicate_document_ids.append(document.id)
- continue
- document = DocumentService.build_document(
- dataset,
- dataset_process_rule.id,
- knowledge_config.data_source.info_list.data_source_type,
- knowledge_config.doc_form,
- knowledge_config.doc_language,
- data_source_info,
- created_from,
- position,
- account,
- file_name,
- batch,
- )
- db.session.add(document)
- db.session.flush()
- document_ids.append(document.id)
- documents.append(document)
- position += 1
+ db.session.add(document)
+ db.session.flush()
+ document_ids.append(document.id)
+ documents.append(document)
+ position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
@@ -1807,6 +2090,8 @@ class DocumentService:
DuplicateDocumentIndexingTaskProxy(
dataset.tenant_id, dataset.id, duplicate_document_ids
).delay()
+ # Note: Summary index generation is triggered in document_indexing_task after indexing completes
+ # to ensure segments are available. See tasks/document_indexing_task.py
except LockNotOwnedError:
pass
@@ -2111,6 +2396,11 @@ class DocumentService:
name: str,
batch: str,
):
+ # Set need_summary based on dataset's summary_index_setting
+ need_summary = False
+ if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
+ need_summary = True
+
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -2124,6 +2414,7 @@ class DocumentService:
created_by=account.id,
doc_form=document_form,
doc_language=document_language,
+ need_summary=need_summary,
)
doc_metadata = {}
if dataset.built_in_field_enabled:
@@ -2348,6 +2639,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
+ summary_index_setting=knowledge_config.summary_index_setting,
is_multimodal=knowledge_config.is_multimodal,
)
@@ -2529,6 +2821,14 @@ class DocumentService:
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
+ # valid summary index setting
+ summary_index_setting = args["process_rule"].get("summary_index_setting")
+ if summary_index_setting and summary_index_setting.get("enable"):
+ if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
+ raise ValueError("Summary index model name is required")
+ if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
+ raise ValueError("Summary index model provider name is required")
+
@staticmethod
def batch_update_document_status(
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
@@ -2801,20 +3101,20 @@ class SegmentService:
db.session.add(binding)
db.session.commit()
- # save vector index
- try:
- VectorService.create_segments_vector(
- [args["keywords"]], [segment_document], dataset, document.doc_form
- )
- except Exception as e:
- logger.exception("create segment index failed")
- segment_document.enabled = False
- segment_document.disabled_at = naive_utc_now()
- segment_document.status = "error"
- segment_document.error = str(e)
- db.session.commit()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
- return segment
+ # save vector index
+ try:
+ keywords = args.get("keywords")
+ keywords_list = [keywords] if keywords is not None else None
+ VectorService.create_segments_vector(keywords_list, [segment_document], dataset, document.doc_form)
+ except Exception as e:
+ logger.exception("create segment index failed")
+ segment_document.enabled = False
+ segment_document.disabled_at = naive_utc_now()
+ segment_document.status = "error"
+ segment_document.error = str(e)
+ db.session.commit()
+ segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
+ return segment
except LockNotOwnedError:
pass
@@ -2997,6 +3297,35 @@ class SegmentService:
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
+ # update summary index if summary is provided and has changed
+ if args.summary is not None:
+ # When user manually provides summary, allow saving even if summary_index_setting doesn't exist
+ # summary_index_setting is only needed for LLM generation, not for manual summary vectorization
+ # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
+ if dataset.indexing_technique == "high_quality":
+ # Query existing summary from database
+ from models.dataset import DocumentSegmentSummary
+
+ existing_summary = (
+ db.session.query(DocumentSegmentSummary)
+ .where(
+ DocumentSegmentSummary.chunk_id == segment.id,
+ DocumentSegmentSummary.dataset_id == dataset.id,
+ )
+ .first()
+ )
+
+ # Check if summary has changed
+ existing_summary_content = existing_summary.summary_content if existing_summary else None
+ if existing_summary_content != args.summary:
+ # Summary has changed, update it
+ from services.summary_index_service import SummaryIndexService
+
+ try:
+ SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
+ except Exception:
+ logger.exception("Failed to update summary for segment %s", segment.id)
+ # Don't fail the entire update if summary update fails
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@@ -3071,6 +3400,73 @@ class SegmentService:
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
+ # Handle summary index when content changed
+ if dataset.indexing_technique == "high_quality":
+ from models.dataset import DocumentSegmentSummary
+
+ existing_summary = (
+ db.session.query(DocumentSegmentSummary)
+ .where(
+ DocumentSegmentSummary.chunk_id == segment.id,
+ DocumentSegmentSummary.dataset_id == dataset.id,
+ )
+ .first()
+ )
+
+ if args.summary is None:
+ # User didn't provide summary, auto-regenerate if segment previously had summary
+ # Auto-regeneration only happens if summary_index_setting exists and enable is True
+ if (
+ existing_summary
+ and dataset.summary_index_setting
+ and dataset.summary_index_setting.get("enable") is True
+ ):
+ # Segment previously had summary, regenerate it with new content
+ from services.summary_index_service import SummaryIndexService
+
+ try:
+ SummaryIndexService.generate_and_vectorize_summary(
+ segment, dataset, dataset.summary_index_setting
+ )
+ logger.info("Auto-regenerated summary for segment %s after content change", segment.id)
+ except Exception:
+ logger.exception("Failed to auto-regenerate summary for segment %s", segment.id)
+ # Don't fail the entire update if summary regeneration fails
+ else:
+ # User provided summary, check if it has changed
+ # Manual summary updates are allowed even if summary_index_setting doesn't exist
+ existing_summary_content = existing_summary.summary_content if existing_summary else None
+ if existing_summary_content != args.summary:
+ # Summary has changed, use user-provided summary
+ from services.summary_index_service import SummaryIndexService
+
+ try:
+ SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
+ logger.info("Updated summary for segment %s with user-provided content", segment.id)
+ except Exception:
+ logger.exception("Failed to update summary for segment %s", segment.id)
+ # Don't fail the entire update if summary update fails
+ else:
+ # Summary hasn't changed, regenerate based on new content
+ # Auto-regeneration only happens if summary_index_setting exists and enable is True
+ if (
+ existing_summary
+ and dataset.summary_index_setting
+ and dataset.summary_index_setting.get("enable") is True
+ ):
+ from services.summary_index_service import SummaryIndexService
+
+ try:
+ SummaryIndexService.generate_and_vectorize_summary(
+ segment, dataset, dataset.summary_index_setting
+ )
+ logger.info(
+ "Regenerated summary for segment %s after content change (summary unchanged)",
+ segment.id,
+ )
+ except Exception:
+ logger.exception("Failed to regenerate summary for segment %s", segment.id)
+ # Don't fail the entire update if summary regeneration fails
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:
@@ -3407,7 +3803,8 @@ class SegmentService:
.order_by(ChildChunk.position.asc())
)
if keyword:
- query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
+ escaped_keyword = helper.escape_like_pattern(keyword)
+ query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\"))
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
@@ -3440,9 +3837,10 @@ class SegmentService:
query = query.where(DocumentSegment.status.in_(status_list))
if keyword:
- query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
+ escaped_keyword = helper.escape_like_pattern(keyword)
+ query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"))
- query = query.order_by(DocumentSegment.position.asc())
+ query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total
@@ -3457,6 +3855,39 @@ class SegmentService:
)
return result if isinstance(result, DocumentSegment) else None
+ @classmethod
+ def get_segments_by_document_and_dataset(
+ cls,
+ document_id: str,
+ dataset_id: str,
+ status: str | None = None,
+ enabled: bool | None = None,
+ ) -> Sequence[DocumentSegment]:
+ """
+ Get segments for a document in a dataset with optional filtering.
+
+ Args:
+ document_id: Document ID
+ dataset_id: Dataset ID
+ status: Optional status filter (e.g., "completed")
+ enabled: Optional enabled filter (True/False)
+
+ Returns:
+ Sequence of DocumentSegment instances
+ """
+ query = select(DocumentSegment).where(
+ DocumentSegment.document_id == document_id,
+ DocumentSegment.dataset_id == dataset_id,
+ )
+
+ if status is not None:
+ query = query.where(DocumentSegment.status == status)
+
+ if enabled is not None:
+ query = query.where(DocumentSegment.enabled == enabled)
+
+ return db.session.scalars(query).all()
+
class DatasetCollectionBindingService:
@classmethod
diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py
index bdc960aa2d..e3832475aa 100644
--- a/api/services/enterprise/base.py
+++ b/api/services/enterprise/base.py
@@ -1,9 +1,14 @@
+import logging
import os
from collections.abc import Mapping
from typing import Any
import httpx
+from core.helper.trace_id_helper import generate_traceparent_header
+
+logger = logging.getLogger(__name__)
+
class BaseRequest:
proxies: Mapping[str, str] | None = {
@@ -38,6 +43,15 @@ class BaseRequest:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
mounts = cls._build_mounts()
+
+ try:
+ # ensure traceparent even when OTEL is disabled
+ traceparent = generate_traceparent_header()
+ if traceparent:
+ headers["traceparent"] = traceparent
+ except Exception:
+ logger.debug("Failed to generate traceparent header", exc_info=True)
+
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
return response.json()
diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py
index 83d0fcf296..a5133dfcb4 100644
--- a/api/services/enterprise/enterprise_service.py
+++ b/api/services/enterprise/enterprise_service.py
@@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
)
+class WorkspacePermission(BaseModel):
+ workspace_id: str = Field(
+ description="The ID of the workspace.",
+ alias="workspaceId",
+ )
+ allow_member_invite: bool = Field(
+ description="Whether to allow members to invite new members to the workspace.",
+ default=False,
+ alias="allowMemberInvite",
+ )
+ allow_owner_transfer: bool = Field(
+ description="Whether to allow owners to transfer ownership of the workspace.",
+ default=False,
+ alias="allowOwnerTransfer",
+ )
+
+
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -44,6 +61,16 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
+ class WorkspacePermissionService:
+ @classmethod
+ def get_permission(cls, workspace_id: str):
+ if not workspace_id:
+ raise ValueError("workspace_id must be provided.")
+ data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
+ if not data or "permission" not in data:
+ raise ValueError("No data found.")
+ return WorkspacePermission.model_validate(data["permission"])
+
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
@@ -110,5 +137,5 @@ class EnterpriseService:
if not app_id:
raise ValueError("app_id must be provided.")
- body = {"appId": app_id}
- EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
+ params = {"appId": app_id}
+ EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
diff --git a/api/services/enterprise/workspace_sync.py b/api/services/enterprise/workspace_sync.py
new file mode 100644
index 0000000000..acfe325397
--- /dev/null
+++ b/api/services/enterprise/workspace_sync.py
@@ -0,0 +1,58 @@
+import json
+import logging
+import uuid
+from datetime import UTC, datetime
+
+from redis import RedisError
+
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
+WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
+
+
+class WorkspaceSyncService:
+ """Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
+
+ @staticmethod
+ def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
+ """
+ Queue a credential sync task for a newly created workspace.
+
+ This publishes a task to Redis that will be consumed by the enterprise backend
+ worker to sync credentials with the plugin-manager.
+
+ Args:
+ workspace_id: The workspace/tenant ID to sync credentials for
+ source: Source of the sync request (for debugging/tracking)
+
+ Returns:
+ bool: True if task was queued successfully, False otherwise
+ """
+ try:
+ task = {
+ "task_id": str(uuid.uuid4()),
+ "workspace_id": workspace_id,
+ "retry_count": 0,
+ "created_at": datetime.now(UTC).isoformat(),
+ "source": source,
+ }
+
+ # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
+ redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
+
+ logger.info(
+ "Queued credential sync task for workspace %s, task_id: %s, source: %s",
+ workspace_id,
+ task["task_id"],
+ source,
+ )
+ return True
+
+ except (RedisError, TypeError) as e:
+ logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
+ # Don't raise - we don't want to fail workspace creation if queueing fails
+ # The scheduled task will catch it later
+ return False
diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py
index 7959734e89..8dc5b93501 100644
--- a/api/services/entities/knowledge_entities/knowledge_entities.py
+++ b/api/services/entities/knowledge_entities/knowledge_entities.py
@@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel):
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
+ summary_index_setting: dict | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: str | None = None
@@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel):
regenerate_child_chunks: bool = False
enabled: bool | None = None
attachment_ids: list[str] | None = None
+ summary: str | None = None # Summary content for summary index
class ChildChunkUpdateArgs(BaseModel):
diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py
index a97ccab914..041ae4edba 100644
--- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py
+++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py
@@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
- partial_member_list: list[str] | None = None
+ partial_member_list: list[dict[str, str]] | None = None
yaml_content: str | None = None
@@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel):
embedding_model: str = ""
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
+ # add summary index setting
+ summary_index_setting: dict | None = None
@field_validator("embedding_model_provider", mode="before")
@classmethod
diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py
index f405546909..a29d848ac5 100644
--- a/api/services/entities/model_provider_entities.py
+++ b/api/services/entities/model_provider_entities.py
@@ -70,7 +70,6 @@ class ProviderResponse(BaseModel):
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]
@@ -98,11 +97,6 @@ class ProviderResponse(BaseModel):
en_US=f"{url_prefix}/icon_small_dark/en_US",
zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans",
)
-
- if self.icon_large is not None:
- self.icon_large = I18nObject(
- en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
- )
return self
@@ -116,7 +110,6 @@ class ProviderWithModelsResponse(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
- icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
@@ -134,11 +127,6 @@ class ProviderWithModelsResponse(BaseModel):
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
-
- if self.icon_large is not None:
- self.icon_large = I18nObject(
- en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
- )
return self
@@ -163,11 +151,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
-
- if self.icon_large is not None:
- self.icon_large = I18nObject(
- en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
- )
return self
diff --git a/api/services/errors/app.py b/api/services/errors/app.py
index 24e4760acc..60e59e97dc 100644
--- a/api/services/errors/app.py
+++ b/api/services/errors/app.py
@@ -18,8 +18,8 @@ class WorkflowIdFormatError(Exception):
pass
-class InvokeRateLimitError(Exception):
- """Raised when rate limit is exceeded for workflow invocations."""
+class WorkflowQuotaLimitError(Exception):
+ """Raised when workflow execution quota is exceeded (for async/background workflows)."""
pass
diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py
index 40faa85b9a..65dd41af43 100644
--- a/api/services/external_knowledge_service.py
+++ b/api/services/external_knowledge_service.py
@@ -35,7 +35,10 @@ class ExternalDatasetService:
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
- query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
+ from libs.helper import escape_like_pattern
+
+ escaped_search = escape_like_pattern(search)
+ query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\"))
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index 8035adc734..d94ae49d91 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from enums.cloud_plan import CloudPlan
+from enums.hosted_provider import HostedTrialProvider
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
@@ -140,6 +141,7 @@ class FeatureModel(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
+ next_credit_reset_date: int = 0
class KnowledgeRateLimitModel(BaseModel):
@@ -169,6 +171,9 @@ class SystemFeatureModel(BaseModel):
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
+ trial_models: list[str] = []
+ enable_trial_app: bool = False
+ enable_explore_banner: bool = False
class FeatureService:
@@ -199,7 +204,7 @@ class FeatureService:
return knowledge_rate_limit
@classmethod
- def get_system_features(cls) -> SystemFeatureModel:
+ def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
system_features = SystemFeatureModel()
cls._fulfill_system_params_from_env(system_features)
@@ -209,7 +214,7 @@ class FeatureService:
system_features.webapp_auth.enabled = True
system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
- cls._fulfill_params_from_enterprise(system_features)
+ cls._fulfill_params_from_enterprise(system_features, is_authenticated)
if dify_config.MARKETPLACE_ENABLED:
system_features.enable_marketplace = True
@@ -224,6 +229,20 @@ class FeatureService:
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
+ system_features.trial_models = cls._fulfill_trial_models_from_env()
+ system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
+ system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
+
+ @classmethod
+ def _fulfill_trial_models_from_env(cls) -> list[str]:
+ return [
+ provider.value
+ for provider in HostedTrialProvider
+ if (
+ getattr(dify_config, f"HOSTED_{provider.config_key}_PAID_ENABLED", False)
+ and getattr(dify_config, f"HOSTED_{provider.config_key}_TRIAL_ENABLED", False)
+ )
+ ]
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
@@ -301,8 +320,11 @@ class FeatureService:
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
+ if "next_credit_reset_date" in billing_info:
+ features.next_credit_reset_date = billing_info["next_credit_reset_date"]
+
@classmethod
- def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
+ def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False):
enterprise_info = EnterpriseService.get_info()
if "SSOEnforcedForSignin" in enterprise_info:
@@ -339,19 +361,14 @@ class FeatureService:
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
- if "License" in enterprise_info:
- license_info = enterprise_info["License"]
+ if is_authenticated and (license_info := enterprise_info.get("License")):
+ features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
+ features.license.expired_at = license_info.get("expiredAt", "")
- if "status" in license_info:
- features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
-
- if "expiredAt" in license_info:
- features.license.expired_at = license_info["expiredAt"]
-
- if "workspaces" in license_info:
- features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
- features.license.workspaces.limit = license_info["workspaces"]["limit"]
- features.license.workspaces.size = license_info["workspaces"]["used"]
+ if workspaces_info := license_info.get("workspaces"):
+ features.license.workspaces.enabled = workspaces_info.get("enabled", False)
+ features.license.workspaces.limit = workspaces_info.get("limit", 0)
+ features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
diff --git a/api/services/file_service.py b/api/services/file_service.py
index 0911cf38c4..a0a99f3f82 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -2,7 +2,11 @@ import base64
import hashlib
import os
import uuid
+from collections.abc import Iterator, Sequence
+from contextlib import contextmanager, suppress
+from tempfile import NamedTemporaryFile
from typing import Literal, Union
+from zipfile import ZIP_DEFLATED, ZipFile
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
@@ -17,6 +21,7 @@ from constants import (
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
+from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
@@ -167,6 +172,9 @@ class FileService:
return upload_file
def get_file_preview(self, file_id: str):
+ """
+ Return a short text preview extracted from a document file.
+ """
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
@@ -253,3 +261,101 @@ class FileService:
return
storage.delete(upload_file.key)
session.delete(upload_file)
+
+ @staticmethod
+ def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
+ """
+ Fetch `UploadFile` rows for a tenant in a single batch query.
+
+ This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`.
+ """
+ if not upload_file_ids:
+ return {}
+
+ # Normalize and deduplicate ids before using them in the IN clause.
+ upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids]
+ unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
+
+ # Fetch upload files in one query for efficient batch access.
+ upload_files: Sequence[UploadFile] = db.session.scalars(
+ select(UploadFile).where(
+ UploadFile.tenant_id == tenant_id,
+ UploadFile.id.in_(unique_upload_file_ids),
+ )
+ ).all()
+ return {str(upload_file.id): upload_file for upload_file in upload_files}
+
+ @staticmethod
+ def _sanitize_zip_entry_name(name: str) -> str:
+ """
+ Sanitize a ZIP entry name to avoid path traversal and weird separators.
+
+ We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data)
+ could still contain unsafe names.
+ """
+ # Drop any directory components and prevent empty names.
+ base = os.path.basename(name).strip() or "file"
+
+ # ZIP uses forward slashes as separators; remove any residual separator characters.
+ return base.replace("/", "_").replace("\\", "_")
+
+ @staticmethod
+ def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str:
+ """
+ Return a unique ZIP entry name, inserting suffixes before the extension.
+ """
+ # Keep the original name when it's not already used.
+ if original_name not in used_names:
+ return original_name
+
+ # Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt").
+ stem, extension = os.path.splitext(original_name)
+ suffix = 1
+ while True:
+ candidate = f"{stem} ({suffix}){extension}"
+ if candidate not in used_names:
+ return candidate
+ suffix += 1
+
+ @staticmethod
+ @contextmanager
+ def build_upload_files_zip_tempfile(
+ *,
+ upload_files: Sequence[UploadFile],
+ ) -> Iterator[str]:
+ """
+ Build a ZIP from `UploadFile`s and yield a tempfile path.
+
+ We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug
+ streams responses. The caller is expected to keep this context open until the response is fully sent, then
+ close it (e.g., via `response.call_on_close(...)`) to delete the tempfile.
+ """
+ used_names: set[str] = set()
+
+ # Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it.
+ tmp_path: str | None = None
+ try:
+ with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp:
+ tmp_path = tmp.name
+ with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf:
+ for upload_file in upload_files:
+ # Ensure the entry name is safe and unique.
+ safe_name = FileService._sanitize_zip_entry_name(upload_file.name)
+ arcname = FileService._dedupe_zip_entry_name(safe_name, used_names)
+ used_names.add(arcname)
+
+ # Stream file bytes from storage into the ZIP entry.
+ with zf.open(arcname, "w") as entry:
+ for chunk in storage.load(upload_file.key, stream=True):
+ entry.write(chunk)
+
+ # Flush so `send_file(path, ...)` can re-open it safely on all platforms.
+ tmp.flush()
+
+ assert tmp_path is not None
+ yield tmp_path
+ finally:
+ # Remove the temp file when the context is closed (typically after the response finishes streaming).
+ if tmp_path is not None:
+ with suppress(FileNotFoundError):
+ os.remove(tmp_path)
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index 8e8e78f83f..8cbf3a25c3 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -178,8 +178,8 @@ class HitTestingService:
@classmethod
def hit_testing_args_check(cls, args):
- query = args["query"]
- attachment_ids = args["attachment_ids"]
+ query = args.get("query")
+ attachment_ids = args.get("attachment_ids")
if not attachment_ids and not query:
raise ValueError("Query or attachment_ids is required")
diff --git a/api/services/message_service.py b/api/services/message_service.py
index e1a256e64d..a53ca8b22d 100644
--- a/api/services/message_service.py
+++ b/api/services/message_service.py
@@ -261,10 +261,9 @@ class MessageService:
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
app_model_config = AppModelConfig(
- id=conversation.app_model_config_id,
app_id=app_model.id,
)
-
+ app_model_config.id = conversation.app_model_config_id
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if not app_model_config:
raise ValueError("did not find app model config")
diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py
index eea382febe..edd1004b82 100644
--- a/api/services/model_provider_service.py
+++ b/api/services/model_provider_service.py
@@ -99,7 +99,6 @@ class ModelProviderService:
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_small_dark=provider_configuration.provider.icon_small_dark,
- icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
@@ -423,7 +422,6 @@ class ModelProviderService:
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_small_dark=first_model.provider.icon_small_dark,
- icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE,
models=[
ProviderModelWithStatusEntity(
@@ -488,7 +486,6 @@ class ModelProviderService:
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
- icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types,
),
)
@@ -522,7 +519,7 @@ class ModelProviderService:
:param tenant_id: workspace id
:param provider: provider name
- :param icon_type: icon type (icon_small or icon_large)
+ :param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return:
"""
diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py
index c517d9f966..40565c56ed 100644
--- a/api/services/plugin/plugin_parameter_service.py
+++ b/api/services/plugin/plugin_parameter_service.py
@@ -105,3 +105,49 @@ class PluginParameterService:
)
.options
)
+
+ @staticmethod
+ def get_dynamic_select_options_with_credentials(
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ action: str,
+ parameter: str,
+ credential_id: str,
+ credentials: Mapping[str, Any],
+ ) -> Sequence[PluginParameterOption]:
+ """
+ Get dynamic select options using provided credentials directly.
+ Used for edit mode when credentials have been modified but not yet saved.
+
+ Security: credential_id is validated against tenant_id to ensure
+ users can only access their own credentials.
+ """
+ from constants import HIDDEN_VALUE
+
+ # Get original subscription to replace hidden values (with tenant_id check for security)
+ original_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
+ if not original_subscription:
+ raise ValueError(f"Subscription {credential_id} not found")
+
+ # Replace [__HIDDEN__] with original values
+ resolved_credentials: dict[str, Any] = {
+ key: (original_subscription.credentials.get(key) if value == HIDDEN_VALUE else value)
+ for key, value in credentials.items()
+ }
+
+ return (
+ DynamicSelectClient()
+ .fetch_dynamic_select_options(
+ tenant_id,
+ user_id,
+ plugin_id,
+ provider,
+ action,
+ resolved_credentials,
+ original_subscription.credential_type or CredentialType.UNAUTHORIZED.value,
+ parameter,
+ )
+ .options
+ )
diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py
index b8303eb724..411c335c17 100644
--- a/api/services/plugin/plugin_service.py
+++ b/api/services/plugin/plugin_service.py
@@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
+from sqlalchemy import select
from yarl import URL
from configs import dify_config
@@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
+from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
+
+ # Get plugin info before uninstalling to delete associated credentials
+ try:
+ plugins = manager.list_plugins(tenant_id)
+ plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
+
+ if plugin:
+ plugin_id = plugin.plugin_id
+ logger.info("Deleting credentials for plugin: %s", plugin_id)
+
+ # Delete provider credentials that match this plugin
+ credentials = db.session.scalars(
+ select(ProviderCredential).where(
+ ProviderCredential.tenant_id == tenant_id,
+ ProviderCredential.provider_name.like(f"{plugin_id}/%"),
+ )
+ ).all()
+
+ for cred in credentials:
+ db.session.delete(cred)
+
+ db.session.commit()
+ logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
+ except Exception as e:
+ logger.warning("Failed to delete credentials: %s", e)
+ # Continue with uninstall even if credential deletion fails
+
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 097d16e2a7..ccc6abcc06 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
-from core.variables.variables import Variable
+from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""
@@ -436,7 +436,7 @@ class RagPipelineService:
user_inputs=user_inputs,
user_id=account.id,
variable_pool=VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs=user_inputs,
environment_variables=[],
conversation_variables=[],
@@ -874,7 +874,7 @@ class RagPipelineService:
variable_pool = node_instance.graph_runtime_state.variable_pool
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
- if invoke_from.value == InvokeFrom.PUBLISHED:
+ if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
@@ -1248,14 +1248,13 @@ class RagPipelineService:
session.commit()
return workflow_node_execution_db_model
- def get_recommended_plugins(self) -> dict:
+ def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
- pipeline_recommended_plugins = (
- db.session.query(PipelineRecommendedPlugin)
- .where(PipelineRecommendedPlugin.active == True)
- .order_by(PipelineRecommendedPlugin.position.asc())
- .all()
- )
+ query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
+ if type and type != "all":
+ query = query.where(PipelineRecommendedPlugin.type == type)
+
+ pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
if not pipeline_recommended_plugins:
return {
@@ -1319,7 +1318,7 @@ class RagPipelineService:
"datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
"original_document_id": document.id,
},
- invoke_from=InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
streaming=False,
call_depth=0,
workflow_thread_pool_id=None,
diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py
index 06f294863d..be1ce834f6 100644
--- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py
+++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py
@@ -343,6 +343,9 @@ class RagPipelineDslService:
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
+ # Update summary_index_setting if provided
+ if knowledge_configuration.summary_index_setting is not None:
+ dataset.summary_index_setting = knowledge_configuration.summary_index_setting
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
@@ -477,6 +480,9 @@ class RagPipelineDslService:
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
+ # Update summary_index_setting if provided
+ if knowledge_configuration.summary_index_setting is not None:
+ dataset.summary_index_setting = knowledge_configuration.summary_index_setting
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
@@ -870,15 +876,16 @@ class RagPipelineDslService:
return dependencies
@classmethod
- def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
+ def get_leaked_dependencies(
+ cls, tenant_id: str, dsl_dependencies: list[PluginDependency]
+ ) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
- dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
- if not dependencies:
+ if not dsl_dependencies:
return []
- return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
+ return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies)
def _generate_aes_key(self, tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py
index 84f97907c0..8ea365e907 100644
--- a/api/services/rag_pipeline/rag_pipeline_transform_service.py
+++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py
@@ -44,7 +44,7 @@ class RagPipelineTransformService:
doc_form = dataset.doc_form
if not doc_form:
return self._transform_to_empty_pipeline(dataset)
- retrieval_model = dataset.retrieval_model
+ retrieval_model = RetrievalSetting.model_validate(dataset.retrieval_model) if dataset.retrieval_model else None
pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
# deal dependencies
self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
@@ -154,7 +154,12 @@ class RagPipelineTransformService:
return node
def _deal_knowledge_index(
- self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
+ self,
+ dataset: Dataset,
+ doc_form: str,
+ indexing_technique: str | None,
+ retrieval_model: RetrievalSetting | None,
+ node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
@@ -163,10 +168,9 @@ class RagPipelineTransformService:
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
- retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
if indexing_technique == "economy":
- retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
- knowledge_configuration.retrieval_model = retrieval_setting
+ retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH
+ knowledge_configuration.retrieval_model = retrieval_model
else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py
index 544383a106..6b211a5632 100644
--- a/api/services/recommended_app_service.py
+++ b/api/services/recommended_app_service.py
@@ -1,4 +1,7 @@
from configs import dify_config
+from extensions.ext_database import db
+from models.model import AccountTrialAppRecord, TrialApp
+from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@@ -20,6 +23,15 @@ class RecommendedAppService:
)
)
+ if FeatureService.get_system_features().enable_trial_app:
+ apps = result["recommended_apps"]
+ for app in apps:
+ app_id = app["app_id"]
+ trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
+ if trial_app_model:
+ app["can_trial"] = True
+ else:
+ app["can_trial"] = False
return result
@classmethod
@@ -32,4 +44,30 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
+ if FeatureService.get_system_features().enable_trial_app:
+ app_id = result["id"]
+ trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
+ if trial_app_model:
+ result["can_trial"] = True
+ else:
+ result["can_trial"] = False
return result
+
+ @classmethod
+ def add_trial_app_record(cls, app_id: str, account_id: str):
+ """
+ Add trial app record.
+ :param app_id: app id
+ :return:
+ """
+ account_trial_app_record = (
+ db.session.query(AccountTrialAppRecord)
+ .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
+ .first()
+ )
+ if account_trial_app_record:
+ account_trial_app_record.count += 1
+ db.session.commit()
+ else:
+ db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
+ db.session.commit()
diff --git a/api/services/retention/__init__.py b/api/services/retention/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py
new file mode 100644
index 0000000000..6e647b983b
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_policy.py
@@ -0,0 +1,216 @@
+import datetime
+import logging
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SimpleMessage:
+ id: str
+ app_id: str
+ created_at: datetime.datetime
+
+
+class MessagesCleanPolicy(ABC):
+ """
+ Abstract base class for message cleanup policies.
+
+ A policy determines which messages from a batch should be deleted.
+ """
+
+ @abstractmethod
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages and return IDs of messages that should be deleted.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ ...
+
+
+class BillingDisabledPolicy(MessagesCleanPolicy):
+ """
+ Policy for community or enterpriseedition (billing disabled).
+
+ No special filter logic, just return all message ids.
+ """
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ return [msg.id for msg in messages]
+
+
+class BillingSandboxPolicy(MessagesCleanPolicy):
+ """
+ Policy for sandbox plan tenants in cloud edition (billing enabled).
+
+ Filters messages based on sandbox plan expiration rules:
+ - Skip tenants in the whitelist
+ - Only delete messages from sandbox plan tenants
+ - Respect grace period after subscription expiration
+ - Safe default: if tenant mapping or plan is missing, do NOT delete
+ """
+
+ def __init__(
+ self,
+ plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
+ graceful_period_days: int = 21,
+ tenant_whitelist: Sequence[str] | None = None,
+ current_timestamp: int | None = None,
+ ) -> None:
+ self._graceful_period_days = graceful_period_days
+ self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
+ self._plan_provider = plan_provider
+ self._current_timestamp = current_timestamp
+
+ def filter_message_ids(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ ) -> Sequence[str]:
+ """
+ Filter messages based on sandbox plan expiration rules.
+
+ Args:
+ messages: Batch of messages to evaluate
+ app_to_tenant: Mapping from app_id to tenant_id
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ if not messages or not app_to_tenant:
+ return []
+
+ # Get unique tenant_ids and fetch subscription plans
+ tenant_ids = list(set(app_to_tenant.values()))
+ tenant_plans = self._plan_provider(tenant_ids)
+
+ if not tenant_plans:
+ return []
+
+ # Apply sandbox deletion rules
+ return self._filter_expired_sandbox_messages(
+ messages=messages,
+ app_to_tenant=app_to_tenant,
+ tenant_plans=tenant_plans,
+ )
+
+ def _filter_expired_sandbox_messages(
+ self,
+ messages: Sequence[SimpleMessage],
+ app_to_tenant: dict[str, str],
+ tenant_plans: dict[str, SubscriptionPlan],
+ ) -> list[str]:
+ """
+ Filter messages that should be deleted based on sandbox plan expiration.
+
+ A message should be deleted if:
+ 1. It belongs to a sandbox tenant AND
+ 2. Either:
+ a) The tenant has no previous subscription (expiration_date == -1), OR
+ b) The subscription expired more than graceful_period_days ago
+
+ Args:
+ messages: List of message objects with id and app_id attributes
+ app_to_tenant: Mapping from app_id to tenant_id
+ tenant_plans: Mapping from tenant_id to subscription plan info
+
+ Returns:
+ List of message IDs that should be deleted
+ """
+ current_timestamp = self._current_timestamp
+ if current_timestamp is None:
+ current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
+
+ sandbox_message_ids: list[str] = []
+ graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
+
+ for msg in messages:
+ # Get tenant_id for this message's app
+ tenant_id = app_to_tenant.get(msg.app_id)
+ if not tenant_id:
+ continue
+
+ # Skip tenant messages in whitelist
+ if tenant_id in self._tenant_whitelist:
+ continue
+
+ # Get subscription plan for this tenant
+ tenant_plan = tenant_plans.get(tenant_id)
+ if not tenant_plan:
+ continue
+
+ plan = str(tenant_plan["plan"])
+ expiration_date = int(tenant_plan["expiration_date"])
+
+ # Only process sandbox plans
+ if plan != CloudPlan.SANDBOX:
+ continue
+
+ # Case 1: No previous subscription (-1 means never had a paid subscription)
+ if expiration_date == -1:
+ sandbox_message_ids.append(msg.id)
+ continue
+
+ # Case 2: Subscription expired beyond grace period
+ if current_timestamp - expiration_date > graceful_period_seconds:
+ sandbox_message_ids.append(msg.id)
+
+ return sandbox_message_ids
+
+
+def create_message_clean_policy(
+ graceful_period_days: int = 21,
+ current_timestamp: int | None = None,
+) -> MessagesCleanPolicy:
+ """
+ Factory function to create the appropriate message clean policy.
+
+ Determines which policy to use based on BILLING_ENABLED configuration:
+ - If BILLING_ENABLED is True: returns BillingSandboxPolicy
+ - If BILLING_ENABLED is False: returns BillingDisabledPolicy
+
+ Args:
+ graceful_period_days: Grace period in days after subscription expiration (default: 21)
+ current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
+ """
+ if not dify_config.BILLING_ENABLED:
+ logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
+ return BillingDisabledPolicy()
+
+ # Billing enabled - fetch whitelist from BillingService
+ tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
+ plan_provider = BillingService.get_plan_bulk_with_cache
+
+ logger.info(
+ "create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
+ "(graceful_period_days=%s, whitelist=%s)",
+ graceful_period_days,
+ tenant_whitelist,
+ )
+
+ return BillingSandboxPolicy(
+ plan_provider=plan_provider,
+ graceful_period_days=graceful_period_days,
+ tenant_whitelist=tenant_whitelist,
+ current_timestamp=current_timestamp,
+ )
diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py
new file mode 100644
index 0000000000..3ca5d82860
--- /dev/null
+++ b/api/services/retention/conversation/messages_clean_service.py
@@ -0,0 +1,334 @@
+import datetime
+import logging
+import random
+from collections.abc import Sequence
+from typing import cast
+
+from sqlalchemy import delete, select
+from sqlalchemy.engine import CursorResult
+from sqlalchemy.orm import Session
+
+from extensions.ext_database import db
+from models.model import (
+ App,
+ AppAnnotationHitHistory,
+ DatasetRetrieverResource,
+ Message,
+ MessageAgentThought,
+ MessageAnnotation,
+ MessageChain,
+ MessageFeedback,
+ MessageFile,
+)
+from models.web import SavedMessage
+from services.retention.conversation.messages_clean_policy import (
+ MessagesCleanPolicy,
+ SimpleMessage,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class MessagesCleanService:
+ """
+ Service for cleaning expired messages based on retention policies.
+
+ Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
+ If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
+ """
+
+ def __init__(
+ self,
+ policy: MessagesCleanPolicy,
+ end_before: datetime.datetime,
+ start_from: datetime.datetime | None = None,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> None:
+ """
+ Initialize the service with cleanup parameters.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ end_before: End time (exclusive) of the range
+ start_from: Optional start time (inclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+ """
+ self._policy = policy
+ self._end_before = end_before
+ self._start_from = start_from
+ self._batch_size = batch_size
+ self._dry_run = dry_run
+
+ @classmethod
+ def from_time_range(
+ cls,
+ policy: MessagesCleanPolicy,
+ start_from: datetime.datetime,
+ end_before: datetime.datetime,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages within a specific time range.
+
+ Time range is [start_from, end_before).
+
+ Args:
+ policy: The policy that determines which messages to delete
+ start_from: Start time (inclusive) of the range
+ end_before: End time (exclusive) of the range
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If start_from >= end_before or invalid parameters
+ """
+ if start_from >= end_before:
+ raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ logger.info(
+ "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
+ start_from,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(
+ policy=policy,
+ end_before=end_before,
+ start_from=start_from,
+ batch_size=batch_size,
+ dry_run=dry_run,
+ )
+
+ @classmethod
+ def from_days(
+ cls,
+ policy: MessagesCleanPolicy,
+ days: int = 30,
+ batch_size: int = 1000,
+ dry_run: bool = False,
+ ) -> "MessagesCleanService":
+ """
+ Create a service instance for cleaning messages older than specified days.
+
+ Args:
+ policy: The policy that determines which messages to delete
+ days: Number of days to look back from now
+ batch_size: Number of messages to process per batch
+ dry_run: Whether to perform a dry run (no actual deletion)
+
+ Returns:
+ MessagesCleanService instance
+
+ Raises:
+ ValueError: If invalid parameters
+ """
+ if days < 0:
+ raise ValueError(f"days ({days}) must be greater than or equal to 0")
+
+ if batch_size <= 0:
+ raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
+
+ end_before = datetime.datetime.now() - datetime.timedelta(days=days)
+
+ logger.info(
+ "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
+ days,
+ end_before,
+ batch_size,
+ policy.__class__.__name__,
+ )
+
+ return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
+
+ def run(self) -> dict[str, int]:
+ """
+ Execute the message cleanup operation.
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ return self._clean_messages_by_time_range()
+
+ def _clean_messages_by_time_range(self) -> dict[str, int]:
+ """
+ Clean messages within a time range using cursor-based pagination.
+
+ Time range is [start_from, end_before)
+
+ Steps:
+ 1. Iterate messages using cursor pagination (by created_at, id)
+ 2. Query app_id -> tenant_id mapping
+ 3. Delegate to policy to determine which messages to delete
+ 4. Batch delete messages and their relations
+
+ Returns:
+ Dict with statistics: batches, filtered_messages, total_deleted
+ """
+ stats = {
+ "batches": 0,
+ "total_messages": 0,
+ "filtered_messages": 0,
+ "total_deleted": 0,
+ }
+
+ # Cursor-based pagination using (created_at, id) to avoid infinite loops
+ # and ensure proper ordering with time-based filtering
+ _cursor: tuple[datetime.datetime, str] | None = None
+
+ logger.info(
+ "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
+ self._dry_run,
+ self._start_from,
+ self._end_before,
+ )
+
+ while True:
+ stats["batches"] += 1
+
+ # Step 1: Fetch a batch of messages using cursor
+ with Session(db.engine, expire_on_commit=False) as session:
+ msg_stmt = (
+ select(Message.id, Message.app_id, Message.created_at)
+ .where(Message.created_at < self._end_before)
+ .order_by(Message.created_at, Message.id)
+ .limit(self._batch_size)
+ )
+
+ if self._start_from:
+ msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
+
+ # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
+ # This translates to:
+ # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
+ if _cursor:
+ # Continuing from previous batch
+ msg_stmt = msg_stmt.where(
+ (Message.created_at > _cursor[0])
+ | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
+ )
+
+ raw_messages = list(session.execute(msg_stmt).all())
+ messages = [
+ SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
+ for msg_id, app_id, msg_created_at in raw_messages
+ ]
+
+ # Track total messages fetched across all batches
+ stats["total_messages"] += len(messages)
+
+ if not messages:
+ logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
+ break
+
+ # Update cursor to the last message's (created_at, id)
+ _cursor = (messages[-1].created_at, messages[-1].id)
+
+ # Step 2: Extract app_ids and query tenant_ids
+ app_ids = list({msg.app_id for msg in messages})
+
+ if not app_ids:
+ logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
+ continue
+
+ app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
+ apps = list(session.execute(app_stmt).all())
+
+ if not apps:
+ logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
+ continue
+
+ # Build app_id -> tenant_id mapping
+ app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
+
+ # Step 3: Delegate to policy to determine which messages to delete
+ message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
+
+ if not message_ids_to_delete:
+ logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
+ continue
+
+ stats["filtered_messages"] += len(message_ids_to_delete)
+
+ # Step 4: Batch delete messages and their relations
+ if not self._dry_run:
+ with Session(db.engine, expire_on_commit=False) as session:
+ # Delete related records first
+ self._batch_delete_message_relations(session, message_ids_to_delete)
+
+ # Delete messages
+ delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
+ delete_result = cast(CursorResult, session.execute(delete_stmt))
+ messages_deleted = delete_result.rowcount
+ session.commit()
+
+ stats["total_deleted"] += messages_deleted
+
+ logger.info(
+ "clean_messages (batch %s): processed %s messages, deleted %s messages",
+ stats["batches"],
+ len(messages),
+ messages_deleted,
+ )
+ else:
+ # Log random sample of message IDs that would be deleted (up to 10)
+ sample_size = min(10, len(message_ids_to_delete))
+ sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
+
+ logger.info(
+ "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
+ stats["batches"],
+ len(message_ids_to_delete),
+ sample_size,
+ )
+ for msg_id in sampled_ids:
+ logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
+
+ logger.info(
+ "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
+ stats["batches"],
+ stats["total_messages"],
+ stats["filtered_messages"],
+ stats["total_deleted"],
+ )
+
+ return stats
+
+ @staticmethod
+ def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
+ """
+ Batch delete all related records for given message IDs.
+
+ Args:
+ session: Database session
+ message_ids: List of message IDs to delete relations for
+ """
+ if not message_ids:
+ return
+
+ # Delete all related records in batch
+ session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
+
+ session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
+
+ session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
+
+ session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
+
+ session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))
diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py
new file mode 100644
index 0000000000..18dd42c91e
--- /dev/null
+++ b/api/services/retention/workflow_run/__init__.py
@@ -0,0 +1 @@
+"""Workflow run retention services."""
diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py
new file mode 100644
index 0000000000..ea5cbb7740
--- /dev/null
+++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py
@@ -0,0 +1,531 @@
+"""
+Archive Paid Plan Workflow Run Logs Service.
+
+This service archives workflow run logs for paid plan users older than the configured
+retention period (default: 90 days) to S3-compatible storage.
+
+Archived tables:
+- workflow_runs
+- workflow_app_logs
+- workflow_node_executions
+- workflow_node_execution_offload
+- workflow_pauses
+- workflow_pause_reasons
+- workflow_trigger_logs
+
+"""
+
+import datetime
+import io
+import json
+import logging
+import time
+import zipfile
+from collections.abc import Sequence
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass, field
+from typing import Any
+
+import click
+from sqlalchemy import inspect
+from sqlalchemy.orm import Session, sessionmaker
+
+from configs import dify_config
+from core.workflow.enums import WorkflowType
+from enums.cloud_plan import CloudPlan
+from extensions.ext_database import db
+from libs.archive_storage import (
+ ArchiveStorage,
+ ArchiveStorageNotConfiguredError,
+ get_archive_storage,
+)
+from models.workflow import WorkflowAppLog, WorkflowRun
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+from services.billing_service import BillingService
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class TableStats:
+ """Statistics for a single archived table."""
+
+ table_name: str
+ row_count: int
+ checksum: str
+ size_bytes: int
+
+
+@dataclass
+class ArchiveResult:
+ """Result of archiving a single workflow run."""
+
+ run_id: str
+ tenant_id: str
+ success: bool
+ tables: list[TableStats] = field(default_factory=list)
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+@dataclass
+class ArchiveSummary:
+ """Summary of the entire archive operation."""
+
+ total_runs_processed: int = 0
+ runs_archived: int = 0
+ runs_skipped: int = 0
+ runs_failed: int = 0
+ total_elapsed_time: float = 0.0
+
+
+class WorkflowRunArchiver:
+ """
+ Archive workflow run logs for paid plan users.
+
+ Storage Layout:
+ {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/
+ └── archive.v1.0.zip
+ ├── manifest.json
+ ├── workflow_runs.jsonl
+ ├── workflow_app_logs.jsonl
+ ├── workflow_node_executions.jsonl
+ ├── workflow_node_execution_offload.jsonl
+ ├── workflow_pauses.jsonl
+ ├── workflow_pause_reasons.jsonl
+ └── workflow_trigger_logs.jsonl
+ """
+
+ ARCHIVED_TYPE = [
+ WorkflowType.WORKFLOW,
+ WorkflowType.RAG_PIPELINE,
+ ]
+ ARCHIVED_TABLES = [
+ "workflow_runs",
+ "workflow_app_logs",
+ "workflow_node_executions",
+ "workflow_node_execution_offload",
+ "workflow_pauses",
+ "workflow_pause_reasons",
+ "workflow_trigger_logs",
+ ]
+
+ start_from: datetime.datetime | None
+ end_before: datetime.datetime
+
+ def __init__(
+ self,
+ days: int = 90,
+ batch_size: int = 100,
+ start_from: datetime.datetime | None = None,
+ end_before: datetime.datetime | None = None,
+ workers: int = 1,
+ tenant_ids: Sequence[str] | None = None,
+ limit: int | None = None,
+ dry_run: bool = False,
+ delete_after_archive: bool = False,
+ workflow_run_repo: APIWorkflowRunRepository | None = None,
+ ):
+ """
+ Initialize the archiver.
+
+ Args:
+ days: Archive runs older than this many days
+ batch_size: Number of runs to process per batch
+ start_from: Optional start time (inclusive) for archiving
+ end_before: Optional end time (exclusive) for archiving
+ workers: Number of concurrent workflow runs to archive
+ tenant_ids: Optional tenant IDs for grayscale rollout
+ limit: Maximum number of runs to archive (None for unlimited)
+ dry_run: If True, only preview without making changes
+ delete_after_archive: If True, delete runs and related data after archiving
+ """
+ self.days = days
+ self.batch_size = batch_size
+ if start_from or end_before:
+ if start_from is None or end_before is None:
+ raise ValueError("start_from and end_before must be provided together")
+ if start_from >= end_before:
+ raise ValueError("start_from must be earlier than end_before")
+ self.start_from = start_from.replace(tzinfo=datetime.UTC)
+ self.end_before = end_before.replace(tzinfo=datetime.UTC)
+ else:
+ self.start_from = None
+ self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
+ if workers < 1:
+ raise ValueError("workers must be at least 1")
+ self.workers = workers
+ self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else []
+ self.limit = limit
+ self.dry_run = dry_run
+ self.delete_after_archive = delete_after_archive
+ self.workflow_run_repo = workflow_run_repo
+
+ def run(self) -> ArchiveSummary:
+ """
+ Main archiving loop.
+
+ Returns:
+ ArchiveSummary with statistics about the operation
+ """
+ summary = ArchiveSummary()
+ start_time = time.time()
+
+ click.echo(
+ click.style(
+ self._build_start_message(),
+ fg="white",
+ )
+ )
+
+ # Initialize archive storage (will raise if not configured)
+ try:
+ if not self.dry_run:
+ storage = get_archive_storage()
+ else:
+ storage = None
+ except ArchiveStorageNotConfiguredError as e:
+ click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
+ return summary
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = self._get_workflow_run_repo()
+
+ def _archive_with_session(run: WorkflowRun) -> ArchiveResult:
+ with session_maker() as session:
+ return self._archive_run(session, storage, run)
+
+ last_seen: tuple[datetime.datetime, str] | None = None
+ archived_count = 0
+
+ with ThreadPoolExecutor(max_workers=self.workers) as executor:
+ while True:
+ # Check limit
+ if self.limit and archived_count >= self.limit:
+ click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow"))
+ break
+
+ # Fetch batch of runs
+ runs = self._get_runs_batch(last_seen)
+
+ if not runs:
+ break
+
+ run_ids = [run.id for run in runs]
+ with session_maker() as session:
+ archived_run_ids = repo.get_archived_run_ids(session, run_ids)
+
+ last_seen = (runs[-1].created_at, runs[-1].id)
+
+ # Filter to paid tenants only
+ tenant_ids = {run.tenant_id for run in runs}
+ paid_tenants = self._filter_paid_tenants(tenant_ids)
+
+ runs_to_process: list[WorkflowRun] = []
+ for run in runs:
+ summary.total_runs_processed += 1
+
+ # Skip non-paid tenants
+ if run.tenant_id not in paid_tenants:
+ summary.runs_skipped += 1
+ continue
+
+ # Skip already archived runs
+ if run.id in archived_run_ids:
+ summary.runs_skipped += 1
+ continue
+
+ # Check limit
+ if self.limit and archived_count + len(runs_to_process) >= self.limit:
+ break
+
+ runs_to_process.append(run)
+
+ if not runs_to_process:
+ continue
+
+ results = list(executor.map(_archive_with_session, runs_to_process))
+
+ for run, result in zip(runs_to_process, results):
+ if result.success:
+ summary.runs_archived += 1
+ archived_count += 1
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} "
+ f"run {run.id} (tenant={run.tenant_id}, "
+ f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)",
+ fg="green",
+ )
+ )
+ else:
+ summary.runs_failed += 1
+ click.echo(
+ click.style(
+ f"Failed to archive run {run.id}: {result.error}",
+ fg="red",
+ )
+ )
+
+ summary.total_elapsed_time = time.time() - start_time
+ click.echo(
+ click.style(
+ f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: "
+ f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
+ f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
+ f"time={summary.total_elapsed_time:.2f}s",
+ fg="white",
+ )
+ )
+
+ return summary
+
+ def _get_runs_batch(
+ self,
+ last_seen: tuple[datetime.datetime, str] | None,
+ ) -> Sequence[WorkflowRun]:
+ """Fetch a batch of workflow runs to archive."""
+ repo = self._get_workflow_run_repo()
+ return repo.get_runs_batch_by_time_range(
+ start_from=self.start_from,
+ end_before=self.end_before,
+ last_seen=last_seen,
+ batch_size=self.batch_size,
+ run_types=self.ARCHIVED_TYPE,
+ tenant_ids=self.tenant_ids or None,
+ )
+
+ def _build_start_message(self) -> str:
+ range_desc = f"before {self.end_before.isoformat()}"
+ if self.start_from:
+ range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}"
+ return (
+ f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving "
+ f"for runs {range_desc} "
+ f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})"
+ )
+
+ def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]:
+ """Filter tenant IDs to only include paid tenants."""
+ if not dify_config.BILLING_ENABLED:
+ # If billing is not enabled, treat all tenants as paid
+ return tenant_ids
+
+ if not tenant_ids:
+ return set()
+
+ try:
+ bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids))
+ except Exception:
+ logger.exception("Failed to fetch billing plans for tenants")
+ # On error, skip all tenants in this batch
+ return set()
+
+ # Filter to paid tenants (any plan except SANDBOX)
+ paid = set()
+ for tid, info in bulk_info.items():
+ if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM):
+ paid.add(tid)
+
+ return paid
+
+ def _archive_run(
+ self,
+ session: Session,
+ storage: ArchiveStorage | None,
+ run: WorkflowRun,
+ ) -> ArchiveResult:
+ """Archive a single workflow run."""
+ start_time = time.time()
+ result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
+
+ try:
+ # Extract data from all tables
+ table_data, app_logs, trigger_metadata = self._extract_data(session, run)
+
+ if self.dry_run:
+ # In dry run, just report what would be archived
+ for table_name in self.ARCHIVED_TABLES:
+ records = table_data.get(table_name, [])
+ result.tables.append(
+ TableStats(
+ table_name=table_name,
+ row_count=len(records),
+ checksum="",
+ size_bytes=0,
+ )
+ )
+ result.success = True
+ else:
+ if storage is None:
+ raise ArchiveStorageNotConfiguredError("Archive storage not configured")
+ archive_key = self._get_archive_key(run)
+
+ # Serialize tables for the archive bundle
+ table_stats: list[TableStats] = []
+ table_payloads: dict[str, bytes] = {}
+ for table_name in self.ARCHIVED_TABLES:
+ records = table_data.get(table_name, [])
+ data = ArchiveStorage.serialize_to_jsonl(records)
+ table_payloads[table_name] = data
+ checksum = ArchiveStorage.compute_checksum(data)
+
+ table_stats.append(
+ TableStats(
+ table_name=table_name,
+ row_count=len(records),
+ checksum=checksum,
+ size_bytes=len(data),
+ )
+ )
+
+ # Generate and upload archive bundle
+ manifest = self._generate_manifest(run, table_stats)
+ manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8")
+ archive_data = self._build_archive_bundle(manifest_data, table_payloads)
+ storage.put_object(archive_key, archive_data)
+
+ repo = self._get_workflow_run_repo()
+ archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata)
+ session.commit()
+
+ deleted_counts = None
+ if self.delete_after_archive:
+ deleted_counts = repo.delete_runs_with_related(
+ [run],
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+
+ logger.info(
+ "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s",
+ run.id,
+ {s.table_name: s.row_count for s in table_stats},
+ archived_log_count,
+ deleted_counts,
+ )
+
+ result.tables = table_stats
+ result.success = True
+
+ except Exception as e:
+ logger.exception("Failed to archive workflow run %s", run.id)
+ result.error = str(e)
+ session.rollback()
+
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def _extract_data(
+ self,
+ session: Session,
+ run: WorkflowRun,
+ ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]:
+ table_data: dict[str, list[dict[str, Any]]] = {}
+ table_data["workflow_runs"] = [self._row_to_dict(run)]
+ repo = self._get_workflow_run_repo()
+ app_logs = repo.get_app_logs_by_run_id(session, run.id)
+ table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs]
+ node_exec_repo = self._get_workflow_node_execution_repo(session)
+ node_exec_records = node_exec_repo.get_executions_by_workflow_run(
+ tenant_id=run.tenant_id,
+ app_id=run.app_id,
+ workflow_run_id=run.id,
+ )
+ node_exec_ids = [record.id for record in node_exec_records]
+ offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids)
+ table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records]
+ table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records]
+ repo = self._get_workflow_run_repo()
+ pause_records = repo.get_pause_records_by_run_id(session, run.id)
+ pause_ids = [pause.id for pause in pause_records]
+ pause_reason_records = repo.get_pause_reason_records_by_run_id(
+ session,
+ pause_ids,
+ )
+ table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records]
+ table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records]
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ trigger_records = trigger_repo.list_by_run_id(run.id)
+ table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records]
+ trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None
+ return table_data, app_logs, trigger_metadata
+
+ @staticmethod
+ def _row_to_dict(row: Any) -> dict[str, Any]:
+ mapper = inspect(row).mapper
+ return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns}
+
+ def _get_archive_key(self, run: WorkflowRun) -> str:
+ """Get the storage key for the archive bundle."""
+ created_at = run.created_at
+ prefix = (
+ f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
+ f"month={created_at.strftime('%m')}/workflow_run_id={run.id}"
+ )
+ return f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+
+ def _generate_manifest(
+ self,
+ run: WorkflowRun,
+ table_stats: list[TableStats],
+ ) -> dict[str, Any]:
+ """Generate a manifest for the archived workflow run."""
+ return {
+ "schema_version": ARCHIVE_SCHEMA_VERSION,
+ "workflow_run_id": run.id,
+ "tenant_id": run.tenant_id,
+ "app_id": run.app_id,
+ "workflow_id": run.workflow_id,
+ "created_at": run.created_at.isoformat(),
+ "archived_at": datetime.datetime.now(datetime.UTC).isoformat(),
+ "tables": {
+ stat.table_name: {
+ "row_count": stat.row_count,
+ "checksum": stat.checksum,
+ "size_bytes": stat.size_bytes,
+ }
+ for stat in table_stats
+ },
+ }
+
+ def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes:
+ buffer = io.BytesIO()
+ with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
+ archive.writestr("manifest.json", manifest_data)
+ for table_name in self.ARCHIVED_TABLES:
+ data = table_payloads.get(table_name)
+ if data is None:
+ raise ValueError(f"Missing archive payload for {table_name}")
+ archive.writestr(f"{table_name}.jsonl", data)
+ return buffer.getvalue()
+
+ def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids)
+
+ def _get_workflow_node_execution_repo(
+ self,
+ session: Session,
+ ) -> DifyAPIWorkflowNodeExecutionRepository:
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+ return self.workflow_run_repo
diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
new file mode 100644
index 0000000000..c3e0dce399
--- /dev/null
+++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py
@@ -0,0 +1,293 @@
+import datetime
+import logging
+from collections.abc import Iterable, Sequence
+
+import click
+from sqlalchemy.orm import Session, sessionmaker
+
+from configs import dify_config
+from enums.cloud_plan import CloudPlan
+from extensions.ext_database import db
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.factory import DifyAPIRepositoryFactory
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+from services.billing_service import BillingService, SubscriptionPlan
+
+logger = logging.getLogger(__name__)
+
+
+class WorkflowRunCleanup:
+ def __init__(
+ self,
+ days: int,
+ batch_size: int,
+ start_from: datetime.datetime | None = None,
+ end_before: datetime.datetime | None = None,
+ workflow_run_repo: APIWorkflowRunRepository | None = None,
+ dry_run: bool = False,
+ ):
+ if (start_from is None) ^ (end_before is None):
+ raise ValueError("start_from and end_before must be both set or both omitted.")
+
+ computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days)
+ self.window_start = start_from
+ self.window_end = end_before or computed_cutoff
+
+ if self.window_start and self.window_end <= self.window_start:
+ raise ValueError("end_before must be greater than start_from.")
+
+ if batch_size <= 0:
+ raise ValueError("batch_size must be greater than 0.")
+
+ self.batch_size = batch_size
+ self._cleanup_whitelist: set[str] | None = None
+ self.dry_run = dry_run
+ self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
+ self.workflow_run_repo: APIWorkflowRunRepository
+ if workflow_run_repo:
+ self.workflow_run_repo = workflow_run_repo
+ else:
+ # Lazy import to avoid circular dependencies during module import
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ def run(self) -> None:
+ click.echo(
+ click.style(
+ f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs "
+ f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}"
+ f"{self.window_end.isoformat()} (batch={self.batch_size})",
+ fg="white",
+ )
+ )
+ if self.dry_run:
+ click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow"))
+
+ total_runs_deleted = 0
+ total_runs_targeted = 0
+ related_totals = self._empty_related_counts() if self.dry_run else None
+ batch_index = 0
+ last_seen: tuple[datetime.datetime, str] | None = None
+
+ while True:
+ run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
+ start_from=self.window_start,
+ end_before=self.window_end,
+ last_seen=last_seen,
+ batch_size=self.batch_size,
+ )
+ if not run_rows:
+ break
+
+ batch_index += 1
+ last_seen = (run_rows[-1].created_at, run_rows[-1].id)
+ tenant_ids = {row.tenant_id for row in run_rows}
+ free_tenants = self._filter_free_tenants(tenant_ids)
+ free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
+ paid_or_skipped = len(run_rows) - len(free_runs)
+
+ if not free_runs:
+ skipped_message = (
+ f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
+ )
+ click.echo(
+ click.style(
+ skipped_message,
+ fg="yellow",
+ )
+ )
+ continue
+
+ total_runs_targeted += len(free_runs)
+
+ if self.dry_run:
+ batch_counts = self.workflow_run_repo.count_runs_with_related(
+ free_runs,
+ count_node_executions=self._count_node_executions,
+ count_trigger_logs=self._count_trigger_logs,
+ )
+ if related_totals is not None:
+ for key in related_totals:
+ related_totals[key] += batch_counts.get(key, 0)
+ sample_ids = ", ".join(run.id for run in free_runs[:5])
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] would delete {len(free_runs)} runs "
+ f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
+ fg="yellow",
+ )
+ )
+ continue
+
+ try:
+ counts = self.workflow_run_repo.delete_runs_with_related(
+ free_runs,
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+ except Exception:
+ logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
+ raise
+
+ total_runs_deleted += counts["runs"]
+ click.echo(
+ click.style(
+ f"[batch #{batch_index}] deleted runs: {counts['runs']} "
+ f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
+ f"skipped {paid_or_skipped} paid/unknown",
+ fg="green",
+ )
+ )
+
+ if self.dry_run:
+ if self.window_start:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
+ f"before {self.window_end.isoformat()}"
+ )
+ if related_totals is not None:
+ summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
+ summary_color = "yellow"
+ else:
+ if self.window_start:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
+ f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
+ )
+ else:
+ summary_message = (
+ f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
+ )
+ summary_color = "white"
+
+ click.echo(click.style(summary_message, fg=summary_color))
+
+ def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
+ tenant_id_list = list(tenant_ids)
+
+ if not dify_config.BILLING_ENABLED:
+ return set(tenant_id_list)
+
+ if not tenant_id_list:
+ return set()
+
+ cleanup_whitelist = self._get_cleanup_whitelist()
+
+ try:
+ bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list)
+ except Exception:
+ bulk_info = {}
+ logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list)
+
+ eligible_free_tenants: set[str] = set()
+ for tenant_id in tenant_id_list:
+ if tenant_id in cleanup_whitelist:
+ continue
+
+ info = bulk_info.get(tenant_id)
+ if info is None:
+ logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id)
+ continue
+
+ if info.get("plan") != CloudPlan.SANDBOX:
+ continue
+
+ if self._is_within_grace_period(tenant_id, info):
+ continue
+
+ eligible_free_tenants.add(tenant_id)
+
+ return eligible_free_tenants
+
+ def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None:
+ if expiration_value < 0:
+ return None
+
+ try:
+ return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC)
+ except (OverflowError, OSError, ValueError):
+ logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id)
+ return None
+
+ def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool:
+ if self.free_plan_grace_period_days <= 0:
+ return False
+
+ expiration_value = info.get("expiration_date", -1)
+ expiration_at = self._expiration_datetime(tenant_id, expiration_value)
+ if expiration_at is None:
+ return False
+
+ grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days)
+ return datetime.datetime.now(datetime.UTC) < grace_deadline
+
+ def _get_cleanup_whitelist(self) -> set[str]:
+ if self._cleanup_whitelist is not None:
+ return self._cleanup_whitelist
+
+ if not dify_config.BILLING_ENABLED:
+ self._cleanup_whitelist = set()
+ return self._cleanup_whitelist
+
+ try:
+ whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist()
+ except Exception:
+ logger.exception("Failed to fetch cleanup whitelist from billing service")
+ whitelist_ids = []
+
+ self._cleanup_whitelist = set(whitelist_ids)
+ return self._cleanup_whitelist
+
+ def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.count_by_run_ids(run_ids)
+
+ @staticmethod
+ def _empty_related_counts() -> dict[str, int]:
+ return {
+ "node_executions": 0,
+ "offloads": 0,
+ "app_logs": 0,
+ "trigger_logs": 0,
+ "pauses": 0,
+ "pause_reasons": 0,
+ }
+
+ @staticmethod
+ def _format_related_counts(counts: dict[str, int]) -> str:
+ return (
+ f"node_executions {counts['node_executions']}, "
+ f"offloads {counts['offloads']}, "
+ f"app_logs {counts['app_logs']}, "
+ f"trigger_logs {counts['trigger_logs']}, "
+ f"pauses {counts['pauses']}, "
+ f"pause_reasons {counts['pause_reasons']}"
+ )
+
+ def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.count_by_runs(session, run_ids)
+
+ def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.delete_by_runs(session, run_ids)
diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py
new file mode 100644
index 0000000000..162bb4947d
--- /dev/null
+++ b/api/services/retention/workflow_run/constants.py
@@ -0,0 +1,2 @@
+ARCHIVE_SCHEMA_VERSION = "1.0"
+ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip"
diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py
new file mode 100644
index 0000000000..11873bf1b9
--- /dev/null
+++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py
@@ -0,0 +1,134 @@
+"""
+Delete Archived Workflow Run Service.
+
+This service deletes archived workflow run data from the database while keeping
+archive logs intact.
+"""
+
+import time
+from collections.abc import Sequence
+from dataclasses import dataclass, field
+from datetime import datetime
+
+from sqlalchemy.orm import Session, sessionmaker
+
+from extensions.ext_database import db
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
+
+
+@dataclass
+class DeleteResult:
+ run_id: str
+ tenant_id: str
+ success: bool
+ deleted_counts: dict[str, int] = field(default_factory=dict)
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+class ArchivedWorkflowRunDeletion:
+ def __init__(self, dry_run: bool = False):
+ self.dry_run = dry_run
+ self.workflow_run_repo: APIWorkflowRunRepository | None = None
+
+ def delete_by_run_id(self, run_id: str) -> DeleteResult:
+ start_time = time.time()
+ result = DeleteResult(run_id=run_id, tenant_id="", success=False)
+
+ repo = self._get_workflow_run_repo()
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ with session_maker() as session:
+ run = session.get(WorkflowRun, run_id)
+ if not run:
+ result.error = f"Workflow run {run_id} not found"
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ result.tenant_id = run.tenant_id
+ if not repo.get_archived_run_ids(session, [run.id]):
+ result.error = f"Workflow run {run_id} is not archived"
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ result = self._delete_run(run)
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def delete_batch(
+ self,
+ tenant_ids: list[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int = 100,
+ ) -> list[DeleteResult]:
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ results: list[DeleteResult] = []
+
+ repo = self._get_workflow_run_repo()
+ with session_maker() as session:
+ runs = list(
+ repo.get_archived_runs_by_time_range(
+ session=session,
+ tenant_ids=tenant_ids,
+ start_date=start_date,
+ end_date=end_date,
+ limit=limit,
+ )
+ )
+ for run in runs:
+ results.append(self._delete_run(run))
+
+ return results
+
+ def _delete_run(self, run: WorkflowRun) -> DeleteResult:
+ start_time = time.time()
+ result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
+ if self.dry_run:
+ result.success = True
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ repo = self._get_workflow_run_repo()
+ try:
+ deleted_counts = repo.delete_runs_with_related(
+ [run],
+ delete_node_executions=self._delete_node_executions,
+ delete_trigger_logs=self._delete_trigger_logs,
+ )
+ result.deleted_counts = deleted_counts
+ result.success = True
+ except Exception as e:
+ result.error = str(e)
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ @staticmethod
+ def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int:
+ trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
+ return trigger_repo.delete_by_run_ids(run_ids)
+
+ @staticmethod
+ def _delete_node_executions(
+ session: Session,
+ runs: Sequence[WorkflowRun],
+ ) -> tuple[int, int]:
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ run_ids = [run.id for run in runs]
+ repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
+ )
+ return repo.delete_by_runs(session, run_ids)
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ from repositories.factory import DifyAPIRepositoryFactory
+
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
+ sessionmaker(bind=db.engine, expire_on_commit=False)
+ )
+ return self.workflow_run_repo
diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py
new file mode 100644
index 0000000000..d4a6e87585
--- /dev/null
+++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py
@@ -0,0 +1,481 @@
+"""
+Restore Archived Workflow Run Service.
+
+This service restores archived workflow run data from S3-compatible storage
+back to the database.
+"""
+
+import io
+import json
+import logging
+import time
+import zipfile
+from collections.abc import Callable
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Any, cast
+
+import click
+from sqlalchemy.dialects.postgresql import insert as pg_insert
+from sqlalchemy.engine import CursorResult
+from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
+
+from extensions.ext_database import db
+from libs.archive_storage import (
+ ArchiveStorage,
+ ArchiveStorageNotConfiguredError,
+ get_archive_storage,
+)
+from models.trigger import WorkflowTriggerLog
+from models.workflow import (
+ WorkflowAppLog,
+ WorkflowArchiveLog,
+ WorkflowNodeExecutionModel,
+ WorkflowNodeExecutionOffload,
+ WorkflowPause,
+ WorkflowPauseReason,
+ WorkflowRun,
+)
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+from repositories.factory import DifyAPIRepositoryFactory
+from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
+
+logger = logging.getLogger(__name__)
+
+
+# Mapping of table names to SQLAlchemy models
+TABLE_MODELS = {
+ "workflow_runs": WorkflowRun,
+ "workflow_app_logs": WorkflowAppLog,
+ "workflow_node_executions": WorkflowNodeExecutionModel,
+ "workflow_node_execution_offload": WorkflowNodeExecutionOffload,
+ "workflow_pauses": WorkflowPause,
+ "workflow_pause_reasons": WorkflowPauseReason,
+ "workflow_trigger_logs": WorkflowTriggerLog,
+}
+
+SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]]
+
+SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = {
+ "1.0": {},
+}
+
+
+@dataclass
+class RestoreResult:
+ """Result of restoring a single workflow run."""
+
+ run_id: str
+ tenant_id: str
+ success: bool
+ restored_counts: dict[str, int]
+ error: str | None = None
+ elapsed_time: float = 0.0
+
+
+class WorkflowRunRestore:
+ """
+ Restore archived workflow run data from storage to database.
+
+ This service reads archived data from storage and restores it to the
+ database tables. It handles idempotency by skipping records that already
+ exist in the database.
+ """
+
+ def __init__(self, dry_run: bool = False, workers: int = 1):
+ """
+ Initialize the restore service.
+
+ Args:
+ dry_run: If True, only preview without making changes
+ workers: Number of concurrent workflow runs to restore
+ """
+ self.dry_run = dry_run
+ if workers < 1:
+ raise ValueError("workers must be at least 1")
+ self.workers = workers
+ self.workflow_run_repo: APIWorkflowRunRepository | None = None
+
+ def _restore_from_run(
+ self,
+ run: WorkflowRun | WorkflowArchiveLog,
+ *,
+ session_maker: sessionmaker,
+ ) -> RestoreResult:
+ start_time = time.time()
+ run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id
+ created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at
+ result = RestoreResult(
+ run_id=run_id,
+ tenant_id=run.tenant_id,
+ success=False,
+ restored_counts={},
+ )
+
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})",
+ fg="white",
+ )
+ )
+
+ try:
+ storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ result.error = str(e)
+ click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ prefix = (
+ f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
+ f"month={created_at.strftime('%m')}/workflow_run_id={run_id}"
+ )
+ archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
+ try:
+ archive_data = storage.get_object(archive_key)
+ except FileNotFoundError:
+ result.error = f"Archive bundle not found: {archive_key}"
+ click.echo(click.style(result.error, fg="red"))
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ with session_maker() as session:
+ try:
+ with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive:
+ try:
+ manifest = self._load_manifest_from_zip(archive)
+ except ValueError as e:
+ result.error = f"Archive bundle invalid: {e}"
+ click.echo(click.style(result.error, fg="red"))
+ return result
+
+ tables = manifest.get("tables", {})
+ schema_version = self._get_schema_version(manifest)
+ for table_name, info in tables.items():
+ row_count = info.get("row_count", 0)
+ if row_count == 0:
+ result.restored_counts[table_name] = 0
+ continue
+
+ if self.dry_run:
+ result.restored_counts[table_name] = row_count
+ continue
+
+ member_path = f"{table_name}.jsonl"
+ try:
+ data = archive.read(member_path)
+ except KeyError:
+ click.echo(
+ click.style(
+ f" Warning: Table data not found in archive: {member_path}",
+ fg="yellow",
+ )
+ )
+ result.restored_counts[table_name] = 0
+ continue
+
+ records = ArchiveStorage.deserialize_from_jsonl(data)
+ restored = self._restore_table_records(
+ session,
+ table_name,
+ records,
+ schema_version=schema_version,
+ )
+ result.restored_counts[table_name] = restored
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f" Restored {restored}/{len(records)} records to {table_name}",
+ fg="white",
+ )
+ )
+
+ # Verify row counts match manifest
+ manifest_total = sum(info.get("row_count", 0) for info in tables.values())
+ restored_total = sum(result.restored_counts.values())
+
+ if not self.dry_run:
+ # Note: restored count might be less than manifest count if records already exist
+ logger.info(
+ "Restore verification: manifest_total=%d, restored_total=%d",
+ manifest_total,
+ restored_total,
+ )
+
+ # Delete the archive log record after successful restore
+ repo = self._get_workflow_run_repo()
+ repo.delete_archive_log_by_run_id(session, run_id)
+
+ session.commit()
+
+ result.success = True
+ if not self.dry_run:
+ click.echo(
+ click.style(
+ f"Completed restore for workflow run {run_id}: restored={result.restored_counts}",
+ fg="green",
+ )
+ )
+
+ except Exception as e:
+ logger.exception("Failed to restore workflow run %s", run_id)
+ result.error = str(e)
+ session.rollback()
+ click.echo(click.style(f"Restore failed: {e}", fg="red"))
+
+ result.elapsed_time = time.time() - start_time
+ return result
+
+ def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
+ if self.workflow_run_repo is not None:
+ return self.workflow_run_repo
+
+ self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
+ sessionmaker(bind=db.engine, expire_on_commit=False)
+ )
+ return self.workflow_run_repo
+
+ @staticmethod
+ def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
+ try:
+ data = archive.read("manifest.json")
+ except KeyError as e:
+ raise ValueError("manifest.json missing from archive bundle") from e
+ return json.loads(data.decode("utf-8"))
+
+ def _restore_table_records(
+ self,
+ session: Session,
+ table_name: str,
+ records: list[dict[str, Any]],
+ *,
+ schema_version: str,
+ ) -> int:
+ """
+ Restore records to a table.
+
+ Uses INSERT ... ON CONFLICT DO NOTHING for idempotency.
+
+ Args:
+ session: Database session
+ table_name: Name of the table
+ records: List of record dictionaries
+ schema_version: Archived schema version from manifest
+
+ Returns:
+ Number of records actually inserted
+ """
+ if not records:
+ return 0
+
+ model = TABLE_MODELS.get(table_name)
+ if not model:
+ logger.warning("Unknown table: %s", table_name)
+ return 0
+
+ column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model)
+ unknown_fields: set[str] = set()
+
+ # Apply schema mapping, filter to current columns, then convert datetimes
+ converted_records = []
+ for record in records:
+ mapped = self._apply_schema_mapping(table_name, schema_version, record)
+ unknown_fields.update(set(mapped.keys()) - column_names)
+ filtered = {key: value for key, value in mapped.items() if key in column_names}
+ for key in non_nullable_with_default:
+ if key in filtered and filtered[key] is None:
+ filtered.pop(key)
+ missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None]
+ if missing_required:
+ missing_cols = ", ".join(sorted(missing_required))
+ raise ValueError(
+ f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}"
+ )
+ converted = self._convert_datetime_fields(filtered, model)
+ converted_records.append(converted)
+ if unknown_fields:
+ logger.warning(
+ "Dropped unknown columns for %s (schema_version=%s): %s",
+ table_name,
+ schema_version,
+ ", ".join(sorted(unknown_fields)),
+ )
+
+ # Use INSERT ... ON CONFLICT DO NOTHING for idempotency
+ stmt = pg_insert(model).values(converted_records)
+ stmt = stmt.on_conflict_do_nothing(index_elements=["id"])
+
+ result = session.execute(stmt)
+ return cast(CursorResult, result).rowcount or 0
+
+ def _convert_datetime_fields(
+ self,
+ record: dict[str, Any],
+ model: type[DeclarativeBase] | Any,
+ ) -> dict[str, Any]:
+ """Convert ISO datetime strings to datetime objects."""
+ from sqlalchemy import DateTime
+
+ result = dict(record)
+
+ for column in model.__table__.columns:
+ if isinstance(column.type, DateTime):
+ value = result.get(column.key)
+ if isinstance(value, str):
+ try:
+ result[column.key] = datetime.fromisoformat(value)
+ except ValueError:
+ pass
+
+ return result
+
+ def _get_schema_version(self, manifest: dict[str, Any]) -> str:
+ schema_version = manifest.get("schema_version")
+ if not schema_version:
+ logger.warning("Manifest missing schema_version; defaulting to 1.0")
+ schema_version = "1.0"
+ schema_version = str(schema_version)
+ if schema_version not in SCHEMA_MAPPERS:
+ raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.")
+ return schema_version
+
+ def _apply_schema_mapping(
+ self,
+ table_name: str,
+ schema_version: str,
+ record: dict[str, Any],
+ ) -> dict[str, Any]:
+ # Keep hook for forward/backward compatibility when schema evolves.
+ mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name)
+ if mapper is None:
+ return dict(record)
+ return mapper(record)
+
+ def _get_model_column_info(
+ self,
+ model: type[DeclarativeBase] | Any,
+ ) -> tuple[set[str], set[str], set[str]]:
+ columns = list(model.__table__.columns)
+ column_names = {column.key for column in columns}
+ required_columns = {
+ column.key
+ for column in columns
+ if not column.nullable
+ and column.default is None
+ and column.server_default is None
+ and not column.autoincrement
+ }
+ non_nullable_with_default = {
+ column.key
+ for column in columns
+ if not column.nullable
+ and (column.default is not None or column.server_default is not None or column.autoincrement)
+ }
+ return column_names, required_columns, non_nullable_with_default
+
+ def restore_batch(
+ self,
+ tenant_ids: list[str] | None,
+ start_date: datetime,
+ end_date: datetime,
+ limit: int = 100,
+ ) -> list[RestoreResult]:
+ """
+ Restore multiple workflow runs by time range.
+
+ Args:
+ tenant_ids: Optional tenant IDs
+ start_date: Start date filter
+ end_date: End date filter
+ limit: Maximum number of runs to restore (default: 100)
+
+ Returns:
+ List of RestoreResult objects
+ """
+ results: list[RestoreResult] = []
+ if tenant_ids is not None and not tenant_ids:
+ return results
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = self._get_workflow_run_repo()
+
+ with session_maker() as session:
+ archive_logs = repo.get_archived_logs_by_time_range(
+ session=session,
+ tenant_ids=tenant_ids,
+ start_date=start_date,
+ end_date=end_date,
+ limit=limit,
+ )
+
+ click.echo(
+ click.style(
+ f"Found {len(archive_logs)} archived workflow runs to restore",
+ fg="white",
+ )
+ )
+
+ def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult:
+ return self._restore_from_run(
+ archive_log,
+ session_maker=session_maker,
+ )
+
+ with ThreadPoolExecutor(max_workers=self.workers) as executor:
+ results = list(executor.map(_restore_with_session, archive_logs))
+
+ total_counts: dict[str, int] = {}
+ for result in results:
+ for table_name, count in result.restored_counts.items():
+ total_counts[table_name] = total_counts.get(table_name, 0) + count
+ success_count = sum(1 for result in results if result.success)
+
+ if self.dry_run:
+ click.echo(
+ click.style(
+ f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}",
+ fg="yellow",
+ )
+ )
+ else:
+ click.echo(
+ click.style(
+ f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}",
+ fg="green",
+ )
+ )
+
+ return results
+
+ def restore_by_run_id(
+ self,
+ run_id: str,
+ ) -> RestoreResult:
+ """
+ Restore a single workflow run by run ID.
+ """
+ repo = self._get_workflow_run_repo()
+ archive_log = repo.get_archived_log_by_run_id(run_id)
+
+ if not archive_log:
+ click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red"))
+ return RestoreResult(
+ run_id=run_id,
+ tenant_id="",
+ success=False,
+ restored_counts={},
+ error=f"Workflow run archive {run_id} not found",
+ )
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ result = self._restore_from_run(archive_log, session_maker=session_maker)
+ if self.dry_run and result.success:
+ click.echo(
+ click.style(
+ f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}",
+ fg="yellow",
+ )
+ )
+ return result
diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py
new file mode 100644
index 0000000000..b8e1f8bc3f
--- /dev/null
+++ b/api/services/summary_index_service.py
@@ -0,0 +1,1432 @@
+"""Summary index service for generating and managing document segment summaries."""
+
+import logging
+import time
+import uuid
+from datetime import UTC, datetime
+from typing import Any
+
+from sqlalchemy.orm import Session
+
+from core.db.session_factory import session_factory
+from core.model_manager import ModelManager
+from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.index_processor.constant.doc_type import DocType
+from core.rag.models.document import Document
+from libs import helper
+from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
+from models.dataset import Document as DatasetDocument
+
+logger = logging.getLogger(__name__)
+
+
+class SummaryIndexService:
+ """Service for generating and managing summary indexes."""
+
+ @staticmethod
+ def generate_summary_for_segment(
+ segment: DocumentSegment,
+ dataset: Dataset,
+ summary_index_setting: dict,
+ ) -> tuple[str, LLMUsage]:
+ """
+ Generate summary for a single segment.
+
+ Args:
+ segment: DocumentSegment to generate summary for
+ dataset: Dataset containing the segment
+ summary_index_setting: Summary index configuration
+
+ Returns:
+ Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
+
+ Raises:
+ ValueError: If summary_index_setting is invalid or generation fails
+ """
+ # Reuse the existing generate_summary method from ParagraphIndexProcessor
+ # Use lazy import to avoid circular import
+ from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
+
+ summary_content, usage = ParagraphIndexProcessor.generate_summary(
+ tenant_id=dataset.tenant_id,
+ text=segment.content,
+ summary_index_setting=summary_index_setting,
+ segment_id=segment.id,
+ )
+
+ if not summary_content:
+ raise ValueError("Generated summary is empty")
+
+ return summary_content, usage
+
+ @staticmethod
+ def create_summary_record(
+ segment: DocumentSegment,
+ dataset: Dataset,
+ summary_content: str,
+ status: str = "generating",
+ ) -> DocumentSegmentSummary:
+ """
+ Create or update a DocumentSegmentSummary record.
+ If a summary record already exists for this segment, it will be updated instead of creating a new one.
+
+ Args:
+ segment: DocumentSegment to create summary for
+ dataset: Dataset containing the segment
+ summary_content: Generated summary content
+ status: Summary status (default: "generating")
+
+ Returns:
+ Created or updated DocumentSegmentSummary instance
+ """
+ with session_factory.create_session() as session:
+ # Check if summary record already exists
+ existing_summary = (
+ session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
+ )
+
+ if existing_summary:
+ # Update existing record
+ existing_summary.summary_content = summary_content
+ existing_summary.status = status
+ existing_summary.error = None # type: ignore[assignment] # Clear any previous errors
+ # Re-enable if it was disabled
+ if not existing_summary.enabled:
+ existing_summary.enabled = True
+ existing_summary.disabled_at = None
+ existing_summary.disabled_by = None
+ session.add(existing_summary)
+ session.flush()
+ return existing_summary
+ else:
+ # Create new record (enabled by default)
+ summary_record = DocumentSegmentSummary(
+ dataset_id=dataset.id,
+ document_id=segment.document_id,
+ chunk_id=segment.id,
+ summary_content=summary_content,
+ status=status,
+ enabled=True, # Explicitly set enabled to True
+ )
+ session.add(summary_record)
+ session.flush()
+ return summary_record
+
+ @staticmethod
+ def vectorize_summary(
+ summary_record: DocumentSegmentSummary,
+ segment: DocumentSegment,
+ dataset: Dataset,
+ session: Session | None = None,
+ ) -> None:
+ """
+ Vectorize summary and store in vector database.
+
+ Args:
+ summary_record: DocumentSegmentSummary record
+ segment: Original DocumentSegment
+ dataset: Dataset containing the segment
+ session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one.
+ If not provided, creates a new session and commits automatically.
+ """
+ if dataset.indexing_technique != "high_quality":
+ logger.warning(
+ "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
+ dataset.id,
+ )
+ return
+
+ # Get summary_record_id for later session queries
+ summary_record_id = summary_record.id
+ # Save the original session parameter for use in error handling
+ original_session = session
+ logger.debug(
+ "Starting vectorization for segment %s, summary_record_id=%s, using_provided_session=%s",
+ segment.id,
+ summary_record_id,
+ original_session is not None,
+ )
+
+ # Reuse existing index_node_id if available (like segment does), otherwise generate new one
+ old_summary_node_id = summary_record.summary_index_node_id
+ if old_summary_node_id:
+ # Reuse existing index_node_id (like segment behavior)
+ summary_index_node_id = old_summary_node_id
+ logger.debug("Reusing existing index_node_id %s for segment %s", summary_index_node_id, segment.id)
+ else:
+ # Generate new index node ID only for new summaries
+ summary_index_node_id = str(uuid.uuid4())
+ logger.debug("Generated new index_node_id %s for segment %s", summary_index_node_id, segment.id)
+
+ # Always regenerate hash (in case summary content changed)
+ summary_content = summary_record.summary_content
+ if not summary_content or not summary_content.strip():
+ raise ValueError(f"Summary content is empty for segment {segment.id}, cannot vectorize")
+ summary_hash = helper.generate_text_hash(summary_content)
+
+ # Delete old vector only if we're reusing the same index_node_id (to overwrite)
+ # If index_node_id changed, the old vector should have been deleted elsewhere
+ if old_summary_node_id and old_summary_node_id == summary_index_node_id:
+ try:
+ vector = Vector(dataset)
+ vector.delete_by_ids([old_summary_node_id])
+ except Exception as e:
+ logger.warning(
+ "Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.",
+ segment.id,
+ str(e),
+ )
+
+ # Calculate embedding tokens for summary (for logging and statistics)
+ embedding_tokens = 0
+ try:
+ model_manager = ModelManager()
+ embedding_model = model_manager.get_model_instance(
+ tenant_id=dataset.tenant_id,
+ provider=dataset.embedding_model_provider,
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=dataset.embedding_model,
+ )
+ if embedding_model:
+ tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content])
+ embedding_tokens = tokens_list[0] if tokens_list else 0
+ except Exception as e:
+ logger.warning("Failed to calculate embedding tokens for summary: %s", str(e))
+
+ # Create document with summary content and metadata
+ summary_document = Document(
+ page_content=summary_content,
+ metadata={
+ "doc_id": summary_index_node_id,
+ "doc_hash": summary_hash,
+ "dataset_id": dataset.id,
+ "document_id": segment.document_id,
+ "original_chunk_id": segment.id, # Key: link to original chunk
+ "doc_type": DocType.TEXT,
+ "is_summary": True, # Identifier for summary documents
+ },
+ )
+
+ # Vectorize and store with retry mechanism for connection errors
+ max_retries = 3
+ retry_delay = 2.0
+
+ for attempt in range(max_retries):
+ try:
+ logger.debug(
+ "Attempting to vectorize summary for segment %s (attempt %s/%s)",
+ segment.id,
+ attempt + 1,
+ max_retries,
+ )
+ vector = Vector(dataset)
+ # Use duplicate_check=False to ensure re-vectorization even if old vector still exists
+ # The old vector should have been deleted above, but if deletion failed,
+ # we still want to re-vectorize (upsert will overwrite)
+ vector.add_texts([summary_document], duplicate_check=False)
+ logger.debug(
+ "Successfully added summary vector to database for segment %s (attempt %s/%s)",
+ segment.id,
+ attempt + 1,
+ max_retries,
+ )
+
+ # Log embedding token usage
+ if embedding_tokens > 0:
+ logger.info(
+ "Summary embedding for segment %s used %s tokens",
+ segment.id,
+ embedding_tokens,
+ )
+
+ # Success - update summary record with index node info
+ # Use provided session if available, otherwise create a new one
+ use_provided_session = session is not None
+ if not use_provided_session:
+ logger.debug("Creating new session for vectorization of segment %s", segment.id)
+ session_context = session_factory.create_session()
+ session = session_context.__enter__()
+ else:
+ logger.debug("Using provided session for vectorization of segment %s", segment.id)
+ session_context = None # Don't use context manager for provided session
+
+ # At this point, session is guaranteed to be not None
+ # Type narrowing: session is definitely not None after the if/else above
+ if session is None:
+ raise RuntimeError("Session should not be None at this point")
+
+ try:
+ # Declare summary_record_in_session variable
+ summary_record_in_session: DocumentSegmentSummary | None
+
+ # If using provided session, merge the summary_record into it
+ if use_provided_session:
+ # Merge the summary_record into the provided session
+ logger.debug(
+ "Merging summary_record (id=%s) into provided session for segment %s",
+ summary_record_id,
+ segment.id,
+ )
+ summary_record_in_session = session.merge(summary_record)
+ logger.debug(
+ "Successfully merged summary_record for segment %s, merged_id=%s",
+ segment.id,
+ summary_record_in_session.id,
+ )
+ else:
+ # Query the summary record in the new session
+ logger.debug(
+ "Querying summary_record by id=%s for segment %s in new session",
+ summary_record_id,
+ segment.id,
+ )
+ summary_record_in_session = (
+ session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
+ )
+
+ if not summary_record_in_session:
+ # Record not found - try to find by chunk_id and dataset_id instead
+ logger.debug(
+ "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s "
+ "for segment %s",
+ summary_record_id,
+ segment.id,
+ dataset.id,
+ segment.id,
+ )
+ summary_record_in_session = (
+ session.query(DocumentSegmentSummary)
+ .filter_by(chunk_id=segment.id, dataset_id=dataset.id)
+ .first()
+ )
+
+ if not summary_record_in_session:
+ # Still not found - create a new one using the parameter data
+ logger.warning(
+ "Summary record not found in database for segment %s (id=%s), creating new one. "
+ "This may indicate a session isolation issue.",
+ segment.id,
+ summary_record_id,
+ )
+ summary_record_in_session = DocumentSegmentSummary(
+ id=summary_record_id, # Use the same ID if available
+ dataset_id=dataset.id,
+ document_id=segment.document_id,
+ chunk_id=segment.id,
+ summary_content=summary_content,
+ summary_index_node_id=summary_index_node_id,
+ summary_index_node_hash=summary_hash,
+ tokens=embedding_tokens,
+ status="completed",
+ enabled=True,
+ )
+ session.add(summary_record_in_session)
+ logger.info(
+ "Created new summary record (id=%s) for segment %s after vectorization",
+ summary_record_id,
+ segment.id,
+ )
+ else:
+ # Found by chunk_id - update it
+ logger.info(
+ "Found summary record for segment %s by chunk_id "
+ "(id mismatch: expected %s, found %s). "
+ "This may indicate the record was created in a different session.",
+ segment.id,
+ summary_record_id,
+ summary_record_in_session.id,
+ )
+ else:
+ logger.debug(
+ "Found summary_record (id=%s) for segment %s in new session",
+ summary_record_id,
+ segment.id,
+ )
+
+ # At this point, summary_record_in_session is guaranteed to be not None
+ if summary_record_in_session is None:
+ raise RuntimeError("summary_record_in_session should not be None at this point")
+
+ # Update all fields including summary_content
+ # Always use the summary_content from the parameter (which is the latest from outer session)
+ # rather than relying on what's in the database, in case outer session hasn't committed yet
+ summary_record_in_session.summary_index_node_id = summary_index_node_id
+ summary_record_in_session.summary_index_node_hash = summary_hash
+ summary_record_in_session.tokens = embedding_tokens # Save embedding tokens
+ summary_record_in_session.status = "completed"
+ # Ensure summary_content is preserved (use the latest from summary_record parameter)
+ # This is critical: use the parameter value, not the database value
+ summary_record_in_session.summary_content = summary_content
+ # Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed
+ summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ session.add(summary_record_in_session)
+
+ # Only commit if we created the session ourselves
+ if not use_provided_session:
+ logger.debug("Committing session for segment %s (self-created session)", segment.id)
+ session.commit()
+ logger.debug("Successfully committed session for segment %s", segment.id)
+ else:
+ # When using provided session, flush to ensure changes are written to database
+ # This prevents refresh() from overwriting our changes
+ logger.debug(
+ "Flushing session for segment %s (using provided session, caller will commit)",
+ segment.id,
+ )
+ session.flush()
+ logger.debug("Successfully flushed session for segment %s", segment.id)
+ # If using provided session, let the caller handle commit
+
+ logger.info(
+ "Successfully vectorized summary for segment %s, index_node_id=%s, index_node_hash=%s, "
+ "tokens=%s, summary_record_id=%s, use_provided_session=%s",
+ segment.id,
+ summary_index_node_id,
+ summary_hash,
+ embedding_tokens,
+ summary_record_in_session.id,
+ use_provided_session,
+ )
+ # Update the original object for consistency
+ summary_record.summary_index_node_id = summary_index_node_id
+ summary_record.summary_index_node_hash = summary_hash
+ summary_record.tokens = embedding_tokens
+ summary_record.status = "completed"
+ summary_record.summary_content = summary_content
+ if summary_record_in_session.updated_at:
+ summary_record.updated_at = summary_record_in_session.updated_at
+ finally:
+ # Only close session if we created it ourselves
+ if not use_provided_session and session_context:
+ session_context.__exit__(None, None, None)
+ # Success, exit function
+ return
+
+ except (ConnectionError, Exception) as e:
+ error_str = str(e).lower()
+ # Check if it's a connection-related error that might be transient
+ is_connection_error = any(
+ keyword in error_str
+ for keyword in [
+ "connection",
+ "disconnected",
+ "timeout",
+ "network",
+ "could not connect",
+ "server disconnected",
+ "weaviate",
+ ]
+ )
+
+ if is_connection_error and attempt < max_retries - 1:
+ # Retry for connection errors
+ wait_time = retry_delay * (2**attempt) # Exponential backoff
+ logger.warning(
+ "Vectorization attempt %s/%s failed for segment %s (connection error): %s. "
+ "Retrying in %.1f seconds...",
+ attempt + 1,
+ max_retries,
+ segment.id,
+ str(e),
+ wait_time,
+ )
+ time.sleep(wait_time)
+ continue
+ else:
+ # Final attempt failed or non-connection error - log and update status
+ logger.error(
+ "Failed to vectorize summary for segment %s after %s attempts: %s. "
+ "summary_record_id=%s, index_node_id=%s, use_provided_session=%s",
+ segment.id,
+ attempt + 1,
+ str(e),
+ summary_record_id,
+ summary_index_node_id,
+ session is not None,
+ exc_info=True,
+ )
+ # Update error status in session
+ # Use the original_session saved at function start (the function parameter)
+ logger.debug(
+ "Updating error status for segment %s, summary_record_id=%s, has_original_session=%s",
+ segment.id,
+ summary_record_id,
+ original_session is not None,
+ )
+ # Always create a new session for error handling to avoid issues with closed sessions
+ # Even if original_session was provided, we create a new one for safety
+ with session_factory.create_session() as error_session:
+ # Try to find the record by id first
+ # Note: Using assignment only (no type annotation) to avoid redeclaration error
+ summary_record_in_session = (
+ error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
+ )
+ if not summary_record_in_session:
+ # Try to find by chunk_id and dataset_id
+ logger.debug(
+ "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s "
+ "for segment %s",
+ summary_record_id,
+ segment.id,
+ dataset.id,
+ segment.id,
+ )
+ summary_record_in_session = (
+ error_session.query(DocumentSegmentSummary)
+ .filter_by(chunk_id=segment.id, dataset_id=dataset.id)
+ .first()
+ )
+
+ if summary_record_in_session:
+ summary_record_in_session.status = "error"
+ summary_record_in_session.error = f"Vectorization failed: {str(e)}"
+ summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ error_session.add(summary_record_in_session)
+ error_session.commit()
+ logger.info(
+ "Updated error status in new session for segment %s, record_id=%s",
+ segment.id,
+ summary_record_in_session.id,
+ )
+ # Update the original object for consistency
+ summary_record.status = "error"
+ summary_record.error = summary_record_in_session.error
+ summary_record.updated_at = summary_record_in_session.updated_at
+ else:
+ logger.warning(
+ "Could not update error status: summary record not found for segment %s (id=%s). "
+ "This may indicate a session isolation issue.",
+ segment.id,
+ summary_record_id,
+ )
+ raise
+
+ @staticmethod
+ def batch_create_summary_records(
+ segments: list[DocumentSegment],
+ dataset: Dataset,
+ status: str = "not_started",
+ ) -> None:
+ """
+ Batch create summary records for segments with specified status.
+ If a record already exists, update its status.
+
+ Args:
+ segments: List of DocumentSegment instances
+ dataset: Dataset containing the segments
+ status: Initial status for the records (default: "not_started")
+ """
+ segment_ids = [segment.id for segment in segments]
+ if not segment_ids:
+ return
+
+ with session_factory.create_session() as session:
+ # Query existing summary records
+ existing_summaries = (
+ session.query(DocumentSegmentSummary)
+ .filter(
+ DocumentSegmentSummary.chunk_id.in_(segment_ids),
+ DocumentSegmentSummary.dataset_id == dataset.id,
+ )
+ .all()
+ )
+ existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
+
+ # Create or update records
+ for segment in segments:
+ existing_summary = existing_summary_map.get(segment.id)
+ if existing_summary:
+ # Update existing record
+ existing_summary.status = status
+ existing_summary.error = None # type: ignore[assignment] # Clear any previous errors
+ if not existing_summary.enabled:
+ existing_summary.enabled = True
+ existing_summary.disabled_at = None
+ existing_summary.disabled_by = None
+ session.add(existing_summary)
+ else:
+ # Create new record
+ summary_record = DocumentSegmentSummary(
+ dataset_id=dataset.id,
+ document_id=segment.document_id,
+ chunk_id=segment.id,
+ summary_content=None, # Will be filled later
+ status=status,
+ enabled=True,
+ )
+ session.add(summary_record)
+
+ @staticmethod
+ def update_summary_record_error(
+ segment: DocumentSegment,
+ dataset: Dataset,
+ error: str,
+ ) -> None:
+ """
+ Update summary record with error status.
+
+ Args:
+ segment: DocumentSegment
+ dataset: Dataset containing the segment
+ error: Error message
+ """
+ with session_factory.create_session() as session:
+ summary_record = (
+ session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
+ )
+
+ if summary_record:
+ summary_record.status = "error"
+ summary_record.error = error
+ session.add(summary_record)
+ session.commit()
+ else:
+ logger.warning("Summary record not found for segment %s when updating error", segment.id)
+
+ @staticmethod
+ def generate_and_vectorize_summary(
+ segment: DocumentSegment,
+ dataset: Dataset,
+ summary_index_setting: dict,
+ ) -> DocumentSegmentSummary:
+ """
+ Generate summary for a segment and vectorize it.
+ Assumes summary record already exists (created by batch_create_summary_records).
+
+ Args:
+ segment: DocumentSegment to generate summary for
+ dataset: Dataset containing the segment
+ summary_index_setting: Summary index configuration
+
+ Returns:
+ Created DocumentSegmentSummary instance
+
+ Raises:
+ ValueError: If summary generation fails
+ """
+ with session_factory.create_session() as session:
+ try:
+ # Get or refresh summary record in this session
+ summary_record_in_session = (
+ session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
+ )
+
+ if not summary_record_in_session:
+ # If not found, create one
+ logger.warning("Summary record not found for segment %s, creating one", segment.id)
+ summary_record_in_session = DocumentSegmentSummary(
+ dataset_id=dataset.id,
+ document_id=segment.document_id,
+ chunk_id=segment.id,
+ summary_content="",
+ status="generating",
+ enabled=True,
+ )
+ session.add(summary_record_in_session)
+ session.flush()
+
+ # Update status to "generating"
+ summary_record_in_session.status = "generating"
+ summary_record_in_session.error = None # type: ignore[assignment]
+ session.add(summary_record_in_session)
+ # Don't flush here - wait until after vectorization succeeds
+
+ # Generate summary (returns summary_content and llm_usage)
+ summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment(
+ segment, dataset, summary_index_setting
+ )
+
+ # Update summary content
+ summary_record_in_session.summary_content = summary_content
+ session.add(summary_record_in_session)
+ # Flush to ensure summary_content is saved before vectorize_summary queries it
+ session.flush()
+
+ # Log LLM usage for summary generation
+ if llm_usage and llm_usage.total_tokens > 0:
+ logger.info(
+ "Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)",
+ segment.id,
+ llm_usage.total_tokens,
+ llm_usage.prompt_tokens,
+ llm_usage.completion_tokens,
+ )
+
+ # Vectorize summary (will delete old vector if exists before creating new one)
+ # Pass the session-managed record to vectorize_summary
+ # vectorize_summary will update status to "completed" and tokens in its own session
+ # vectorize_summary will also ensure summary_content is preserved
+ try:
+ # Pass the session to vectorize_summary to avoid session isolation issues
+ SummaryIndexService.vectorize_summary(summary_record_in_session, segment, dataset, session=session)
+ # Refresh the object from database to get the updated status and tokens from vectorize_summary
+ session.refresh(summary_record_in_session)
+ # Commit the session
+ # (summary_record_in_session should have status="completed" and tokens from refresh)
+ session.commit()
+ logger.info("Successfully generated and vectorized summary for segment %s", segment.id)
+ return summary_record_in_session
+ except Exception as vectorize_error:
+ # If vectorization fails, update status to error in current session
+ logger.exception("Failed to vectorize summary for segment %s", segment.id)
+ summary_record_in_session.status = "error"
+ summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}"
+ session.add(summary_record_in_session)
+ session.commit()
+ raise
+
+ except Exception as e:
+ logger.exception("Failed to generate summary for segment %s", segment.id)
+ # Update summary record with error status
+ summary_record_in_session = (
+ session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
+ )
+ if summary_record_in_session:
+ summary_record_in_session.status = "error"
+ summary_record_in_session.error = str(e)
+ session.add(summary_record_in_session)
+ session.commit()
+ raise
+
+ @staticmethod
+ def generate_summaries_for_document(
+ dataset: Dataset,
+ document: DatasetDocument,
+ summary_index_setting: dict,
+ segment_ids: list[str] | None = None,
+ only_parent_chunks: bool = False,
+ ) -> list[DocumentSegmentSummary]:
+ """
+ Generate summaries for all segments in a document including vectorization.
+
+ Args:
+ dataset: Dataset containing the document
+ document: DatasetDocument to generate summaries for
+ summary_index_setting: Summary index configuration
+ segment_ids: Optional list of specific segment IDs to process
+ only_parent_chunks: If True, only process parent chunks (for parent-child mode)
+
+ Returns:
+ List of created DocumentSegmentSummary instances
+ """
+ # Only generate summary index for high_quality indexing technique
+ if dataset.indexing_technique != "high_quality":
+ logger.info(
+ "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
+ dataset.id,
+ dataset.indexing_technique,
+ )
+ return []
+
+ if not summary_index_setting or not summary_index_setting.get("enable"):
+ logger.info("Summary index is disabled for dataset %s", dataset.id)
+ return []
+
+ # Skip qa_model documents
+ if document.doc_form == "qa_model":
+ logger.info("Skipping summary generation for qa_model document %s", document.id)
+ return []
+
+ logger.info(
+ "Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s",
+ document.id,
+ dataset.id,
+ len(segment_ids) if segment_ids else "all",
+ only_parent_chunks,
+ )
+
+ with session_factory.create_session() as session:
+ # Query segments (only enabled segments)
+ query = session.query(DocumentSegment).filter_by(
+ dataset_id=dataset.id,
+ document_id=document.id,
+ status="completed",
+ enabled=True, # Only generate summaries for enabled segments
+ )
+
+ if segment_ids:
+ query = query.filter(DocumentSegment.id.in_(segment_ids))
+
+ segments = query.all()
+
+ if not segments:
+ logger.info("No segments found for document %s", document.id)
+ return []
+
+ # Batch create summary records with "not_started" status before processing
+ # This ensures all records exist upfront, allowing status tracking
+ SummaryIndexService.batch_create_summary_records(
+ segments=segments,
+ dataset=dataset,
+ status="not_started",
+ )
+ session.commit() # Commit initial records
+
+ summary_records = []
+
+ for segment in segments:
+ # For parent-child mode, only process parent chunks
+ # In parent-child mode, all DocumentSegments are parent chunks,
+ # so we process all of them. Child chunks are stored in ChildChunk table
+ # and are not DocumentSegments, so they won't be in the segments list.
+ # This check is mainly for clarity and future-proofing.
+ if only_parent_chunks:
+ # In parent-child mode, all segments in the query are parent chunks
+ # Child chunks are not DocumentSegments, so they won't appear here
+ # We can process all segments
+ pass
+
+ try:
+ summary_record = SummaryIndexService.generate_and_vectorize_summary(
+ segment, dataset, summary_index_setting
+ )
+ summary_records.append(summary_record)
+ except Exception as e:
+ logger.exception("Failed to generate summary for segment %s", segment.id)
+ # Update summary record with error status
+ SummaryIndexService.update_summary_record_error(
+ segment=segment,
+ dataset=dataset,
+ error=str(e),
+ )
+ # Continue with other segments
+ continue
+
+ logger.info(
+ "Completed summary generation for document %s: %s summaries generated and vectorized",
+ document.id,
+ len(summary_records),
+ )
+ return summary_records
+
+ @staticmethod
+ def disable_summaries_for_segments(
+ dataset: Dataset,
+ segment_ids: list[str] | None = None,
+ disabled_by: str | None = None,
+ ) -> None:
+ """
+ Disable summary records and remove vectors from vector database for segments.
+ Unlike delete, this preserves the summary records but marks them as disabled.
+
+ Args:
+ dataset: Dataset containing the segments
+ segment_ids: List of segment IDs to disable summaries for. If None, disable all.
+ disabled_by: User ID who disabled the summaries
+ """
+ from libs.datetime_utils import naive_utc_now
+
+ with session_factory.create_session() as session:
+ query = session.query(DocumentSegmentSummary).filter_by(
+ dataset_id=dataset.id,
+ enabled=True, # Only disable enabled summaries
+ )
+
+ if segment_ids:
+ query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
+
+ summaries = query.all()
+
+ if not summaries:
+ return
+
+ logger.info(
+ "Disabling %s summary records for dataset %s, segment_ids: %s",
+ len(summaries),
+ dataset.id,
+ len(segment_ids) if segment_ids else "all",
+ )
+
+ # Remove from vector database (but keep records)
+ if dataset.indexing_technique == "high_quality":
+ summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
+ if summary_node_ids:
+ try:
+ vector = Vector(dataset)
+ vector.delete_by_ids(summary_node_ids)
+ except Exception as e:
+ logger.warning("Failed to remove summary vectors: %s", str(e))
+
+ # Disable summary records (don't delete)
+ now = naive_utc_now()
+ for summary in summaries:
+ summary.enabled = False
+ summary.disabled_at = now
+ summary.disabled_by = disabled_by
+ session.add(summary)
+
+ session.commit()
+ logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id)
+
+ @staticmethod
+ def enable_summaries_for_segments(
+ dataset: Dataset,
+ segment_ids: list[str] | None = None,
+ ) -> None:
+ """
+ Enable summary records and re-add vectors to vector database for segments.
+
+ Note: This method enables summaries based on chunk status, not summary_index_setting.enable.
+ The summary_index_setting.enable flag only controls automatic generation,
+ not whether existing summaries can be used.
+ Summary.enabled should always be kept in sync with chunk.enabled.
+
+ Args:
+ dataset: Dataset containing the segments
+ segment_ids: List of segment IDs to enable summaries for. If None, enable all.
+ """
+ # Only enable summary index for high_quality indexing technique
+ if dataset.indexing_technique != "high_quality":
+ return
+
+ with session_factory.create_session() as session:
+ query = session.query(DocumentSegmentSummary).filter_by(
+ dataset_id=dataset.id,
+ enabled=False, # Only enable disabled summaries
+ )
+
+ if segment_ids:
+ query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
+
+ summaries = query.all()
+
+ if not summaries:
+ return
+
+ logger.info(
+ "Enabling %s summary records for dataset %s, segment_ids: %s",
+ len(summaries),
+ dataset.id,
+ len(segment_ids) if segment_ids else "all",
+ )
+
+ # Re-vectorize and re-add to vector database
+ enabled_count = 0
+ for summary in summaries:
+ # Get the original segment
+ segment = (
+ session.query(DocumentSegment)
+ .filter_by(
+ id=summary.chunk_id,
+ dataset_id=dataset.id,
+ )
+ .first()
+ )
+
+ # Summary.enabled stays in sync with chunk.enabled,
+ # only enable summary if the associated chunk is enabled.
+ if not segment or not segment.enabled or segment.status != "completed":
+ continue
+
+ if not summary.summary_content:
+ continue
+
+ try:
+ # Re-vectorize summary (this will update status and tokens in its own session)
+ # Pass the session to vectorize_summary to avoid session isolation issues
+ SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session)
+
+ # Refresh the object from database to get the updated status and tokens from vectorize_summary
+ session.refresh(summary)
+
+ # Enable summary record
+ summary.enabled = True
+ summary.disabled_at = None
+ summary.disabled_by = None
+ session.add(summary)
+ enabled_count += 1
+ except Exception:
+ logger.exception("Failed to re-vectorize summary %s", summary.id)
+ # Keep it disabled if vectorization fails
+ continue
+
+ session.commit()
+ logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id)
+
+ @staticmethod
+ def delete_summaries_for_segments(
+ dataset: Dataset,
+ segment_ids: list[str] | None = None,
+ ) -> None:
+ """
+ Delete summary records and vectors for segments (used only for actual deletion scenarios).
+ For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments.
+
+ Args:
+ dataset: Dataset containing the segments
+ segment_ids: List of segment IDs to delete summaries for. If None, delete all.
+ """
+ with session_factory.create_session() as session:
+ query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
+
+ if segment_ids:
+ query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
+
+ summaries = query.all()
+
+ if not summaries:
+ return
+
+ # Delete from vector database
+ if dataset.indexing_technique == "high_quality":
+ summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
+ if summary_node_ids:
+ vector = Vector(dataset)
+ vector.delete_by_ids(summary_node_ids)
+
+ # Delete summary records
+ for summary in summaries:
+ session.delete(summary)
+
+ session.commit()
+ logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id)
+
+ @staticmethod
+ def update_summary_for_segment(
+ segment: DocumentSegment,
+ dataset: Dataset,
+ summary_content: str,
+ ) -> DocumentSegmentSummary | None:
+ """
+ Update summary for a segment and re-vectorize it.
+
+ Args:
+ segment: DocumentSegment to update summary for
+ dataset: Dataset containing the segment
+ summary_content: New summary content
+
+ Returns:
+ Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality
+ """
+ # Only update summary index for high_quality indexing technique
+ if dataset.indexing_technique != "high_quality":
+ return None
+
+ # When user manually provides summary, allow saving even if summary_index_setting doesn't exist
+ # summary_index_setting is only needed for LLM generation, not for manual summary vectorization
+ # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
+
+ # Skip qa_model documents
+ if segment.document and segment.document.doc_form == "qa_model":
+ return None
+
+ with session_factory.create_session() as session:
+ try:
+ # Check if summary_content is empty (whitespace-only strings are considered empty)
+ if not summary_content or not summary_content.strip():
+ # If summary is empty, only delete existing summary vector and record
+ summary_record = (
+ session.query(DocumentSegmentSummary)
+ .filter_by(chunk_id=segment.id, dataset_id=dataset.id)
+ .first()
+ )
+
+ if summary_record:
+ # Delete old vector if exists
+ old_summary_node_id = summary_record.summary_index_node_id
+ if old_summary_node_id:
+ try:
+ vector = Vector(dataset)
+ vector.delete_by_ids([old_summary_node_id])
+ except Exception as e:
+ logger.warning(
+ "Failed to delete old summary vector for segment %s: %s",
+ segment.id,
+ str(e),
+ )
+
+ # Delete summary record since summary is empty
+ session.delete(summary_record)
+ session.commit()
+ logger.info("Deleted summary for segment %s (empty content provided)", segment.id)
+ return None
+ else:
+ # No existing summary record, nothing to do
+ logger.info("No summary record found for segment %s, nothing to delete", segment.id)
+ return None
+
+ # Find existing summary record
+ summary_record = (
+ session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
+ )
+
+ if summary_record:
+ # Update existing summary
+ old_summary_node_id = summary_record.summary_index_node_id
+
+ # Update summary content
+ summary_record.summary_content = summary_content
+ summary_record.status = "generating"
+ summary_record.error = None # type: ignore[assignment] # Clear any previous errors
+ session.add(summary_record)
+ # Flush to ensure summary_content is saved before vectorize_summary queries it
+ session.flush()
+
+ # Delete old vector if exists (before vectorization)
+ if old_summary_node_id:
+ try:
+ vector = Vector(dataset)
+ vector.delete_by_ids([old_summary_node_id])
+ except Exception as e:
+ logger.warning(
+ "Failed to delete old summary vector for segment %s: %s",
+ segment.id,
+ str(e),
+ )
+
+ # Re-vectorize summary (this will update status to "completed" and tokens in its own session)
+ # vectorize_summary will also ensure summary_content is preserved
+ # Note: vectorize_summary may take time due to embedding API calls, but we need to complete it
+ # to ensure the summary is properly indexed
+ try:
+ # Pass the session to vectorize_summary to avoid session isolation issues
+ SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session)
+ # Refresh the object from database to get the updated status and tokens from vectorize_summary
+ session.refresh(summary_record)
+ # Now commit the session (summary_record should have status="completed" and tokens from refresh)
+ session.commit()
+ logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id)
+ return summary_record
+ except Exception as e:
+ # If vectorization fails, update status to error in current session
+ # Don't raise the exception - just log it and return the record with error status
+ # This allows the segment update to complete even if vectorization fails
+ summary_record.status = "error"
+ summary_record.error = f"Vectorization failed: {str(e)}"
+ session.commit()
+ logger.exception("Failed to vectorize summary for segment %s", segment.id)
+ # Return the record with error status instead of raising
+ # The caller can check the status if needed
+ return summary_record
+ else:
+ # Create new summary record if doesn't exist
+ summary_record = SummaryIndexService.create_summary_record(
+ segment, dataset, summary_content, status="generating"
+ )
+ # Re-vectorize summary (this will update status to "completed" and tokens in its own session)
+ # Note: summary_record was created in a different session,
+ # so we need to merge it into current session
+ try:
+ # Merge the record into current session first (since it was created in a different session)
+ summary_record = session.merge(summary_record)
+ # Pass the session to vectorize_summary - it will update the merged record
+ SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session)
+ # Refresh to get updated status and tokens from database
+ session.refresh(summary_record)
+ # Commit the session to persist the changes
+ session.commit()
+ logger.info("Successfully created and vectorized summary for segment %s", segment.id)
+ return summary_record
+ except Exception as e:
+ # If vectorization fails, update status to error in current session
+ # Merge the record into current session first
+ error_record = session.merge(summary_record)
+ error_record.status = "error"
+ error_record.error = f"Vectorization failed: {str(e)}"
+ session.commit()
+ logger.exception("Failed to vectorize summary for segment %s", segment.id)
+ # Return the record with error status instead of raising
+ return error_record
+
+ except Exception as e:
+ logger.exception("Failed to update summary for segment %s", segment.id)
+ # Update summary record with error status if it exists
+ summary_record = (
+ session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
+ )
+ if summary_record:
+ summary_record.status = "error"
+ summary_record.error = str(e)
+ session.add(summary_record)
+ session.commit()
+ raise
+
+ @staticmethod
+ def get_segment_summary(segment_id: str, dataset_id: str) -> DocumentSegmentSummary | None:
+ """
+ Get summary for a single segment.
+
+ Args:
+ segment_id: Segment ID (chunk_id)
+ dataset_id: Dataset ID
+
+ Returns:
+ DocumentSegmentSummary instance if found, None otherwise
+ """
+ with session_factory.create_session() as session:
+ return (
+ session.query(DocumentSegmentSummary)
+ .where(
+ DocumentSegmentSummary.chunk_id == segment_id,
+ DocumentSegmentSummary.dataset_id == dataset_id,
+ DocumentSegmentSummary.enabled == True, # Only return enabled summaries
+ )
+ .first()
+ )
+
+ @staticmethod
+ def get_segments_summaries(segment_ids: list[str], dataset_id: str) -> dict[str, DocumentSegmentSummary]:
+ """
+ Get summaries for multiple segments.
+
+ Args:
+ segment_ids: List of segment IDs (chunk_ids)
+ dataset_id: Dataset ID
+
+ Returns:
+ Dictionary mapping segment_id to DocumentSegmentSummary (only enabled summaries)
+ """
+ if not segment_ids:
+ return {}
+
+ with session_factory.create_session() as session:
+ summary_records = (
+ session.query(DocumentSegmentSummary)
+ .where(
+ DocumentSegmentSummary.chunk_id.in_(segment_ids),
+ DocumentSegmentSummary.dataset_id == dataset_id,
+ DocumentSegmentSummary.enabled == True, # Only return enabled summaries
+ )
+ .all()
+ )
+
+ return {summary.chunk_id: summary for summary in summary_records}
+
+ @staticmethod
+ def get_document_summaries(
+ document_id: str, dataset_id: str, segment_ids: list[str] | None = None
+ ) -> list[DocumentSegmentSummary]:
+ """
+ Get all summary records for a document.
+
+ Args:
+ document_id: Document ID
+ dataset_id: Dataset ID
+ segment_ids: Optional list of segment IDs to filter by
+
+ Returns:
+ List of DocumentSegmentSummary instances (only enabled summaries)
+ """
+ with session_factory.create_session() as session:
+ query = session.query(DocumentSegmentSummary).filter(
+ DocumentSegmentSummary.document_id == document_id,
+ DocumentSegmentSummary.dataset_id == dataset_id,
+ DocumentSegmentSummary.enabled == True, # Only return enabled summaries
+ )
+
+ if segment_ids:
+ query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
+
+ return query.all()
+
+ @staticmethod
+ def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None:
+ """
+ Get summary_index_status for a single document.
+
+ Args:
+ document_id: Document ID
+ dataset_id: Dataset ID
+ tenant_id: Tenant ID
+
+ Returns:
+ "SUMMARIZING" if there are pending summaries, None otherwise
+ """
+ # Get all segments for this document (excluding qa_model and re_segment)
+ with session_factory.create_session() as session:
+ segments = (
+ session.query(DocumentSegment.id)
+ .where(
+ DocumentSegment.document_id == document_id,
+ DocumentSegment.status != "re_segment",
+ DocumentSegment.tenant_id == tenant_id,
+ )
+ .all()
+ )
+ segment_ids = [seg.id for seg in segments]
+
+ if not segment_ids:
+ return None
+
+ # Get all summary records for these segments
+ summaries = SummaryIndexService.get_segments_summaries(segment_ids, dataset_id)
+ summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()}
+
+ # Check if there are any "not_started" or "generating" status summaries
+ has_pending_summaries = any(
+ summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
+ and summary_status_map[segment_id] in ("not_started", "generating")
+ for segment_id in segment_ids
+ )
+
+ return "SUMMARIZING" if has_pending_summaries else None
+
+ @staticmethod
+ def get_documents_summary_index_status(
+ document_ids: list[str], dataset_id: str, tenant_id: str
+ ) -> dict[str, str | None]:
+ """
+ Get summary_index_status for multiple documents.
+
+ Args:
+ document_ids: List of document IDs
+ dataset_id: Dataset ID
+ tenant_id: Tenant ID
+
+ Returns:
+ Dictionary mapping document_id to summary_index_status ("SUMMARIZING" or None)
+ """
+ if not document_ids:
+ return {}
+
+ # Get all segments for these documents (excluding qa_model and re_segment)
+ with session_factory.create_session() as session:
+ segments = (
+ session.query(DocumentSegment.id, DocumentSegment.document_id)
+ .where(
+ DocumentSegment.document_id.in_(document_ids),
+ DocumentSegment.status != "re_segment",
+ DocumentSegment.tenant_id == tenant_id,
+ )
+ .all()
+ )
+
+ # Group segments by document_id
+ document_segments_map: dict[str, list[str]] = {}
+ for segment in segments:
+ doc_id = str(segment.document_id)
+ if doc_id not in document_segments_map:
+ document_segments_map[doc_id] = []
+ document_segments_map[doc_id].append(segment.id)
+
+ # Get all summary records for these segments
+ all_segment_ids = [seg.id for seg in segments]
+ summaries = SummaryIndexService.get_segments_summaries(all_segment_ids, dataset_id)
+ summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()}
+
+ # Calculate summary_index_status for each document
+ result: dict[str, str | None] = {}
+ for doc_id in document_ids:
+ segment_ids = document_segments_map.get(doc_id, [])
+ if not segment_ids:
+ # No segments, status is None (not started)
+ result[doc_id] = None
+ continue
+
+ # Check if there are any "not_started" or "generating" status summaries
+ # Only check enabled=True summaries (already filtered in query)
+ # If segment has no summary record (summary_status_map.get returns None),
+ # it means the summary is disabled (enabled=False) or not created yet, ignore it
+ has_pending_summaries = any(
+ summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True)
+ and summary_status_map[segment_id] in ("not_started", "generating")
+ for segment_id in segment_ids
+ )
+
+ if has_pending_summaries:
+ # Task is still running (not started or generating)
+ result[doc_id] = "SUMMARIZING"
+ else:
+ # All enabled=True summaries are "completed" or "error", task finished
+ # Or no enabled=True summaries exist (all disabled)
+ result[doc_id] = None
+
+ return result
+
+ @staticmethod
+ def get_document_summary_status_detail(
+ document_id: str,
+ dataset_id: str,
+ ) -> dict[str, Any]:
+ """
+ Get detailed summary status for a document.
+
+ Args:
+ document_id: Document ID
+ dataset_id: Dataset ID
+
+ Returns:
+ Dictionary containing:
+ - total_segments: Total number of segments in the document
+ - summary_status: Dictionary with status counts
+ - completed: Number of summaries completed
+ - generating: Number of summaries being generated
+ - error: Number of summaries with errors
+ - not_started: Number of segments without summary records
+ - summaries: List of summary records with status and content preview
+ """
+ from services.dataset_service import SegmentService
+
+ # Get all segments for this document
+ segments = SegmentService.get_segments_by_document_and_dataset(
+ document_id=document_id,
+ dataset_id=dataset_id,
+ status="completed",
+ enabled=True,
+ )
+
+ total_segments = len(segments)
+
+ # Get all summary records for these segments
+ segment_ids = [segment.id for segment in segments]
+ summaries = []
+ if segment_ids:
+ summaries = SummaryIndexService.get_document_summaries(
+ document_id=document_id,
+ dataset_id=dataset_id,
+ segment_ids=segment_ids,
+ )
+
+ # Create a mapping of chunk_id to summary
+ summary_map = {summary.chunk_id: summary for summary in summaries}
+
+ # Count statuses
+ status_counts = {
+ "completed": 0,
+ "generating": 0,
+ "error": 0,
+ "not_started": 0,
+ }
+
+ summary_list = []
+ for segment in segments:
+ summary = summary_map.get(segment.id)
+ if summary:
+ status = summary.status
+ status_counts[status] = status_counts.get(status, 0) + 1
+ summary_list.append(
+ {
+ "segment_id": segment.id,
+ "segment_position": segment.position,
+ "status": summary.status,
+ "summary_preview": (
+ summary.summary_content[:100] + "..."
+ if summary.summary_content and len(summary.summary_content) > 100
+ else summary.summary_content
+ ),
+ "error": summary.error,
+ "created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
+ "updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
+ }
+ )
+ else:
+ status_counts["not_started"] += 1
+ summary_list.append(
+ {
+ "segment_id": segment.id,
+ "segment_position": segment.position,
+ "status": "not_started",
+ "summary_preview": None,
+ "error": None,
+ "created_at": None,
+ "updated_at": None,
+ }
+ )
+
+ return {
+ "total_segments": total_segments,
+ "summary_status": status_counts,
+ "summaries": summary_list,
+ }
diff --git a/api/services/tag_service.py b/api/services/tag_service.py
index 937e6593fe..56f4ae9494 100644
--- a/api/services/tag_service.py
+++ b/api/services/tag_service.py
@@ -19,9 +19,12 @@ class TagService:
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
- query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
+ from libs.helper import escape_like_pattern
+
+ escaped_keyword = escape_like_pattern(keyword)
+ query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
- results: list = query.order_by(Tag.created_at.desc()).all()
+ results = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py
index b3b6e36346..c32157919b 100644
--- a/api/services/tools/api_tools_manage_service.py
+++ b/api/services/tools/api_tools_manage_service.py
@@ -7,7 +7,6 @@ from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
@@ -86,7 +85,9 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
- def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
+ def convert_schema_to_tool_bundles(
+ schema: str, extra_info: dict | None = None
+ ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
convert schema to tool bundles
@@ -104,7 +105,7 @@ class ApiToolManageService:
provider_name: str,
icon: dict,
credentials: dict,
- schema_type: str,
+ schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
@@ -113,9 +114,6 @@ class ApiToolManageService:
"""
create api tool provider
"""
- if schema_type not in [member.value for member in ApiProviderSchemaType]:
- raise ValueError(f"invalid schema type {schema}")
-
provider_name = provider_name.strip()
# check if the provider exists
@@ -178,9 +176,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -245,18 +240,15 @@ class ApiToolManageService:
original_provider: str,
icon: dict,
credentials: dict,
- schema_type: str,
+ _schema_type: ApiProviderSchemaType,
schema: str,
- privacy_policy: str,
+ privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
- if schema_type not in [member.value for member in ApiProviderSchemaType]:
- raise ValueError(f"invalid schema type {schema}")
-
provider_name = provider_name.strip()
# check if the provider exists
@@ -281,7 +273,7 @@ class ApiToolManageService:
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
- provider.schema_type_str = ApiProviderSchemaType.OPENAPI
+ provider.schema_type_str = schema_type
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
@@ -322,9 +314,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -347,9 +336,6 @@ class ApiToolManageService:
db.session.delete(provider)
db.session.commit()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -366,7 +352,7 @@ class ApiToolManageService:
tool_name: str,
credentials: dict,
parameters: dict,
- schema_type: str,
+ schema_type: ApiProviderSchemaType,
schema: str,
):
"""
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index cf1d39fa25..6797a67dde 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@@ -205,9 +204,6 @@ class BuiltinToolManageService:
db_provider.name = name
session.commit()
-
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
@@ -286,12 +282,10 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
-
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
+
return {"result": "success"}
@staticmethod
@@ -409,9 +403,6 @@ class BuiltinToolManageService:
)
cache.delete()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@staticmethod
@@ -434,8 +425,6 @@ class BuiltinToolManageService:
target_provider.is_default = True
session.commit()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py
index d641fe0315..0be106f597 100644
--- a/api/services/tools/mcp_tools_manage_service.py
+++ b/api/services/tools/mcp_tools_manage_service.py
@@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
@@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
+class ProviderUrlValidationData(BaseModel):
+ """Data required for URL validation, extracted from database to perform network operations outside of session"""
+
+ current_server_url_hash: str
+ headers: dict[str, str]
+ timeout: float | None
+ sse_read_timeout: float | None
+
+
class MCPToolManageService:
"""Service class for managing MCP tools and providers."""
@@ -166,9 +174,6 @@ class MCPToolManageService:
self._session.add(mcp_tool)
self._session.flush()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
@@ -192,7 +197,7 @@ class MCPToolManageService:
Update an MCP provider.
Args:
- validation_result: Pre-validation result from validate_server_url_change.
+ validation_result: Pre-validation result from validate_server_url_standalone.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
@@ -251,8 +256,6 @@ class MCPToolManageService:
# Flush changes to database
self._session.flush()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier)
@@ -261,9 +264,6 @@ class MCPToolManageService:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:
@@ -319,8 +319,14 @@ class MCPToolManageService:
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
- # Update database with retrieved tools
- db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
+ # Update database with retrieved tools (ensure description is a non-null string)
+ tools_payload = []
+ for tool in tools:
+ data = tool.model_dump()
+ if data.get("description") is None:
+ data["description"] = ""
+ tools_payload.append(data)
+ db_provider.tools = json.dumps(tools_payload)
db_provider.authed = True
db_provider.updated_at = datetime.now()
self._session.flush()
@@ -546,30 +552,39 @@ class MCPToolManageService:
)
return self.execute_auth_actions(auth_result)
- def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
- """Attempt to reconnect to MCP provider with new server URL."""
+ def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
+ """
+ Get provider data required for URL validation.
+ This method performs database read and should be called within a session.
+
+ Returns:
+ ProviderUrlValidationData: Data needed for standalone URL validation
+ """
+ provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = provider.to_entity()
- headers = provider_entity.headers
+ return ProviderUrlValidationData(
+ current_server_url_hash=provider.server_url_hash,
+ headers=provider_entity.headers,
+ timeout=provider_entity.timeout,
+ sse_read_timeout=provider_entity.sse_read_timeout,
+ )
- try:
- tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
- return ReconnectResult(
- authed=True,
- tools=json.dumps([tool.model_dump() for tool in tools]),
- encrypted_credentials=EMPTY_CREDENTIALS_JSON,
- )
- except MCPAuthError:
- return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
- except MCPError as e:
- raise ValueError(f"Failed to re-connect MCP server: {e}") from e
-
- def validate_server_url_change(
- self, *, tenant_id: str, provider_id: str, new_server_url: str
+ @staticmethod
+ def validate_server_url_standalone(
+ *,
+ tenant_id: str,
+ new_server_url: str,
+ validation_data: ProviderUrlValidationData,
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
- This method should be called BEFORE update_provider to perform network operations
- outside of the database transaction.
+ This method performs network operations and MUST be called OUTSIDE of any database session
+ to avoid holding locks during network I/O.
+
+ Args:
+ tenant_id: Tenant ID for encryption
+ new_server_url: The new server URL to validate
+ validation_data: Provider data obtained from get_provider_for_url_validation
Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
@@ -579,25 +594,30 @@ class MCPToolManageService:
return ServerUrlValidationResult(needs_validation=False)
# Validate URL format
- if not self._is_valid_url(new_server_url):
+ parsed = urlparse(new_server_url)
+ if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
- # Get current provider
- provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
-
# Check if URL is actually different
- if new_server_url_hash == provider.server_url_hash:
+ if new_server_url_hash == validation_data.current_server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
- needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
+ needs_validation=False,
+ encrypted_server_url=encrypted_server_url,
+ server_url_hash=new_server_url_hash,
)
- # Perform validation by attempting to connect
- reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
+ # Perform network validation - this is the expensive operation that should be outside session
+ reconnect_result = MCPToolManageService._reconnect_with_url(
+ server_url=new_server_url,
+ headers=validation_data.headers,
+ timeout=validation_data.timeout,
+ sse_read_timeout=validation_data.sse_read_timeout,
+ )
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
@@ -606,6 +626,60 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
+ @staticmethod
+ def reconnect_with_url(
+ *,
+ server_url: str,
+ headers: dict[str, str],
+ timeout: float | None,
+ sse_read_timeout: float | None,
+ ) -> ReconnectResult:
+ return MCPToolManageService._reconnect_with_url(
+ server_url=server_url,
+ headers=headers,
+ timeout=timeout,
+ sse_read_timeout=sse_read_timeout,
+ )
+
+ @staticmethod
+ def _reconnect_with_url(
+ *,
+ server_url: str,
+ headers: dict[str, str],
+ timeout: float | None,
+ sse_read_timeout: float | None,
+ ) -> ReconnectResult:
+ """
+ Attempt to connect to MCP server with given URL.
+ This is a static method that performs network I/O without database access.
+ """
+ from core.mcp.mcp_client import MCPClient
+
+ try:
+ with MCPClient(
+ server_url=server_url,
+ headers=headers,
+ timeout=timeout,
+ sse_read_timeout=sse_read_timeout,
+ ) as mcp_client:
+ tools = mcp_client.list_tools()
+ # Ensure tool descriptions are non-null in payload
+ tools_payload = []
+ for t in tools:
+ d = t.model_dump()
+ if d.get("description") is None:
+ d["description"] = ""
+ tools_payload.append(d)
+ return ReconnectResult(
+ authed=True,
+ tools=json.dumps(tools_payload),
+ encrypted_credentials=EMPTY_CREDENTIALS_JSON,
+ )
+ except MCPAuthError:
+ return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
+ except MCPError as e:
+ raise ValueError(f"Failed to re-connect MCP server: {e}") from e
+
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:
diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py
index 038c462f15..51e9120b8d 100644
--- a/api/services/tools/tools_manage_service.py
+++ b/api/services/tools/tools_manage_service.py
@@ -1,6 +1,5 @@
import logging
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
@@ -16,14 +15,6 @@ class ToolCommonService:
:return: the list of tool providers
"""
- # Try to get from cache first
- cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
- if cached_result is not None:
- logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
- return cached_result
-
- # Cache miss - fetch from database
- logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon
@@ -32,7 +23,4 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers]
- # Cache the result
- ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
-
return result
diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py
index fe77ff2dc5..ab5d5480df 100644
--- a/api/services/tools/workflow_tools_manage_service.py
+++ b/api/services/tools/workflow_tools_manage_service.py
@@ -7,7 +7,6 @@ from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
-from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@@ -68,34 +67,31 @@ class WorkflowToolManageService:
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_app_id}")
- with Session(db.engine, expire_on_commit=False) as session, session.begin():
- workflow_tool_provider = WorkflowToolProvider(
- tenant_id=tenant_id,
- user_id=user_id,
- app_id=workflow_app_id,
- name=name,
- label=label,
- icon=json.dumps(icon),
- description=description,
- parameter_configuration=json.dumps(parameters),
- privacy_policy=privacy_policy,
- version=workflow.version,
- )
- session.add(workflow_tool_provider)
+ workflow_tool_provider = WorkflowToolProvider(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ app_id=workflow_app_id,
+ name=name,
+ label=label,
+ icon=json.dumps(icon),
+ description=description,
+ parameter_configuration=json.dumps(parameters),
+ privacy_policy=privacy_policy,
+ version=workflow.version,
+ )
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
+ with Session(db.engine, expire_on_commit=False) as session, session.begin():
+ session.add(workflow_tool_provider)
+
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
-
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@classmethod
@@ -183,9 +179,6 @@ class WorkflowToolManageService:
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@classmethod
@@ -248,9 +241,6 @@ class WorkflowToolManageService:
db.session.commit()
- # Invalidate tool providers cache
- ToolProviderListCache.invalidate_cache(tenant_id)
-
return {"result": "success"}
@classmethod
diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py
index 668e4c5be2..688993c798 100644
--- a/api/services/trigger/trigger_provider_service.py
+++ b/api/services/trigger/trigger_provider_service.py
@@ -94,16 +94,23 @@ class TriggerProviderService:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for subscription in subscriptions:
- encrypter, _ = create_trigger_provider_encrypter_for_subscription(
+ credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(
- encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
+ credential_encrypter.mask_credentials(dict(credential_encrypter.decrypt(subscription.credentials)))
)
- subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
- subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
+ properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
+ tenant_id=tenant_id,
+ controller=provider_controller,
+ subscription=subscription,
+ )
+ subscription.properties = dict(
+ properties_encrypter.mask_credentials(dict(properties_encrypter.decrypt(subscription.properties)))
+ )
+ subscription.parameters = dict(subscription.parameters)
count = workflows_in_use_map.get(subscription.id)
subscription.workflows_in_use = count if count is not None else 0
@@ -209,6 +216,101 @@ class TriggerProviderService:
logger.exception("Failed to add trigger provider")
raise ValueError(str(e))
+ @classmethod
+ def update_trigger_subscription(
+ cls,
+ tenant_id: str,
+ subscription_id: str,
+ name: str | None = None,
+ properties: Mapping[str, Any] | None = None,
+ parameters: Mapping[str, Any] | None = None,
+ credentials: Mapping[str, Any] | None = None,
+ credential_expires_at: int | None = None,
+ expires_at: int | None = None,
+ ) -> None:
+ """
+ Update an existing trigger subscription.
+
+ :param tenant_id: Tenant ID
+ :param subscription_id: Subscription instance ID
+ :param name: Optional new name for this subscription
+ :param properties: Optional new properties
+ :param parameters: Optional new parameters
+ :param credentials: Optional new credentials
+ :param credential_expires_at: Optional new credential expiration timestamp
+ :param expires_at: Optional new expiration timestamp
+ :return: Success response with updated subscription info
+ """
+ with Session(db.engine, expire_on_commit=False) as session:
+ # Use distributed lock to prevent race conditions on the same subscription
+ lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
+ with redis_client.lock(lock_key, timeout=20):
+ subscription: TriggerSubscription | None = (
+ session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
+ )
+ if not subscription:
+ raise ValueError(f"Trigger subscription {subscription_id} not found")
+
+ provider_id = TriggerProviderID(subscription.provider_id)
+ provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
+
+ # Check for name uniqueness if name is being updated
+ if name is not None and name != subscription.name:
+ existing = (
+ session.query(TriggerSubscription)
+ .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
+ .first()
+ )
+ if existing:
+ raise ValueError(f"Subscription name '{name}' already exists for this provider")
+ subscription.name = name
+
+ # Update properties if provided
+ if properties is not None:
+ properties_encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=provider_controller.get_properties_schema(),
+ cache=NoOpProviderCredentialCache(),
+ )
+ # Handle hidden values - preserve original encrypted values
+ original_properties = properties_encrypter.decrypt(subscription.properties)
+ new_properties: dict[str, Any] = {
+ key: value if value != HIDDEN_VALUE else original_properties.get(key, UNKNOWN_VALUE)
+ for key, value in properties.items()
+ }
+ subscription.properties = dict(properties_encrypter.encrypt(new_properties))
+
+ # Update parameters if provided
+ if parameters is not None:
+ subscription.parameters = dict(parameters)
+
+ # Update credentials if provided
+ if credentials is not None:
+ credential_type = CredentialType.of(subscription.credential_type)
+ credential_encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=provider_controller.get_credential_schema_config(credential_type),
+ cache=NoOpProviderCredentialCache(),
+ )
+ subscription.credentials = dict(credential_encrypter.encrypt(dict(credentials)))
+
+ # Update credential expiration timestamp if provided
+ if credential_expires_at is not None:
+ subscription.credential_expires_at = credential_expires_at
+
+ # Update expiration timestamp if provided
+ if expires_at is not None:
+ subscription.expires_at = expires_at
+
+ session.commit()
+
+ # Clear subscription cache
+ delete_cache_for_subscription(
+ tenant_id=tenant_id,
+ provider_id=subscription.provider_id,
+ subscription_id=subscription.id,
+ )
+
@classmethod
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
"""
@@ -257,17 +359,18 @@ class TriggerProviderService:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
+ provider_id = TriggerProviderID(subscription.provider_id)
+ provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
+ tenant_id=tenant_id, provider_id=provider_id
+ )
+ encrypter, _ = create_trigger_provider_encrypter_for_subscription(
+ tenant_id=tenant_id,
+ controller=provider_controller,
+ subscription=subscription,
+ )
+
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
if is_auto_created:
- provider_id = TriggerProviderID(subscription.provider_id)
- provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
- tenant_id=tenant_id, provider_id=provider_id
- )
- encrypter, _ = create_trigger_provider_encrypter_for_subscription(
- tenant_id=tenant_id,
- controller=provider_controller,
- subscription=subscription,
- )
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
@@ -280,8 +383,8 @@ class TriggerProviderService:
except Exception as e:
logger.exception("Error unsubscribing trigger", exc_info=e)
- # Clear cache
session.delete(subscription)
+ # Clear cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
@@ -688,3 +791,127 @@ class TriggerProviderService:
)
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
return subscription
+
+ @classmethod
+ def verify_subscription_credentials(
+ cls,
+ tenant_id: str,
+ user_id: str,
+ provider_id: TriggerProviderID,
+ subscription_id: str,
+ credentials: dict[str, Any],
+ ) -> dict[str, Any]:
+ """
+ Verify credentials for an existing subscription without updating it.
+
+ This is used in edit mode to validate new credentials before rebuild.
+
+ :param tenant_id: Tenant ID
+ :param user_id: User ID
+ :param provider_id: Provider identifier
+ :param subscription_id: Subscription ID
+ :param credentials: New credentials to verify
+ :return: dict with 'verified' boolean
+ """
+ provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
+ if not provider_controller:
+ raise ValueError(f"Provider {provider_id} not found")
+
+ subscription = cls.get_subscription_by_id(
+ tenant_id=tenant_id,
+ subscription_id=subscription_id,
+ )
+ if not subscription:
+ raise ValueError(f"Subscription {subscription_id} not found")
+
+ credential_type = CredentialType.of(subscription.credential_type)
+
+ # For API Key, validate the new credentials
+ if credential_type == CredentialType.API_KEY:
+ new_credentials: dict[str, Any] = {
+ key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
+ for key, value in credentials.items()
+ }
+ try:
+ provider_controller.validate_credentials(user_id, credentials=new_credentials)
+ return {"verified": True}
+ except Exception as e:
+ raise ValueError(f"Invalid credentials: {e}") from e
+
+ return {"verified": True}
+
+ @classmethod
+ def rebuild_trigger_subscription(
+ cls,
+ tenant_id: str,
+ provider_id: TriggerProviderID,
+ subscription_id: str,
+ credentials: Mapping[str, Any],
+ parameters: Mapping[str, Any],
+ name: str | None = None,
+ ) -> None:
+ """
+ Create a subscription builder for rebuilding an existing subscription.
+
+ This method rebuild the subscription by call DELETE and CREATE API of the third party provider(e.g. GitHub)
+ keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
+
+ :param tenant_id: Tenant ID
+ :param name: Name for the subscription
+ :param subscription_id: Subscription ID
+ :param provider_id: Provider identifier
+ :param credentials: Credentials for the subscription
+ :param parameters: Parameters for the subscription
+ :return: SubscriptionBuilderApiEntity
+ """
+ provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
+ if not provider_controller:
+ raise ValueError(f"Provider {provider_id} not found")
+
+ subscription = TriggerProviderService.get_subscription_by_id(
+ tenant_id=tenant_id,
+ subscription_id=subscription_id,
+ )
+ if not subscription:
+ raise ValueError(f"Subscription {subscription_id} not found")
+
+ credential_type = CredentialType.of(subscription.credential_type)
+ if credential_type not in {CredentialType.OAUTH2, CredentialType.API_KEY}:
+ raise ValueError(f"Credential type {credential_type} not supported for auto creation")
+
+ # Delete the previous subscription
+ user_id = subscription.user_id
+ unsubscribe_result = TriggerManager.unsubscribe_trigger(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ provider_id=provider_id,
+ subscription=subscription.to_entity(),
+ credentials=subscription.credentials,
+ credential_type=credential_type,
+ )
+ if not unsubscribe_result.success:
+ raise ValueError(f"Failed to delete previous subscription: {unsubscribe_result.message}")
+
+ # Create a new subscription with the same subscription_id and endpoint_id
+ new_credentials: dict[str, Any] = {
+ key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
+ for key, value in credentials.items()
+ }
+ new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ provider_id=provider_id,
+ endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
+ parameters=parameters,
+ credentials=new_credentials,
+ credential_type=credential_type,
+ )
+ TriggerProviderService.update_trigger_subscription(
+ tenant_id=tenant_id,
+ subscription_id=subscription.id,
+ name=name,
+ parameters=parameters,
+ credentials=new_credentials,
+ properties=new_subscription.properties,
+ expires_at=new_subscription.expires_at,
+ )
diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py
index 571393c782..37f852da3e 100644
--- a/api/services/trigger/trigger_subscription_builder_service.py
+++ b/api/services/trigger/trigger_subscription_builder_service.py
@@ -453,11 +453,12 @@ class TriggerSubscriptionBuilderService:
if not subscription_builder:
return None
- # response to validation endpoint
- controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
- tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
- )
try:
+ # response to validation endpoint
+ controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
+ tenant_id=subscription_builder.tenant_id,
+ provider_id=TriggerProviderID(subscription_builder.provider_id),
+ )
dispatch_response: TriggerDispatchResponse = controller.dispatch(
request=request,
subscription=subscription_builder.to_subscription(),
diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py
index 4b3e1330fd..4159f5f8f4 100644
--- a/api/services/trigger/webhook_service.py
+++ b/api/services/trigger/webhook_service.py
@@ -33,6 +33,11 @@ from services.errors.app import QuotaExceededError
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
+try:
+ import magic
+except ImportError:
+ magic = None # type: ignore[assignment]
+
logger = logging.getLogger(__name__)
@@ -317,7 +322,8 @@ class WebhookService:
try:
file_content = request.get_data()
if file_content:
- file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger)
+ mimetype = cls._detect_binary_mimetype(file_content)
+ file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
return {"raw": file_obj.to_dict()}, {}
else:
return {"raw": None}, {}
@@ -341,6 +347,18 @@ class WebhookService:
body = {"raw": ""}
return body, {}
+ @staticmethod
+ def _detect_binary_mimetype(file_content: bytes) -> str:
+ """Guess MIME type for binary payloads using python-magic when available."""
+ if magic is not None:
+ try:
+ detected = magic.from_buffer(file_content[:1024], mime=True)
+ if detected:
+ return detected
+ except Exception:
+ logger.debug("python-magic detection failed for octet-stream payload")
+ return "application/octet-stream"
+
@classmethod
def _process_file_uploads(
cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger
@@ -845,10 +863,18 @@ class WebhookService:
not_found_in_cache.append(node_id)
continue
- with Session(db.engine) as session:
- try:
- # lock the concurrent webhook trigger creation
- redis_client.lock(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
+ lock_key = f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock"
+ lock = redis_client.lock(lock_key, timeout=10)
+ lock_acquired = False
+
+ try:
+ # acquire the lock with blocking and timeout
+ lock_acquired = lock.acquire(blocking=True, blocking_timeout=10)
+ if not lock_acquired:
+ logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
+ raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
+
+ with Session(db.engine) as session:
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@@ -885,11 +911,16 @@ class WebhookService:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
- except Exception:
- logger.exception("Failed to sync webhook relationships for app %s", app.id)
- raise
- finally:
- redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:apps:{app.id}:lock")
+ except Exception:
+ logger.exception("Failed to sync webhook relationships for app %s", app.id)
+ raise
+ finally:
+ # release the lock only if it was acquired
+ if lock_acquired:
+ try:
+ lock.release()
+ except Exception:
+ logger.exception("Failed to release lock for webhook sync, app %s", app.id)
@classmethod
def generate_webhook_id(cls) -> str:
diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py
index 6eb8d0031d..f973361341 100644
--- a/api/services/variable_truncator.py
+++ b/api/services/variable_truncator.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Mapping
@@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator):
self._max_size_bytes = max_size_bytes
@classmethod
- def default(cls) -> "VariableTruncator":
+ def default(cls) -> VariableTruncator:
return VariableTruncator(
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
@@ -410,9 +412,12 @@ class VariableTruncator(BaseTruncator):
@overload
def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
+ @overload
+ def _truncate_json_primitives(self, val: File, target_size: int) -> _PartResult[File]: ...
+
def _truncate_json_primitives(
self,
- val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None,
+ val: UpdatedVariable | File | str | list[object] | dict[str, object] | bool | int | float | None,
target_size: int,
) -> _PartResult[Any]:
"""Truncate a value within an object to fit within budget."""
@@ -425,6 +430,9 @@ class VariableTruncator(BaseTruncator):
return self._truncate_array(val, target_size)
elif isinstance(val, dict):
return self._truncate_object(val, target_size)
+ elif isinstance(val, File):
+ # File objects should not be truncated, return as-is
+ return _PartResult(val, self.calculate_json_size(val), False)
elif val is None or isinstance(val, (bool, int, float)):
return _PartResult(val, self.calculate_json_size(val), False)
else:
diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py
index 9bd797a45f..5ca0b63001 100644
--- a/api/services/webapp_auth_service.py
+++ b/api/services/webapp_auth_service.py
@@ -12,6 +12,7 @@ from libs.passport import PassportService
from libs.password import compare_password
from models import Account, AccountStatus
from models.model import App, EndUser, Site
+from services.account_service import AccountService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
@@ -32,7 +33,7 @@ class WebAppAuthService:
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
- account = db.session.query(Account).filter_by(email=email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
raise AccountNotFoundError()
@@ -52,7 +53,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
- account = db.session.query(Account).where(Account.email == email).first()
+ account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
return None
diff --git a/api/services/website_service.py b/api/services/website_service.py
index a23f01ec71..fe48c3b08e 100644
--- a/api/services/website_service.py
+++ b/api/services/website_service.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import json
from dataclasses import dataclass
@@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest:
return CrawlRequest(url=self.url, provider=self.provider, options=options)
@classmethod
- def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
+ def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
url = args.get("url")
@@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest:
job_id: str
@classmethod
- def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
+ def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
if not provider:
diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py
index 01f0c7a55a..efc76c33bc 100644
--- a/api/services/workflow_app_service.py
+++ b/api/services/workflow_app_service.py
@@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowExecutionStatus
-from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
+from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole
from models.trigger import WorkflowTriggerLog
from services.plugin.plugin_service import PluginService
@@ -86,12 +86,19 @@ class WorkflowAppService:
# Join to workflow run for filtering when needed.
if keyword:
- keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
+ from libs.helper import escape_like_pattern
+
+ # Escape special characters in keyword to prevent SQL injection via LIKE wildcards
+ escaped_keyword = escape_like_pattern(keyword[:30])
+ keyword_like_val = f"%{escaped_keyword}%"
keyword_conditions = [
- WorkflowRun.inputs.ilike(keyword_like_val),
- WorkflowRun.outputs.ilike(keyword_like_val),
+ WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"),
+ WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"),
# filter keyword by end user session id if created by end user role
- and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
+ and_(
+ WorkflowRun.created_by_role == "end_user",
+ EndUser.session_id.ilike(keyword_like_val, escape="\\"),
+ ),
]
# filter keyword by workflow run id
@@ -166,7 +173,80 @@ class WorkflowAppService:
"data": items,
}
- def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]:
+ def get_paginate_workflow_archive_logs(
+ self,
+ *,
+ session: Session,
+ app_model: App,
+ page: int = 1,
+ limit: int = 20,
+ ):
+ """
+ Get paginate workflow archive logs using SQLAlchemy 2.0 style.
+ """
+ stmt = select(WorkflowArchiveLog).where(
+ WorkflowArchiveLog.tenant_id == app_model.tenant_id,
+ WorkflowArchiveLog.app_id == app_model.id,
+ WorkflowArchiveLog.log_id.isnot(None),
+ )
+
+ stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc())
+
+ count_stmt = select(func.count()).select_from(stmt.subquery())
+ total = session.scalar(count_stmt) or 0
+
+ offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
+
+ logs = list(session.scalars(offset_stmt).all())
+ account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT}
+ end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER}
+
+ accounts_by_id = {}
+ if account_ids:
+ accounts_by_id = {
+ account.id: account
+ for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all()
+ }
+
+ end_users_by_id = {}
+ if end_user_ids:
+ end_users_by_id = {
+ end_user.id: end_user
+ for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all()
+ }
+
+ items = []
+ for log in logs:
+ if log.created_by_role == CreatorUserRole.ACCOUNT:
+ created_by_account = accounts_by_id.get(log.created_by)
+ created_by_end_user = None
+ elif log.created_by_role == CreatorUserRole.END_USER:
+ created_by_account = None
+ created_by_end_user = end_users_by_id.get(log.created_by)
+ else:
+ created_by_account = None
+ created_by_end_user = None
+
+ items.append(
+ {
+ "id": log.id,
+ "workflow_run": log.workflow_run_summary,
+ "trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata),
+ "created_by_account": created_by_account,
+ "created_by_end_user": created_by_end_user,
+ "created_at": log.log_created_at,
+ }
+ )
+
+ return {
+ "page": page,
+ "limit": limit,
+ "total": total,
+ "has_more": total > page * limit,
+ "data": items,
+ }
+
+ def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:
return {}
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index f299ce3baa..70b0190231 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
-from core.variables import Segment, StringSegment, Variable
+from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
- _fallback_variables: Sequence[Variable]
+ _fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
- fallback_variables: Sequence[Variable] | None = None,
+ fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
- def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
- # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
- variable_by_selector: dict[tuple[str, str], Variable] = {}
+ # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
+ variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
- def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
+ def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.
@@ -679,6 +679,7 @@ def _batch_upsert_draft_variable(
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
+ "id": model.id,
"app_id": model.app_id,
"last_edited_at": None,
"node_id": model.node_id,
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index b45a167b73..6404136994 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
-from core.variables import Variable
-from core.variables.variables import VariableUnion
+from core.variables import VariableBase
+from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -198,8 +198,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
- environment_variables: Sequence[Variable],
- conversation_variables: Sequence[Variable],
+ environment_variables: Sequence[VariableBase],
+ conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@@ -675,7 +675,7 @@ class WorkflowService:
else:
variable_pool = VariablePool(
- system_variables=SystemVariable.empty(),
+ system_variables=SystemVariable.default(),
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
@@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
- conversation_variables: list[Variable],
+ conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@@ -1063,16 +1063,16 @@ def _setup_variable_pool(
system_variable.conversation_id = conversation_id
system_variable.dialogue_count = 1
else:
- system_variable = SystemVariable.empty()
+ system_variable = SystemVariable.default()
# init variable pool
variable_pool = VariablePool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
- # Based on the definition of `VariableUnion`,
- # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- conversation_variables=cast(list[VariableUnion], conversation_variables), #
+ # Based on the definition of `Variable`,
+ # `VariableBase` instances can be safely used as `Variable` since they are compatible.
+ conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index 292ac6e008..3ee41c2e8d 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -31,7 +31,8 @@ class WorkspaceService:
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
- can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
+ feature = FeatureService.get_features(tenant.id)
+ can_replace_logo = feature.can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
base_url = dify_config.FILES_URL
@@ -46,5 +47,19 @@ class WorkspaceService:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
+ if dify_config.EDITION == "CLOUD":
+ tenant_info["next_credit_reset_date"] = feature.next_credit_reset_date
+
+ from services.credit_pool_service import CreditPoolService
+
+ paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
+ if paid_pool:
+ tenant_info["trial_credits"] = paid_pool.quota_limit
+ tenant_info["trial_credits_used"] = paid_pool.quota_used
+ else:
+ trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
+ if trial_pool:
+ tenant_info["trial_credits"] = trial_pool.quota_limit
+ tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info
diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py
index e7dead8a56..2d3d00cd50 100644
--- a/api/tasks/add_document_to_index_task.py
+++ b/api/tasks/add_document_to_index_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
@@ -28,106 +28,119 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
start_at = time.perf_counter()
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
+ if not dataset_document:
+ logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
+ return
- if dataset_document.indexing_status != "completed":
- db.session.close()
- return
+ if dataset_document.indexing_status != "completed":
+ return
- indexing_cache_key = f"document_{dataset_document.id}_indexing"
+ indexing_cache_key = f"document_{dataset_document.id}_indexing"
- try:
- dataset = dataset_document.dataset
- if not dataset:
- raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
+ try:
+ dataset = dataset_document.dataset
+ if not dataset:
+ raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
- segments = (
- db.session.query(DocumentSegment)
- .where(
- DocumentSegment.document_id == dataset_document.id,
- DocumentSegment.status == "completed",
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.status == "completed",
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
+
+ # delete auto disable log
+ session.query(DatasetAutoDisableLog).where(
+ DatasetAutoDisableLog.document_id == dataset_document.id
+ ).delete()
+
+ # update segment to enable
+ session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
+ {
+ DocumentSegment.enabled: True,
+ DocumentSegment.disabled_at: None,
+ DocumentSegment.disabled_by: None,
+ DocumentSegment.updated_at: naive_utc_now(),
+ }
)
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
+ session.commit()
+
+ # Enable summary indexes for all segments in this document
+ from services.summary_index_service import SummaryIndexService
+
+ segment_ids_list = [segment.id for segment in segments]
+ if segment_ids_list:
+ try:
+ SummaryIndexService.enable_summaries_for_segments(
+ dataset=dataset,
+ segment_ids=segment_ids_list,
)
- documents.append(document)
+ except Exception as e:
+ logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e))
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
-
- # delete auto disable log
- db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
-
- # update segment to enable
- db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
- {
- DocumentSegment.enabled: True,
- DocumentSegment.disabled_at: None,
- DocumentSegment.disabled_by: None,
- DocumentSegment.updated_at: naive_utc_now(),
- }
- )
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
- )
- except Exception as e:
- logger.exception("add document to index failed")
- dataset_document.enabled = False
- dataset_document.disabled_at = naive_utc_now()
- dataset_document.indexing_status = "error"
- dataset_document.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
+ )
+ except Exception as e:
+ logger.exception("add document to index failed")
+ dataset_document.enabled = False
+ dataset_document.disabled_at = naive_utc_now()
+ dataset_document.indexing_status = "error"
+ dataset_document.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py
index 8e46e8d0e3..fc6bf03454 100644
--- a/api/tasks/annotation/batch_import_annotations_task.py
+++ b/api/tasks/annotation/batch_import_annotations_task.py
@@ -5,9 +5,9 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@@ -30,65 +30,74 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
- # get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ active_jobs_key = f"annotation_import_active:{tenant_id}"
- if app:
- try:
- documents = []
- for content in content_list:
- annotation = MessageAnnotation(
- app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
+ with session_factory.create_session() as session:
+ # get app info
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+
+ if app:
+ try:
+ documents = []
+ for content in content_list:
+ annotation = MessageAnnotation(
+ app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
+ )
+ session.add(annotation)
+ session.flush()
+
+ document = Document(
+ page_content=content["question"],
+ metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ )
+ documents.append(document)
+ # if annotation reply is enabled , batch add annotations' index
+ app_annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
- db.session.add(annotation)
- db.session.flush()
- document = Document(
- page_content=content["question"],
- metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
- )
- documents.append(document)
- # if annotation reply is enabled , batch add annotations' index
- app_annotation_setting = (
- db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
- )
+ if app_annotation_setting:
+ dataset_collection_binding = (
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ app_annotation_setting.collection_binding_id, "annotation"
+ )
+ )
+ if not dataset_collection_binding:
+ raise NotFound("App annotation setting not found")
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=dataset_collection_binding.provider_name,
+ embedding_model=dataset_collection_binding.model_name,
+ collection_binding_id=dataset_collection_binding.id,
+ )
- if app_annotation_setting:
- dataset_collection_binding = (
- DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
- app_annotation_setting.collection_binding_id, "annotation"
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ vector.create(documents, duplicate_check=True)
+
+ session.commit()
+ redis_client.setex(indexing_cache_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Build index successful for batch import annotation: {} latency: {}".format(
+ job_id, end_at - start_at
+ ),
+ fg="green",
)
)
- if not dataset_collection_binding:
- raise NotFound("App annotation setting not found")
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=dataset_collection_binding.provider_name,
- embedding_model=dataset_collection_binding.model_name,
- collection_binding_id=dataset_collection_binding.id,
- )
-
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- vector.create(documents, duplicate_check=True)
-
- db.session.commit()
- redis_client.setex(indexing_cache_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Build index successful for batch import annotation: {} latency: {}".format(
- job_id, end_at - start_at
- ),
- fg="green",
- )
- )
- except Exception as e:
- db.session.rollback()
- redis_client.setex(indexing_cache_key, 600, "error")
- indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
- redis_client.setex(indexing_error_msg_key, 600, str(e))
- logger.exception("Build index for batch import annotations failed")
- finally:
- db.session.close()
+ except Exception as e:
+ session.rollback()
+ redis_client.setex(indexing_cache_key, 600, "error")
+ indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
+ redis_client.setex(indexing_error_msg_key, 600, str(e))
+ logger.exception("Build index for batch import annotations failed")
+ finally:
+ # Clean up active job tracking to release concurrency slot
+ try:
+ redis_client.zrem(active_jobs_key, job_id)
+ logger.debug("Released concurrency slot for job: %s", job_id)
+ except Exception as cleanup_error:
+ # Log but don't fail if cleanup fails - the job will be auto-expired
+ logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py
index c0020b29ed..7b5cd46b00 100644
--- a/api/tasks/annotation/disable_annotation_reply_task.py
+++ b/api/tasks/annotation/disable_annotation_reply_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import exists, select
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
- if not app:
- logger.info(click.style(f"App not found: {app_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
+ if not app:
+ logger.info(click.style(f"App not found: {app_id}", fg="red"))
+ return
- app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
-
- if not app_annotation_setting:
- logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
- db.session.close()
- return
-
- disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
- disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
-
- try:
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- collection_binding_id=app_annotation_setting.collection_binding_id,
+ app_annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
+ if not app_annotation_setting:
+ logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
+ return
+
+ disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
+ disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
+
try:
- if annotations_exists:
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- vector.delete()
- except Exception:
- logger.exception("Delete annotation index failed when annotation deleted.")
- redis_client.setex(disable_app_annotation_job_key, 600, "completed")
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ collection_binding_id=app_annotation_setting.collection_binding_id,
+ )
- # delete annotation setting
- db.session.delete(app_annotation_setting)
- db.session.commit()
+ try:
+ if annotations_exists:
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ vector.delete()
+ except Exception:
+ logger.exception("Delete annotation index failed when annotation deleted.")
+ redis_client.setex(disable_app_annotation_job_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("Annotation batch deleted index failed")
- redis_client.setex(disable_app_annotation_job_key, 600, "error")
- disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
- redis_client.setex(disable_app_annotation_error_key, 600, str(e))
- finally:
- redis_client.delete(disable_app_annotation_key)
- db.session.close()
+ # delete annotation setting
+ session.delete(app_annotation_setting)
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"App annotations index deleted : {app_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ logger.exception("Annotation batch deleted index failed")
+ redis_client.setex(disable_app_annotation_job_key, 600, "error")
+ disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
+ redis_client.setex(disable_app_annotation_error_key, 600, str(e))
+ finally:
+ redis_client.delete(disable_app_annotation_key)
diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py
index cdc07c77a8..4f8e2fec7a 100644
--- a/api/tasks/annotation/enable_annotation_reply_task.py
+++ b/api/tasks/annotation/enable_annotation_reply_task.py
@@ -5,9 +5,9 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
@@ -33,92 +33,98 @@ def enable_annotation_reply_task(
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ with session_factory.create_session() as session:
+ app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- if not app:
- logger.info(click.style(f"App not found: {app_id}", fg="red"))
- db.session.close()
- return
+ if not app:
+ logger.info(click.style(f"App not found: {app_id}", fg="red"))
+ return
- annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
- enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
- enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
+ annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
+ enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
+ enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
- try:
- documents = []
- dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
- embedding_provider_name, embedding_model_name, "annotation"
- )
- annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
- if annotation_setting:
- if dataset_collection_binding.id != annotation_setting.collection_binding_id:
- old_dataset_collection_binding = (
- DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
- annotation_setting.collection_binding_id, "annotation"
- )
- )
- if old_dataset_collection_binding and annotations:
- old_dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=old_dataset_collection_binding.provider_name,
- embedding_model=old_dataset_collection_binding.model_name,
- collection_binding_id=old_dataset_collection_binding.id,
- )
-
- old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
- try:
- old_vector.delete()
- except Exception as e:
- logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
- annotation_setting.score_threshold = score_threshold
- annotation_setting.collection_binding_id = dataset_collection_binding.id
- annotation_setting.updated_user_id = user_id
- annotation_setting.updated_at = naive_utc_now()
- db.session.add(annotation_setting)
- else:
- new_app_annotation_setting = AppAnnotationSetting(
- app_id=app_id,
- score_threshold=score_threshold,
- collection_binding_id=dataset_collection_binding.id,
- created_user_id=user_id,
- updated_user_id=user_id,
+ try:
+ documents = []
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+ embedding_provider_name, embedding_model_name, "annotation"
)
- db.session.add(new_app_annotation_setting)
+ annotation_setting = (
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
+ )
+ if annotation_setting:
+ if dataset_collection_binding.id != annotation_setting.collection_binding_id:
+ old_dataset_collection_binding = (
+ DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
+ annotation_setting.collection_binding_id, "annotation"
+ )
+ )
+ if old_dataset_collection_binding and annotations:
+ old_dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=old_dataset_collection_binding.provider_name,
+ embedding_model=old_dataset_collection_binding.model_name,
+ collection_binding_id=old_dataset_collection_binding.id,
+ )
- dataset = Dataset(
- id=app_id,
- tenant_id=tenant_id,
- indexing_technique="high_quality",
- embedding_model_provider=embedding_provider_name,
- embedding_model=embedding_model_name,
- collection_binding_id=dataset_collection_binding.id,
- )
- if annotations:
- for annotation in annotations:
- document = Document(
- page_content=annotation.question,
- metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ try:
+ old_vector.delete()
+ except Exception as e:
+ logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
+ annotation_setting.score_threshold = score_threshold
+ annotation_setting.collection_binding_id = dataset_collection_binding.id
+ annotation_setting.updated_user_id = user_id
+ annotation_setting.updated_at = naive_utc_now()
+ session.add(annotation_setting)
+ else:
+ new_app_annotation_setting = AppAnnotationSetting(
+ app_id=app_id,
+ score_threshold=score_threshold,
+ collection_binding_id=dataset_collection_binding.id,
+ created_user_id=user_id,
+ updated_user_id=user_id,
)
- documents.append(document)
+ session.add(new_app_annotation_setting)
- vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
- try:
- vector.delete_by_metadata_field("app_id", app_id)
- except Exception as e:
- logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
- vector.create(documents)
- db.session.commit()
- redis_client.setex(enable_app_annotation_job_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("Annotation batch created index failed")
- redis_client.setex(enable_app_annotation_job_key, 600, "error")
- enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
- redis_client.setex(enable_app_annotation_error_key, 600, str(e))
- db.session.rollback()
- finally:
- redis_client.delete(enable_app_annotation_key)
- db.session.close()
+ dataset = Dataset(
+ id=app_id,
+ tenant_id=tenant_id,
+ indexing_technique="high_quality",
+ embedding_model_provider=embedding_provider_name,
+ embedding_model=embedding_model_name,
+ collection_binding_id=dataset_collection_binding.id,
+ )
+ if annotations:
+ for annotation in annotations:
+ document = Document(
+ page_content=annotation.question_text,
+ metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
+ )
+ documents.append(document)
+
+ vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
+ try:
+ vector.delete_by_metadata_field("app_id", app_id)
+ except Exception as e:
+ logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
+ vector.create(documents)
+ session.commit()
+ redis_client.setex(enable_app_annotation_job_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"App annotations added to index: {app_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception as e:
+ logger.exception("Annotation batch created index failed")
+ redis_client.setex(enable_app_annotation_job_key, 600, "error")
+ enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
+ redis_client.setex(enable_app_annotation_error_key, 600, str(e))
+ session.rollback()
+ finally:
+ redis_client.delete(enable_app_annotation_key)
diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py
index f8aac5b469..b51884148e 100644
--- a/api/tasks/async_workflow_tasks.py
+++ b/api/tasks/async_workflow_tasks.py
@@ -10,13 +10,13 @@ from typing import Any
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.trigger_post_layer import TriggerPostLayer
-from extensions.ext_database import db
+from core.db.session_factory import session_factory
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
@@ -98,10 +98,7 @@ def _execute_workflow_common(
):
"""Execute workflow with common logic and trigger log updates."""
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
# Get trigger log
@@ -157,7 +154,7 @@ def _execute_workflow_common(
root_node_id=trigger_data.root_node_id,
graph_engine_layers=[
# TODO: Re-enable TimeSliceLayer after the HITL release.
- TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
+ TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
],
)
diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py
index 3e1bd16cc7..d388284980 100644
--- a/api/tasks/batch_clean_document_task.py
+++ b/api/tasks/batch_clean_document_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
@@ -28,65 +28,66 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
"""
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
+ if not doc_form:
+ raise ValueError("doc_form is required")
- try:
- if not doc_form:
- raise ValueError("doc_form is required")
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- db.session.query(DatasetMetadataBinding).where(
- DatasetMetadataBinding.dataset_id == dataset_id,
- DatasetMetadataBinding.document_id.in_(document_ids),
- ).delete(synchronize_session=False)
+ session.query(DatasetMetadataBinding).where(
+ DatasetMetadataBinding.dataset_id == dataset_id,
+ DatasetMetadataBinding.document_id.in_(document_ids),
+ ).delete(synchronize_session=False)
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
- ).all()
- # check segment is exist
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
+ ).all()
+ # check segment is exist
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(
+ dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
+ )
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
+ for image_file in image_files:
+ try:
+ if image_file and image_file.key:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+ stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(stmt)
+ session.delete(segment)
+ if file_ids:
+ files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
+ for file in files:
try:
- if image_file and image_file.key:
- storage.delete(image_file.key)
+ storage.delete(file.key)
except Exception:
- logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
- )
- db.session.delete(image_file)
- db.session.delete(segment)
+ logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
+ stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
+ session.execute(stmt)
- db.session.commit()
- if file_ids:
- files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
- for file in files:
- try:
- storage.delete(file.key)
- except Exception:
- logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
- db.session.delete(file)
+ session.commit()
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Cleaned documents when documents deleted latency: {end_at - start_at}",
- fg="green",
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned documents when documents deleted latency: {end_at - start_at}",
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned documents when documents deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned documents when documents deleted failed")
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index bd95af2614..8ee09d5738 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -9,9 +9,9 @@ import pandas as pd
from celery import shared_task
from sqlalchemy import func
+from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
@@ -48,104 +48,107 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
- try:
- dataset = db.session.get(Dataset, dataset_id)
- if not dataset:
- raise ValueError("Dataset not exist.")
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.get(Dataset, dataset_id)
+ if not dataset:
+ raise ValueError("Dataset not exist.")
- dataset_document = db.session.get(Document, document_id)
- if not dataset_document:
- raise ValueError("Document not exist.")
+ dataset_document = session.get(Document, document_id)
+ if not dataset_document:
+ raise ValueError("Document not exist.")
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- raise ValueError("Document is not available.")
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ raise ValueError("Document is not available.")
- upload_file = db.session.get(UploadFile, upload_file_id)
- if not upload_file:
- raise ValueError("UploadFile not found.")
+ upload_file = session.get(UploadFile, upload_file_id)
+ if not upload_file:
+ raise ValueError("UploadFile not found.")
- with tempfile.TemporaryDirectory() as temp_dir:
- suffix = Path(upload_file.key).suffix
- file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
- storage.download(upload_file.key, file_path)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ suffix = Path(upload_file.key).suffix
+ file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
+ storage.download(upload_file.key, file_path)
- df = pd.read_csv(file_path)
- content = []
- for _, row in df.iterrows():
+ df = pd.read_csv(file_path)
+ content = []
+ for _, row in df.iterrows():
+ if dataset_document.doc_form == "qa_model":
+ data = {"content": row.iloc[0], "answer": row.iloc[1]}
+ else:
+ data = {"content": row.iloc[0]}
+ content.append(data)
+ if len(content) == 0:
+ raise ValueError("The CSV file is empty.")
+
+ document_segments = []
+ embedding_model = None
+ if dataset.indexing_technique == "high_quality":
+ model_manager = ModelManager()
+ embedding_model = model_manager.get_model_instance(
+ tenant_id=dataset.tenant_id,
+ provider=dataset.embedding_model_provider,
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=dataset.embedding_model,
+ )
+
+ word_count_change = 0
+ if embedding_model:
+ tokens_list = embedding_model.get_text_embedding_num_tokens(
+ texts=[segment["content"] for segment in content]
+ )
+ else:
+ tokens_list = [0] * len(content)
+
+ for segment, tokens in zip(content, tokens_list):
+ content = segment["content"]
+ doc_id = str(uuid.uuid4())
+ segment_hash = helper.generate_text_hash(content)
+ max_position = (
+ session.query(func.max(DocumentSegment.position))
+ .where(DocumentSegment.document_id == dataset_document.id)
+ .scalar()
+ )
+ segment_document = DocumentSegment(
+ tenant_id=tenant_id,
+ dataset_id=dataset_id,
+ document_id=document_id,
+ index_node_id=doc_id,
+ index_node_hash=segment_hash,
+ position=max_position + 1 if max_position else 1,
+ content=content,
+ word_count=len(content),
+ tokens=tokens,
+ created_by=user_id,
+ indexing_at=naive_utc_now(),
+ status="completed",
+ completed_at=naive_utc_now(),
+ )
if dataset_document.doc_form == "qa_model":
- data = {"content": row.iloc[0], "answer": row.iloc[1]}
- else:
- data = {"content": row.iloc[0]}
- content.append(data)
- if len(content) == 0:
- raise ValueError("The CSV file is empty.")
+ segment_document.answer = segment["answer"]
+ segment_document.word_count += len(segment["answer"])
+ word_count_change += segment_document.word_count
+ session.add(segment_document)
+ document_segments.append(segment_document)
- document_segments = []
- embedding_model = None
- if dataset.indexing_technique == "high_quality":
- model_manager = ModelManager()
- embedding_model = model_manager.get_model_instance(
- tenant_id=dataset.tenant_id,
- provider=dataset.embedding_model_provider,
- model_type=ModelType.TEXT_EMBEDDING,
- model=dataset.embedding_model,
- )
+ assert dataset_document.word_count is not None
+ dataset_document.word_count += word_count_change
+ session.add(dataset_document)
- word_count_change = 0
- if embedding_model:
- tokens_list = embedding_model.get_text_embedding_num_tokens(
- texts=[segment["content"] for segment in content]
+ VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
+ session.commit()
+ redis_client.setex(indexing_cache_key, 600, "completed")
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Segment batch created job: {job_id} latency: {end_at - start_at}",
+ fg="green",
+ )
)
- else:
- tokens_list = [0] * len(content)
-
- for segment, tokens in zip(content, tokens_list):
- content = segment["content"]
- doc_id = str(uuid.uuid4())
- segment_hash = helper.generate_text_hash(content)
- max_position = (
- db.session.query(func.max(DocumentSegment.position))
- .where(DocumentSegment.document_id == dataset_document.id)
- .scalar()
- )
- segment_document = DocumentSegment(
- tenant_id=tenant_id,
- dataset_id=dataset_id,
- document_id=document_id,
- index_node_id=doc_id,
- index_node_hash=segment_hash,
- position=max_position + 1 if max_position else 1,
- content=content,
- word_count=len(content),
- tokens=tokens,
- created_by=user_id,
- indexing_at=naive_utc_now(),
- status="completed",
- completed_at=naive_utc_now(),
- )
- if dataset_document.doc_form == "qa_model":
- segment_document.answer = segment["answer"]
- segment_document.word_count += len(segment["answer"])
- word_count_change += segment_document.word_count
- db.session.add(segment_document)
- document_segments.append(segment_document)
-
- assert dataset_document.word_count is not None
- dataset_document.word_count += word_count_change
- db.session.add(dataset_document)
-
- VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
- db.session.commit()
- redis_client.setex(indexing_cache_key, 600, "completed")
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Segment batch created job: {job_id} latency: {end_at - start_at}",
- fg="green",
- )
- )
- except Exception:
- logger.exception("Segments batch created index failed")
- redis_client.setex(indexing_cache_key, 600, "error")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Segments batch created index failed")
+ redis_client.setex(indexing_cache_key, 600, "error")
diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py
index 8608df6b8e..0d51a743ad 100644
--- a/api/tasks/clean_dataset_task.py
+++ b/api/tasks/clean_dataset_task.py
@@ -3,12 +3,13 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
+from models import WorkflowType
from models.dataset import (
AppDatasetJoin,
Dataset,
@@ -18,9 +19,11 @@ from models.dataset import (
DatasetQuery,
Document,
DocumentSegment,
+ Pipeline,
SegmentAttachmentBinding,
)
from models.model import UploadFile
+from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -34,6 +37,7 @@ def clean_dataset_task(
index_struct: str,
collection_binding_id: str,
doc_form: str,
+ pipeline_id: str | None = None,
):
"""
Clean dataset when dataset deleted.
@@ -49,127 +53,155 @@ def clean_dataset_task(
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = Dataset(
- id=dataset_id,
- tenant_id=tenant_id,
- indexing_technique=indexing_technique,
- index_struct=index_struct,
- collection_binding_id=collection_binding_id,
- )
- documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
- # Use JOIN to fetch attachments with bindings in a single query
- attachments_with_bindings = db.session.execute(
- select(SegmentAttachmentBinding, UploadFile)
- .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
- .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
- ).all()
-
- # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
- # This ensures all invalid doc_form values are properly handled
- if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
- # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
- from core.rag.index_processor.constant.index_type import IndexStructureType
-
- doc_form = IndexStructureType.PARAGRAPH_INDEX
- logger.info(
- click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
- )
-
- # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
- # This ensures Document/Segment deletion can continue even if vector database cleanup fails
+ with session_factory.create_session() as session:
try:
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
- logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
- except Exception:
- logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
- # Continue with document and segment deletion even if vector cleanup fails
- logger.info(
- click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
+ dataset = Dataset(
+ id=dataset_id,
+ tenant_id=tenant_id,
+ indexing_technique=indexing_technique,
+ index_struct=index_struct,
+ collection_binding_id=collection_binding_id,
)
+ documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
+ # Use JOIN to fetch attachments with bindings in a single query
+ attachments_with_bindings = session.execute(
+ select(SegmentAttachmentBinding, UploadFile)
+ .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
+ .where(
+ SegmentAttachmentBinding.tenant_id == tenant_id,
+ SegmentAttachmentBinding.dataset_id == dataset_id,
+ )
+ ).all()
- if documents is None or len(documents) == 0:
- logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
- else:
- logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
+ # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
+ # This ensures all invalid doc_form values are properly handled
+ if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
+ # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
+ from core.rag.index_processor.constant.index_type import IndexStructureType
- for document in documents:
- db.session.delete(document)
- # delete document file
+ doc_form = IndexStructureType.PARAGRAPH_INDEX
+ logger.info(
+ click.style(
+ f"Invalid doc_form detected, using default index type for cleanup: {doc_form}",
+ fg="yellow",
+ )
+ )
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
- if image_file is None:
- continue
+ # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
+ # This ensures Document/Segment deletion can continue even if vector database cleanup fails
+ try:
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
+ logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
+ except Exception:
+ logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
+ # Continue with document and segment deletion even if vector cleanup fails
+ logger.info(
+ click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
+ )
+
+ if documents is None or len(documents) == 0:
+ logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
+ else:
+ logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
+
+ for document in documents:
+ session.delete(document)
+
+ segment_ids = [segment.id for segment in segments]
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
+ for image_file in image_files:
+ if image_file is None:
+ continue
+ try:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+ stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(stmt)
+
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ # delete segment attachments
+ if attachments_with_bindings:
+ attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
+ binding_ids = [binding.id for binding, _ in attachments_with_bindings]
+ for binding, attachment_file in attachments_with_bindings:
try:
- storage.delete(image_file.key)
+ storage.delete(attachment_file.key)
except Exception:
logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
+ "Delete attachment_file failed when storage deleted, \
+ attachment_file_id: %s",
+ binding.attachment_id,
)
- db.session.delete(image_file)
- db.session.delete(segment)
- # delete segment attachments
- if attachments_with_bindings:
- for binding, attachment_file in attachments_with_bindings:
- try:
- storage.delete(attachment_file.key)
- except Exception:
- logger.exception(
- "Delete attachment_file failed when storage deleted, \
- attachment_file_id: %s",
- binding.attachment_id,
- )
- db.session.delete(attachment_file)
- db.session.delete(binding)
+ attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
+ session.execute(attachment_file_delete_stmt)
- db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
- db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
- db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
- # delete dataset metadata
- db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
- db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
- # delete files
- if documents:
- for document in documents:
- try:
+ binding_delete_stmt = delete(SegmentAttachmentBinding).where(
+ SegmentAttachmentBinding.id.in_(binding_ids)
+ )
+ session.execute(binding_delete_stmt)
+
+ session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
+ session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
+ session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
+ # delete dataset metadata
+ session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
+ session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
+ # delete pipeline and workflow
+ if pipeline_id:
+ session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
+ session.query(Workflow).where(
+ Workflow.tenant_id == tenant_id,
+ Workflow.app_id == pipeline_id,
+ Workflow.type == WorkflowType.RAG_PIPELINE,
+ ).delete()
+ # delete files
+ if documents:
+ file_ids = []
+ for document in documents:
if document.data_source_type == "upload_file":
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
- file = (
- db.session.query(UploadFile)
- .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
- .first()
- )
- if not file:
- continue
- storage.delete(file.key)
- db.session.delete(file)
- except Exception:
- continue
+ file_ids.append(file_id)
+ files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
+ for file in files:
+ storage.delete(file.key)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
- )
- except Exception:
- # Add rollback to prevent dirty session state in case of exceptions
- # This ensures the database session is properly cleaned up
- try:
- db.session.rollback()
- logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
+ file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
+ session.execute(file_delete_stmt)
+
+ session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
except Exception:
- logger.exception("Failed to rollback database session")
+ # Add rollback to prevent dirty session state in case of exceptions
+ # This ensures the database session is properly cleaned up
+ try:
+ session.rollback()
+ logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
+ except Exception:
+ logger.exception("Failed to rollback database session")
- logger.exception("Cleaned dataset when dataset deleted failed")
- finally:
- db.session.close()
+ logger.exception("Cleaned dataset when dataset deleted failed")
+ finally:
+ # Explicitly close the session for test expectations and safety
+ try:
+ session.close()
+ except Exception:
+ logger.exception("Failed to close database session")
diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py
index 6d2feb1da3..91ace6be02 100644
--- a/api/tasks/clean_document_task.py
+++ b/api/tasks/clean_document_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
-from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
@@ -29,85 +29,96 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- # Use JOIN to fetch attachments with bindings in a single query
- attachments_with_bindings = db.session.execute(
- select(SegmentAttachmentBinding, UploadFile)
- .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
- .where(
- SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
- SegmentAttachmentBinding.dataset_id == dataset_id,
- SegmentAttachmentBinding.document_id == document_id,
- )
- ).all()
- # check segment is exist
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ # Use JOIN to fetch attachments with bindings in a single query
+ attachments_with_bindings = session.execute(
+ select(SegmentAttachmentBinding, UploadFile)
+ .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
+ .where(
+ SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
+ SegmentAttachmentBinding.dataset_id == dataset_id,
+ SegmentAttachmentBinding.document_id == document_id,
+ )
+ ).all()
+ # check segment is exist
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(
+ dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
+ )
- for segment in segments:
- image_upload_file_ids = get_image_upload_file_ids(segment.content)
- for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
- if image_file is None:
- continue
+ for segment in segments:
+ image_upload_file_ids = get_image_upload_file_ids(segment.content)
+ image_files = session.scalars(
+ select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ ).all()
+ for image_file in image_files:
+ if image_file is None:
+ continue
+ try:
+ storage.delete(image_file.key)
+ except Exception:
+ logger.exception(
+ "Delete image_files failed when storage deleted, \
+ image_upload_file_is: %s",
+ image_file.id,
+ )
+
+ image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
+ session.execute(image_file_delete_stmt)
+ session.delete(segment)
+
+ session.commit()
+ if file_id:
+ file = session.query(UploadFile).where(UploadFile.id == file_id).first()
+ if file:
try:
- storage.delete(image_file.key)
+ storage.delete(file.key)
+ except Exception:
+ logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
+ session.delete(file)
+ # delete segment attachments
+ if attachments_with_bindings:
+ attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
+ binding_ids = [binding.id for binding, _ in attachments_with_bindings]
+ for binding, attachment_file in attachments_with_bindings:
+ try:
+ storage.delete(attachment_file.key)
except Exception:
logger.exception(
- "Delete image_files failed when storage deleted, \
- image_upload_file_is: %s",
- upload_file_id,
+ "Delete attachment_file failed when storage deleted, \
+ attachment_file_id: %s",
+ binding.attachment_id,
)
- db.session.delete(image_file)
- db.session.delete(segment)
+ attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
+ session.execute(attachment_file_delete_stmt)
- db.session.commit()
- if file_id:
- file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
- if file:
- try:
- storage.delete(file.key)
- except Exception:
- logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
- db.session.delete(file)
- db.session.commit()
- # delete segment attachments
- if attachments_with_bindings:
- for binding, attachment_file in attachments_with_bindings:
- try:
- storage.delete(attachment_file.key)
- except Exception:
- logger.exception(
- "Delete attachment_file failed when storage deleted, \
- attachment_file_id: %s",
- binding.attachment_id,
- )
- db.session.delete(attachment_file)
- db.session.delete(binding)
+ binding_delete_stmt = delete(SegmentAttachmentBinding).where(
+ SegmentAttachmentBinding.id.in_(binding_ids)
+ )
+ session.execute(binding_delete_stmt)
- # delete dataset metadata binding
- db.session.query(DatasetMetadataBinding).where(
- DatasetMetadataBinding.dataset_id == dataset_id,
- DatasetMetadataBinding.document_id == document_id,
- ).delete()
- db.session.commit()
+ # delete dataset metadata binding
+ session.query(DatasetMetadataBinding).where(
+ DatasetMetadataBinding.dataset_id == dataset_id,
+ DatasetMetadataBinding.document_id == document_id,
+ ).delete()
+ session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
- fg="green",
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when document deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned document when document deleted failed")
diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py
index 771b43f9b0..4214f043e0 100644
--- a/api/tasks/clean_notion_document_task.py
+++ b/api/tasks/clean_notion_document_task.py
@@ -3,10 +3,10 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@@ -24,37 +24,39 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Document has no dataset")
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- for document_id in document_ids:
- document = db.session.query(Document).where(Document.id == document_id).first()
- db.session.delete(document)
+ if not dataset:
+ raise Exception("Document has no dataset")
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- index_node_ids = [segment.index_node_id for segment in segments]
+ document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
+ session.execute(document_delete_stmt)
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ for document_id in document_ids:
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Clean document when import form notion document deleted end :: {} latency: {}".format(
- dataset_id, end_at - start_at
- ),
- fg="green",
+ index_processor.clean(
+ dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
+ )
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Clean document when import form notion document deleted end :: {} latency: {}".format(
+ dataset_id, end_at - start_at
+ ),
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when import form notion document deleted failed")
- finally:
- db.session.close()
+ except Exception:
+ logger.exception("Cleaned document when import form notion document deleted failed")
diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py
index 6b2907cffd..b5e472d71e 100644
--- a/api/tasks/create_segment_to_index_task.py
+++ b/api/tasks/create_segment_to_index_task.py
@@ -4,9 +4,9 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "waiting":
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- # update segment status to indexing
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(
- {
- DocumentSegment.status: "indexing",
- DocumentSegment.indexing_at: naive_utc_now(),
- }
- )
- db.session.commit()
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
-
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "waiting":
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_type = dataset.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.load(dataset, [document])
+ try:
+ # update segment status to indexing
+ session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "indexing",
+ DocumentSegment.indexing_at: naive_utc_now(),
+ }
+ )
+ session.commit()
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- # update segment to completed
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(
- {
- DocumentSegment.status: "completed",
- DocumentSegment.completed_at: naive_utc_now(),
- }
- )
- db.session.commit()
+ dataset = segment.dataset
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("create segment to index failed")
- segment.enabled = False
- segment.disabled_at = naive_utc_now()
- segment.status = "error"
- segment.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_type = dataset.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.load(dataset, [document])
+
+ # update segment to completed
+ session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "completed",
+ DocumentSegment.completed_at: naive_utc_now(),
+ }
+ )
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("create segment to index failed")
+ segment.enabled = False
+ segment.disabled_at = naive_utc_now()
+ segment.status = "error"
+ segment.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py
index 3d13afdec0..fa844a8647 100644
--- a/api/tasks/deal_dataset_index_update_task.py
+++ b/api/tasks/deal_dataset_index_update_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task # type: ignore
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).filter_by(id=dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- if action == "upgrade":
- dataset_documents = (
- db.session.query(DatasetDocument)
- .where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ if action == "upgrade":
+ dataset_documents = (
+ session.query(DatasetDocument)
+ .where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ .all()
)
- .all()
- )
- if dataset_documents:
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
+ if dataset_documents:
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
- for dataset_document in dataset_documents:
- try:
- # add from vector index
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ for dataset_document in dataset_documents:
+ try:
+ # add from vector index
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
)
-
- documents.append(document)
- # save vector index
- # clean keywords
- index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
- index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- elif action == "update":
- dataset_documents = (
- db.session.query(DatasetDocument)
- .where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- .all()
- )
- # add new index
- if dataset_documents:
- # update document status
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
-
- # clean index
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
-
- for dataset_document in dataset_documents:
- # update from vector index
- try:
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(
- dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- else:
- # clean collection
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+ if segments:
+ documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- end_at = time.perf_counter()
- logging.info(
- click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
- )
- except Exception:
- logging.exception("Deal dataset vector index failed")
- finally:
- db.session.close()
+ documents.append(document)
+ # save vector index
+ # clean keywords
+ index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
+ index_processor.load(dataset, documents, with_keywords=False)
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ elif action == "update":
+ dataset_documents = (
+ session.query(DatasetDocument)
+ .where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ .all()
+ )
+ # add new index
+ if dataset_documents:
+ # update document status
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
+
+ # clean index
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ for dataset_document in dataset_documents:
+ # update from vector index
+ try:
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
+ )
+ if segments:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+ # save vector index
+ index_processor.load(
+ dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ )
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ else:
+ # clean collection
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("Deal dataset vector index failed")
diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py
index 1c7de3b1ce..0047e04a17 100644
--- a/api/tasks/deal_dataset_vector_index_task.py
+++ b/api/tasks/deal_dataset_vector_index_task.py
@@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).filter_by(id=dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- if action == "remove":
- index_processor.clean(dataset, None, with_keywords=False)
- elif action == "add":
- dataset_documents = db.session.scalars(
- select(DatasetDocument).where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- ).all()
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ if action == "remove":
+ index_processor.clean(dataset, None, with_keywords=False)
+ elif action == "add":
+ dataset_documents = session.scalars(
+ select(DatasetDocument).where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ ).all()
- if dataset_documents:
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
+ if dataset_documents:
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
- for dataset_document in dataset_documents:
- try:
- # add from vector index
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
+ for dataset_document in dataset_documents:
+ try:
+ # add from vector index
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
)
-
- documents.append(document)
- # save vector index
- index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- elif action == "update":
- dataset_documents = db.session.scalars(
- select(DatasetDocument).where(
- DatasetDocument.dataset_id == dataset_id,
- DatasetDocument.indexing_status == "completed",
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- ).all()
- # add new index
- if dataset_documents:
- # update document status
- dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
- {"indexing_status": "indexing"}, synchronize_session=False
- )
- db.session.commit()
-
- # clean index
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
-
- for dataset_document in dataset_documents:
- # update from vector index
- try:
- segments = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
- .order_by(DocumentSegment.position.asc())
- .all()
- )
- if segments:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(
- dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ .order_by(DocumentSegment.position.asc())
+ .all()
)
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "completed"}, synchronize_session=False
- )
- db.session.commit()
- except Exception as e:
- db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
- {"indexing_status": "error", "error": str(e)}, synchronize_session=False
- )
- db.session.commit()
- else:
- # clean collection
- index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+ if segments:
+ documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
- end_at = time.perf_counter()
- logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("Deal dataset vector index failed")
- finally:
- db.session.close()
+ documents.append(document)
+ # save vector index
+ index_processor.load(dataset, documents, with_keywords=False)
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ elif action == "update":
+ dataset_documents = session.scalars(
+ select(DatasetDocument).where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ ).all()
+ # add new index
+ if dataset_documents:
+ # update document status
+ dataset_documents_ids = [doc.id for doc in dataset_documents]
+ session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ {"indexing_status": "indexing"}, synchronize_session=False
+ )
+ session.commit()
+
+ # clean index
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ for dataset_document in dataset_documents:
+ # update from vector index
+ try:
+ segments = (
+ session.query(DocumentSegment)
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.enabled == True,
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
+ )
+ if segments:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
+ )
+ documents.append(document)
+ # save vector index
+ index_processor.load(
+ dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
+ )
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "completed"}, synchronize_session=False
+ )
+ session.commit()
+ except Exception as e:
+ session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
+ {"indexing_status": "error", "error": str(e)}, synchronize_session=False
+ )
+ session.commit()
+ else:
+ # clean collection
+ index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("Deal dataset vector index failed")
diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py
index fb5eb1d691..ecf6f9cb39 100644
--- a/api/tasks/delete_account_task.py
+++ b/api/tasks/delete_account_task.py
@@ -2,7 +2,8 @@ import logging
from celery import shared_task
-from extensions.ext_database import db
+from configs import dify_config
+from core.db.session_factory import session_factory
from models import Account
from services.billing_service import BillingService
from tasks.mail_account_deletion_task import send_deletion_success_task
@@ -12,15 +13,17 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
- account = db.session.query(Account).where(Account.id == account_id).first()
- try:
- BillingService.delete_account(account_id)
- except Exception:
- logger.exception("Failed to delete account %s from billing service.", account_id)
- raise
+ with session_factory.create_session() as session:
+ account = session.query(Account).where(Account.id == account_id).first()
+ try:
+ if dify_config.BILLING_ENABLED:
+ BillingService.delete_account(account_id)
+ except Exception:
+ logger.exception("Failed to delete account %s from billing service.", account_id)
+ raise
- if not account:
- logger.error("Account %s not found.", account_id)
- return
- # send success email
- send_deletion_success_task.delay(account.email)
+ if not account:
+ logger.error("Account %s not found.", account_id)
+ return
+ # send success email
+ send_deletion_success_task.delay(account.email)
diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py
index 756b67c93e..9664b8ac73 100644
--- a/api/tasks/delete_conversation_task.py
+++ b/api/tasks/delete_conversation_task.py
@@ -4,7 +4,7 @@ import time
import click
from celery import shared_task
-from extensions.ext_database import db
+from core.db.session_factory import session_factory
from models import ConversationVariable
from models.model import Message, MessageAnnotation, MessageFeedback
from models.tools import ToolConversationVariables, ToolFile
@@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str):
)
start_at = time.perf_counter()
- try:
- db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(ToolConversationVariables).where(
- ToolConversationVariables.conversation_id == conversation_id
- ).delete(synchronize_session=False)
-
- db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
-
- db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
-
- db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
- synchronize_session=False
- )
-
- db.session.commit()
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
- fg="green",
+ with session_factory.create_session() as session:
+ try:
+ session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
+ synchronize_session=False
)
- )
- except Exception as e:
- logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
- db.session.rollback()
- raise e
- finally:
- db.session.close()
+ session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.query(ToolConversationVariables).where(
+ ToolConversationVariables.conversation_id == conversation_id
+ ).delete(synchronize_session=False)
+
+ session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
+
+ session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ (
+ f"Succeeded cleaning data from db for conversation_id {conversation_id} "
+ f"latency: {end_at - start_at}"
+ ),
+ fg="green",
+ )
+ )
+
+ except Exception:
+ logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
+ session.rollback()
+ raise
diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py
index bea5c952cf..764c635d83 100644
--- a/api/tasks/delete_segment_from_index_task.py
+++ b/api/tasks/delete_segment_from_index_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
@@ -26,49 +26,54 @@ def delete_segment_from_index_task(
"""
logger.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
- return
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
+ return
- dataset_document = db.session.query(Document).where(Document.id == document_id).first()
- if not dataset_document:
- return
+ dataset_document = session.query(Document).where(Document.id == document_id).first()
+ if not dataset_document:
+ return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logging.info("Document not in valid state for index operations, skipping")
- return
- doc_form = dataset_document.doc_form
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logging.info("Document not in valid state for index operations, skipping")
+ return
+ doc_form = dataset_document.doc_form
- # Proceed with index cleanup using the index_node_ids directly
- index_processor = IndexProcessorFactory(doc_form).init_index_processor()
- index_processor.clean(
- dataset,
- index_node_ids,
- with_keywords=True,
- delete_child_chunks=True,
- precomputed_child_node_ids=child_node_ids,
- )
- if dataset.is_multimodal:
- # delete segment attachment binding
- segment_attachment_bindings = (
- db.session.query(SegmentAttachmentBinding)
- .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
- .all()
+ # Proceed with index cleanup using the index_node_ids directly
+ # For actual deletion, we should delete summaries (not just disable them)
+ index_processor = IndexProcessorFactory(doc_form).init_index_processor()
+ index_processor.clean(
+ dataset,
+ index_node_ids,
+ with_keywords=True,
+ delete_child_chunks=True,
+ precomputed_child_node_ids=child_node_ids,
+ delete_summaries=True, # Actually delete summaries when segment is deleted
)
- if segment_attachment_bindings:
- attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
- index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
- for binding in segment_attachment_bindings:
- db.session.delete(binding)
- # delete upload file
- db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
- db.session.commit()
+ if dataset.is_multimodal:
+ # delete segment attachment binding
+ segment_attachment_bindings = (
+ session.query(SegmentAttachmentBinding)
+ .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
+ .all()
+ )
+ if segment_attachment_bindings:
+ attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
+ index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
+ for binding in segment_attachment_bindings:
+ session.delete(binding)
+ # delete upload file
+ session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
+ session.commit()
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("delete segment from index failed")
- finally:
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
+ except Exception:
+ logger.exception("delete segment from index failed")
diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py
index 6b5f01b416..bc45171623 100644
--- a/api/tasks/disable_segment_from_index_task.py
+++ b/api/tasks/disable_segment_from_index_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@@ -23,46 +23,65 @@ def disable_segment_from_index_task(segment_id: str):
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "completed":
- logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "completed":
+ logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_type = dataset_document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
- index_processor.clean(dataset, [segment.index_node_id])
+ try:
+ dataset = segment.dataset
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("remove segment from index failed")
- segment.enabled = True
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_type = dataset_document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_processor.clean(dataset, [segment.index_node_id])
+
+ # Disable summary index for this segment
+ from services.summary_index_service import SummaryIndexService
+
+ try:
+ SummaryIndexService.disable_summaries_for_segments(
+ dataset=dataset,
+ segment_ids=[segment.id],
+ disabled_by=segment.disabled_by,
+ )
+ except Exception as e:
+ logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e))
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Segment removed from index: {segment.id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("remove segment from index failed")
+ segment.enabled = True
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py
index c2a3de29f4..3cc267e821 100644
--- a/api/tasks/disable_segments_from_index_task.py
+++ b/api/tasks/disable_segments_from_index_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
@@ -26,69 +26,80 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
+ return
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
- db.session.close()
- return
- # sync index processor
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if not dataset_document:
+ logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
+ return
+ if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+ logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
+ return
+ # sync index processor
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- )
- ).all()
-
- if not segments:
- db.session.close()
- return
-
- try:
- index_node_ids = [segment.index_node_id for segment in segments]
- if dataset.is_multimodal:
- segment_ids = [segment.id for segment in segments]
- segment_attachment_bindings = (
- db.session.query(SegmentAttachmentBinding)
- .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
- .all()
+ segments = session.scalars(
+ select(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
)
- if segment_attachment_bindings:
- attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
- index_node_ids.extend(attachment_ids)
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+ ).all()
- end_at = time.perf_counter()
- logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
- except Exception:
- # update segment error msg
- db.session.query(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- ).update(
- {
- "disabled_at": None,
- "disabled_by": None,
- "enabled": True,
- }
- )
- db.session.commit()
- finally:
- for segment in segments:
- indexing_cache_key = f"segment_{segment.id}_indexing"
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ if not segments:
+ return
+
+ try:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ if dataset.is_multimodal:
+ segment_ids = [segment.id for segment in segments]
+ segment_attachment_bindings = (
+ session.query(SegmentAttachmentBinding)
+ .where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
+ .all()
+ )
+ if segment_attachment_bindings:
+ attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
+ index_node_ids.extend(attachment_ids)
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+
+ # Disable summary indexes for these segments
+ from services.summary_index_service import SummaryIndexService
+
+ segment_ids_list = [segment.id for segment in segments]
+ try:
+ # Get disabled_by from first segment (they should all have the same disabled_by)
+ disabled_by = segments[0].disabled_by if segments else None
+ SummaryIndexService.disable_summaries_for_segments(
+ dataset=dataset,
+ segment_ids=segment_ids_list,
+ disabled_by=disabled_by,
+ )
+ except Exception as e:
+ logger.warning("Failed to disable summaries for segments: %s", str(e))
+
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
+ except Exception:
+ # update segment error msg
+ session.query(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
+ ).update(
+ {
+ "disabled_at": None,
+ "disabled_by": None,
+ "enabled": True,
+ }
+ )
+ session.commit()
+ finally:
+ for segment in segments:
+ indexing_cache_key = f"segment_{segment.id}_indexing"
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py
index 4c1f38c3bb..149185f6e2 100644
--- a/api/tasks/document_indexing_sync_task.py
+++ b/api/tasks/document_indexing_sync_task.py
@@ -2,17 +2,16 @@ import logging
import time
import click
-import sqlalchemy as sa
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
-from models.source import DataSourceOauthBinding
+from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@@ -29,96 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- data_source_info = document.data_source_info_dict
- if document.data_source_type == "notion_import":
- if (
- not data_source_info
- or "notion_page_id" not in data_source_info
- or "notion_workspace_id" not in data_source_info
- ):
- raise ValueError("no notion page found")
- workspace_id = data_source_info["notion_workspace_id"]
- page_id = data_source_info["notion_page_id"]
- page_type = data_source_info["type"]
- page_edited_time = data_source_info["last_edited_time"]
+ data_source_info = document.data_source_info_dict
+ if document.data_source_type == "notion_import":
+ if (
+ not data_source_info
+ or "notion_page_id" not in data_source_info
+ or "notion_workspace_id" not in data_source_info
+ ):
+ raise ValueError("no notion page found")
+ workspace_id = data_source_info["notion_workspace_id"]
+ page_id = data_source_info["notion_page_id"]
+ page_type = data_source_info["type"]
+ page_edited_time = data_source_info["last_edited_time"]
+ credential_id = data_source_info.get("credential_id")
- data_source_binding = (
- db.session.query(DataSourceOauthBinding)
- .where(
- sa.and_(
- DataSourceOauthBinding.tenant_id == document.tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.disabled == False,
- DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
- )
+ # Get credentials from datasource provider
+ datasource_provider_service = DatasourceProviderService()
+ credential = datasource_provider_service.get_datasource_credentials(
+ tenant_id=document.tenant_id,
+ credential_id=credential_id,
+ provider="notion_datasource",
+ plugin_id="langgenius/notion_datasource",
)
- .first()
- )
- if not data_source_binding:
- raise ValueError("Data source binding not found.")
- loader = NotionExtractor(
- notion_workspace_id=workspace_id,
- notion_obj_id=page_id,
- notion_page_type=page_type,
- notion_access_token=data_source_binding.access_token,
- tenant_id=document.tenant_id,
- )
-
- last_edited_time = loader.get_notion_last_edited_time()
-
- # check the page is updated
- if last_edited_time != page_edited_time:
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.commit()
-
- # delete all document segment and index
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
- index_type = document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
-
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- index_node_ids = [segment.index_node_id for segment in segments]
-
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
-
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Cleaned document when document update data source or process rule: {} latency: {}".format(
- document_id, end_at - start_at
- ),
- fg="green",
- )
+ if not credential:
+ logger.error(
+ "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
+ document_id,
+ document.tenant_id,
+ credential_id,
)
- except Exception:
- logger.exception("Cleaned document when document update data source or process rule failed")
+ document.indexing_status = "error"
+ document.error = "Datasource credential not found. Please reconnect your Notion workspace."
+ document.stopped_at = naive_utc_now()
+ session.commit()
+ return
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- end_at = time.perf_counter()
- logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ loader = NotionExtractor(
+ notion_workspace_id=workspace_id,
+ notion_obj_id=page_id,
+ notion_page_type=page_type,
+ notion_access_token=credential.get("integration_secret"),
+ tenant_id=document.tenant_id,
+ )
+
+ last_edited_time = loader.get_notion_last_edited_time()
+
+ # check the page is updated
+ if last_edited_time != page_edited_time:
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.commit()
+
+ # delete all document segment and index
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise Exception("Dataset not found")
+ index_type = document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ index_node_ids = [segment.index_node_id for segment in segments]
+
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Cleaned document when document update data source or process rule: {} latency: {}".format(
+ document_id, end_at - start_at
+ ),
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("Cleaned document when document update data source or process rule failed")
+
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ end_at = time.perf_counter()
+ logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py
index acbdab631b..34496e9c6f 100644
--- a/api/tasks/document_indexing_task.py
+++ b/api/tasks/document_indexing_task.py
@@ -6,14 +6,15 @@ import click
from celery import shared_task
from configs import dify_config
+from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
+from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
@@ -46,66 +47,135 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
documents = []
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
- db.session.close()
- return
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- count = len(document_ids)
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
- if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
- raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
- if count > batch_upload_limit:
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
+ return
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ count = len(document_ids)
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+ if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ for document_id in document_ids:
+ document = (
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
- except Exception as e:
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ return
+
for document_id in document_ids:
+ logger.info(click.style(f"Start process document: {document_id}", fg="green"))
+
document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
+
if document:
- document.indexing_status = "error"
- document.error = str(e)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- db.session.close()
- return
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ documents.append(document)
+ session.add(document)
+ session.commit()
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run(documents)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- )
+ # Trigger summary index generation for completed documents if enabled
+ # Only generate for high_quality indexing technique and when summary_index_setting is enabled
+ # Re-query dataset to get latest summary_index_setting (in case it was updated)
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.warning("Dataset %s not found after indexing", dataset_id)
+ return
- if document:
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
-
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run(documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
- finally:
- db.session.close()
+ if dataset.indexing_technique == "high_quality":
+ summary_index_setting = dataset.summary_index_setting
+ if summary_index_setting and summary_index_setting.get("enable"):
+ # expire all session to get latest document's indexing status
+ session.expire_all()
+ # Check each document's indexing status and trigger summary generation if completed
+ for document_id in document_ids:
+ # Re-query document to get latest status (IndexingRunner may have updated it)
+ document = (
+ session.query(Document)
+ .where(Document.id == document_id, Document.dataset_id == dataset_id)
+ .first()
+ )
+ if document:
+ logger.info(
+ "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
+ document_id,
+ document.indexing_status,
+ document.doc_form,
+ document.need_summary,
+ )
+ if (
+ document.indexing_status == "completed"
+ and document.doc_form != "qa_model"
+ and document.need_summary is True
+ ):
+ try:
+ generate_summary_index_task.delay(dataset.id, document_id, None)
+ logger.info(
+ "Queued summary index generation task for document %s in dataset %s "
+ "after indexing completed",
+ document_id,
+ dataset.id,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to queue summary index generation task for document %s",
+ document_id,
+ )
+ # Don't fail the entire indexing process if summary task queuing fails
+ else:
+ logger.info(
+ "Skipping summary generation for document %s: "
+ "status=%s, doc_form=%s, need_summary=%s",
+ document_id,
+ document.indexing_status,
+ document.doc_form,
+ document.need_summary,
+ )
+ else:
+ logger.warning("Document %s not found after indexing", document_id)
+ else:
+ logger.info(
+ "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
+ dataset.id,
+ summary_index_setting.get("enable") if summary_index_setting else None,
+ )
+ else:
+ logger.info(
+ "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
+ dataset.id,
+ dataset.indexing_technique,
+ )
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(
diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py
index 161502a228..67a23be952 100644
--- a/api/tasks/document_indexing_update_task.py
+++ b/api/tasks/document_indexing_update_task.py
@@ -3,8 +3,9 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.commit()
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.commit()
- # delete all document segment and index
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- raise Exception("Dataset not found")
+ # delete all document segment and index
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise Exception("Dataset not found")
- index_type = document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
+ index_type = document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
- end_at = time.perf_counter()
- logger.info(
- click.style(
- "Cleaned document when document update data source or process rule: {} latency: {}".format(
- document_id, end_at - start_at
- ),
- fg="green",
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ db.session.commit()
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ "Cleaned document when document update data source or process rule: {} latency: {}".format(
+ document_id, end_at - start_at
+ ),
+ fg="green",
+ )
)
- )
- except Exception:
- logger.exception("Cleaned document when document update data source or process rule failed")
+ except Exception:
+ logger.exception("Cleaned document when document update data source or process rule failed")
- try:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- end_at = time.perf_counter()
- logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ end_at = time.perf_counter()
+ logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py
index 4078c8910e..00a963255b 100644
--- a/api/tasks/duplicate_document_indexing_task.py
+++ b/api/tasks/duplicate_document_indexing_task.py
@@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
from configs import dify_config
+from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue(
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
- documents = []
+ documents: list[Document] = []
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if dataset is None:
- logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
- db.session.close()
- return
-
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
+ with session_factory.create_session() as session:
try:
- if features.billing.enabled:
- vector_space = features.vector_space
- count = len(document_ids)
- if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
- raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
- if count > batch_upload_limit:
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
- current = int(getattr(vector_space, "size", 0) or 0)
- limit = int(getattr(vector_space, "limit", 0) or 0)
- if limit > 0 and (current + count) > limit:
- raise ValueError(
- "Your total number of documents plus the number of uploads have exceeded the limit of "
- "your subscription."
- )
- except Exception as e:
- for document_id in document_ids:
- document = (
- db.session.query(Document)
- .where(Document.id == document_id, Document.dataset_id == dataset_id)
- .first()
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if dataset is None:
+ logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ count = len(document_ids)
+ if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+ current = int(getattr(vector_space, "size", 0) or 0)
+ limit = int(getattr(vector_space, "limit", 0) or 0)
+ if limit > 0 and (current + count) > limit:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have exceeded the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ documents = list(
+ session.scalars(
+ select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
+ ).all()
)
- if document:
- document.indexing_status = "error"
- document.error = str(e)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- return
+ for document in documents:
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ return
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
-
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ documents = list(
+ session.scalars(
+ select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
+ ).all()
)
- if document:
+ for document in documents:
+ logger.info(click.style(f"Start process document: {document.id}", fg="green"))
+
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
+ session.add(document)
+ session.commit()
- indexing_runner = IndexingRunner()
- indexing_runner.run(documents)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
- finally:
- db.session.close()
+ indexing_runner = IndexingRunner()
+ indexing_runner.run(list(documents))
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
@shared_task(queue="dataset")
diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py
index 7615469ed0..41ebb0b076 100644
--- a/api/tasks/enable_segment_to_index_task.py
+++ b/api/tasks/enable_segment_to_index_task.py
@@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@@ -27,91 +27,104 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
- if not segment:
- logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
- db.session.close()
- return
-
- if segment.status != "completed":
- logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
- db.session.close()
- return
-
- indexing_cache_key = f"segment_{segment.id}_indexing"
-
- try:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
-
- dataset = segment.dataset
-
- if not dataset:
- logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ with session_factory.create_session() as session:
+ segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
+ if not segment:
+ logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
- dataset_document = segment.document
-
- if not dataset_document:
- logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ if segment.status != "completed":
+ logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
- return
+ indexing_cache_key = f"segment_{segment.id}_indexing"
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- },
- )
- child_documents.append(child_document)
- document.children = child_documents
- multimodel_documents = []
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodel_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
+ try:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+
+ dataset = segment.dataset
+
+ if not dataset:
+ logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
+ return
+
+ dataset_document = segment.document
+
+ if not dataset_document:
+ logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
+ return
+
+ if (
+ not dataset_document.enabled
+ or dataset_document.archived
+ or dataset_document.indexing_status != "completed"
+ ):
+ logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
+ return
+
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+ multimodel_documents = []
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodel_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
)
+
+ # save vector index
+ index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
+
+ # Enable summary index for this segment
+ from services.summary_index_service import SummaryIndexService
+
+ try:
+ SummaryIndexService.enable_summaries_for_segments(
+ dataset=dataset,
+ segment_ids=[segment.id],
)
+ except Exception as e:
+ logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e))
- # save vector index
- index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
-
- end_at = time.perf_counter()
- logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("enable segment to index failed")
- segment.enabled = False
- segment.disabled_at = naive_utc_now()
- segment.status = "error"
- segment.error = str(e)
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("enable segment to index failed")
+ segment.enabled = False
+ segment.disabled_at = naive_utc_now()
+ segment.status = "error"
+ segment.error = str(e)
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py
index 9f17d09e18..d90eb4c39f 100644
--- a/api/tasks/enable_segments_to_index_task.py
+++ b/api/tasks/enable_segments_to_index_task.py
@@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
@@ -29,105 +29,114 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
- return
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
+ return
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
- if not dataset_document:
- logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
- db.session.close()
- return
- if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
- logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
- db.session.close()
- return
- # sync index processor
- index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
+ if not dataset_document:
+ logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
+ return
+ if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
+ logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
+ return
+ # sync index processor
+ index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
- segments = db.session.scalars(
- select(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- )
- ).all()
- if not segments:
- logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
- db.session.close()
- return
-
- try:
- documents = []
- multimodal_documents = []
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": document_id,
- "dataset_id": dataset_id,
- },
+ segments = session.scalars(
+ select(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
)
+ ).all()
+ if not segments:
+ logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
+ return
- if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
- child_chunks = segment.get_child_chunks()
- if child_chunks:
- child_documents = []
- for child_chunk in child_chunks:
- child_document = ChildDocument(
- page_content=child_chunk.content,
- metadata={
- "doc_id": child_chunk.index_node_id,
- "doc_hash": child_chunk.index_node_hash,
- "document_id": document_id,
- "dataset_id": dataset_id,
- },
+ try:
+ documents = []
+ multimodal_documents = []
+ for segment in segments:
+ document = Document(
+ page_content=segment.content,
+ metadata={
+ "doc_id": segment.index_node_id,
+ "doc_hash": segment.index_node_hash,
+ "document_id": document_id,
+ "dataset_id": dataset_id,
+ },
+ )
+
+ if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
+ child_chunks = segment.get_child_chunks()
+ if child_chunks:
+ child_documents = []
+ for child_chunk in child_chunks:
+ child_document = ChildDocument(
+ page_content=child_chunk.content,
+ metadata={
+ "doc_id": child_chunk.index_node_id,
+ "doc_hash": child_chunk.index_node_hash,
+ "document_id": document_id,
+ "dataset_id": dataset_id,
+ },
+ )
+ child_documents.append(child_document)
+ document.children = child_documents
+
+ if dataset.is_multimodal:
+ for attachment in segment.attachments:
+ multimodal_documents.append(
+ AttachmentDocument(
+ page_content=attachment["name"],
+ metadata={
+ "doc_id": attachment["id"],
+ "doc_hash": "",
+ "document_id": segment.document_id,
+ "dataset_id": segment.dataset_id,
+ "doc_type": DocType.IMAGE,
+ },
+ )
)
- child_documents.append(child_document)
- document.children = child_documents
+ documents.append(document)
+ # save vector index
+ index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
- if dataset.is_multimodal:
- for attachment in segment.attachments:
- multimodal_documents.append(
- AttachmentDocument(
- page_content=attachment["name"],
- metadata={
- "doc_id": attachment["id"],
- "doc_hash": "",
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- "doc_type": DocType.IMAGE,
- },
- )
- )
- documents.append(document)
- # save vector index
- index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
+ # Enable summary indexes for these segments
+ from services.summary_index_service import SummaryIndexService
- end_at = time.perf_counter()
- logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception("enable segments to index failed")
- # update segment error msg
- db.session.query(DocumentSegment).where(
- DocumentSegment.id.in_(segment_ids),
- DocumentSegment.dataset_id == dataset_id,
- DocumentSegment.document_id == document_id,
- ).update(
- {
- "error": str(e),
- "status": "error",
- "disabled_at": naive_utc_now(),
- "enabled": False,
- }
- )
- db.session.commit()
- finally:
- for segment in segments:
- indexing_cache_key = f"segment_{segment.id}_indexing"
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ segment_ids_list = [segment.id for segment in segments]
+ try:
+ SummaryIndexService.enable_summaries_for_segments(
+ dataset=dataset,
+ segment_ids=segment_ids_list,
+ )
+ except Exception as e:
+ logger.warning("Failed to enable summaries for segments: %s", str(e))
+
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception("enable segments to index failed")
+ # update segment error msg
+ session.query(DocumentSegment).where(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.document_id == document_id,
+ ).update(
+ {
+ "error": str(e),
+ "status": "error",
+ "disabled_at": naive_utc_now(),
+ "enabled": False,
+ }
+ )
+ session.commit()
+ finally:
+ for segment in segments:
+ indexing_cache_key = f"segment_{segment.id}_indexing"
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py
new file mode 100644
index 0000000000..e4273e16b5
--- /dev/null
+++ b/api/tasks/generate_summary_index_task.py
@@ -0,0 +1,119 @@
+"""Async task for generating summary indexes."""
+
+import logging
+import time
+
+import click
+from celery import shared_task
+
+from core.db.session_factory import session_factory
+from models.dataset import Dataset, DocumentSegment
+from models.dataset import Document as DatasetDocument
+from services.summary_index_service import SummaryIndexService
+
+logger = logging.getLogger(__name__)
+
+
+@shared_task(queue="dataset")
+def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
+ """
+ Async generate summary index for document segments.
+
+ Args:
+ dataset_id: Dataset ID
+ document_id: Document ID
+ segment_ids: Optional list of specific segment IDs to process. If None, process all segments.
+
+ Usage:
+ generate_summary_index_task.delay(dataset_id, document_id)
+ generate_summary_index_task.delay(dataset_id, document_id, segment_ids)
+ """
+ logger.info(
+ click.style(
+ f"Start generating summary index for document {document_id} in dataset {dataset_id}",
+ fg="green",
+ )
+ )
+ start_at = time.perf_counter()
+
+ try:
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+
+ document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ if not document:
+ logger.error(click.style(f"Document not found: {document_id}", fg="red"))
+ return
+
+ # Check if document needs summary
+ if not document.need_summary:
+ logger.info(
+ click.style(
+ f"Skipping summary generation for document {document_id}: need_summary is False",
+ fg="cyan",
+ )
+ )
+ return
+
+ # Only generate summary index for high_quality indexing technique
+ if dataset.indexing_technique != "high_quality":
+ logger.info(
+ click.style(
+ f"Skipping summary generation for dataset {dataset_id}: "
+ f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
+ fg="cyan",
+ )
+ )
+ return
+
+ # Check if summary index is enabled
+ summary_index_setting = dataset.summary_index_setting
+ if not summary_index_setting or not summary_index_setting.get("enable"):
+ logger.info(
+ click.style(
+ f"Summary index is disabled for dataset {dataset_id}",
+ fg="cyan",
+ )
+ )
+ return
+
+ # Determine if only parent chunks should be processed
+ only_parent_chunks = dataset.chunk_structure == "parent_child_index"
+
+ # Generate summaries
+ summary_records = SummaryIndexService.generate_summaries_for_document(
+ dataset=dataset,
+ document=document,
+ summary_index_setting=summary_index_setting,
+ segment_ids=segment_ids,
+ only_parent_chunks=only_parent_chunks,
+ )
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Summary index generation completed for document {document_id}: "
+ f"{len(summary_records)} summaries generated, latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+
+ except Exception as e:
+ logger.exception("Failed to generate summary index for document %s", document_id)
+ # Update document segments with error status if needed
+ if segment_ids:
+ error_message = f"Summary generation failed: {str(e)}"
+ with session_factory.create_session() as session:
+ session.query(DocumentSegment).filter(
+ DocumentSegment.id.in_(segment_ids),
+ DocumentSegment.dataset_id == dataset_id,
+ ).update(
+ {
+ DocumentSegment.error: error_message,
+ },
+ synchronize_session=False,
+ )
+ session.commit()
diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
index e6492c230d..b5e6508006 100644
--- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
+++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:"
-CACHE_REDIS_TTL = 60 * 15 # 15 minutes
+CACHE_REDIS_TTL = 60 * 60 # 1 hour
def _get_redis_cache_key(plugin_id: str) -> str:
diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
index 1eef361a92..3c5e152520 100644
--- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
+++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
@@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
- invoke_from=InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py
index 275f5abe6e..093342d1a3 100644
--- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py
+++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py
@@ -178,7 +178,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
- invoke_from=InvokeFrom.PUBLISHED,
+ invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py
index 1b2a653c01..af72023da1 100644
--- a/api/tasks/recover_document_indexing_task.py
+++ b/api/tasks/recover_document_indexing_task.py
@@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
-from extensions.ext_database import db
from models.dataset import Document
logger = logging.getLogger(__name__)
@@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- try:
- indexing_runner = IndexingRunner()
- if document.indexing_status in {"waiting", "parsing", "cleaning"}:
- indexing_runner.run([document])
- elif document.indexing_status == "splitting":
- indexing_runner.run_in_splitting_status(document)
- elif document.indexing_status == "indexing":
- indexing_runner.run_in_indexing_status(document)
- end_at = time.perf_counter()
- logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
- except DocumentIsPausedError as ex:
- logger.info(click.style(str(ex), fg="yellow"))
- except Exception:
- logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
- finally:
- db.session.close()
+ try:
+ indexing_runner = IndexingRunner()
+ if document.indexing_status in {"waiting", "parsing", "cleaning"}:
+ indexing_runner.run([document])
+ elif document.indexing_status == "splitting":
+ indexing_runner.run_in_splitting_status(document)
+ elif document.indexing_status == "indexing":
+ indexing_runner.run_in_indexing_status(document)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
+ except DocumentIsPausedError as ex:
+ logger.info(click.style(str(ex), fg="yellow"))
+ except Exception:
+ logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py
new file mode 100644
index 0000000000..cf8988d13e
--- /dev/null
+++ b/api/tasks/regenerate_summary_index_task.py
@@ -0,0 +1,315 @@
+"""Task for regenerating summary indexes when dataset settings change."""
+
+import logging
+import time
+from collections import defaultdict
+
+import click
+from celery import shared_task
+from sqlalchemy import or_, select
+
+from core.db.session_factory import session_factory
+from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
+from models.dataset import Document as DatasetDocument
+from services.summary_index_service import SummaryIndexService
+
+logger = logging.getLogger(__name__)
+
+
+@shared_task(queue="dataset")
+def regenerate_summary_index_task(
+ dataset_id: str,
+ regenerate_reason: str = "summary_model_changed",
+ regenerate_vectors_only: bool = False,
+):
+ """
+ Regenerate summary indexes for all documents in a dataset.
+
+ This task is triggered when:
+ 1. summary_index_setting model changes (regenerate_reason="summary_model_changed")
+ - Regenerates summary content and vectors for all existing summaries
+ 2. embedding_model changes (regenerate_reason="embedding_model_changed")
+ - Only regenerates vectors for existing summaries (keeps summary content)
+
+ Args:
+ dataset_id: Dataset ID
+ regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed")
+ regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content
+ """
+ logger.info(
+ click.style(
+ f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}",
+ fg="green",
+ )
+ )
+ start_at = time.perf_counter()
+
+ try:
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).filter_by(id=dataset_id).first()
+ if not dataset:
+ logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+
+ # Only regenerate summary index for high_quality indexing technique
+ if dataset.indexing_technique != "high_quality":
+ logger.info(
+ click.style(
+ f"Skipping summary regeneration for dataset {dataset_id}: "
+ f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
+ fg="cyan",
+ )
+ )
+ return
+
+ # Check if summary index is enabled (only for summary_model change)
+ # For embedding_model change, we still re-vectorize existing summaries even if setting is disabled
+ summary_index_setting = dataset.summary_index_setting
+ if not regenerate_vectors_only:
+ # For summary_model change, require summary_index_setting to be enabled
+ if not summary_index_setting or not summary_index_setting.get("enable"):
+ logger.info(
+ click.style(
+ f"Summary index is disabled for dataset {dataset_id}",
+ fg="cyan",
+ )
+ )
+ return
+
+ total_segments_processed = 0
+ total_segments_failed = 0
+
+ if regenerate_vectors_only:
+ # For embedding_model change: directly query all segments with existing summaries
+ # Don't require document indexing_status == "completed"
+ # Include summaries with status "completed" or "error" (if they have content)
+ segments_with_summaries = (
+ session.query(DocumentSegment, DocumentSegmentSummary)
+ .join(
+ DocumentSegmentSummary,
+ DocumentSegment.id == DocumentSegmentSummary.chunk_id,
+ )
+ .join(
+ DatasetDocument,
+ DocumentSegment.document_id == DatasetDocument.id,
+ )
+ .where(
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.status == "completed", # Segment must be completed
+ DocumentSegment.enabled == True,
+ DocumentSegmentSummary.dataset_id == dataset_id,
+ DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content
+ # Include completed summaries or error summaries (with content)
+ or_(
+ DocumentSegmentSummary.status == "completed",
+ DocumentSegmentSummary.status == "error",
+ ),
+ DatasetDocument.enabled == True, # Document must be enabled
+ DatasetDocument.archived == False, # Document must not be archived
+ DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
+ )
+ .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
+ .all()
+ )
+
+ if not segments_with_summaries:
+ logger.info(
+ click.style(
+ f"No segments with summaries found for re-vectorization in dataset {dataset_id}",
+ fg="cyan",
+ )
+ )
+ return
+
+ logger.info(
+ "Found %s segments with summaries for re-vectorization in dataset %s",
+ len(segments_with_summaries),
+ dataset_id,
+ )
+
+ # Group by document for logging
+ segments_by_document = defaultdict(list)
+ for segment, summary_record in segments_with_summaries:
+ segments_by_document[segment.document_id].append((segment, summary_record))
+
+ logger.info(
+ "Segments grouped into %s documents for re-vectorization",
+ len(segments_by_document),
+ )
+
+ for document_id, segment_summary_pairs in segments_by_document.items():
+ logger.info(
+ "Re-vectorizing summaries for %s segments in document %s",
+ len(segment_summary_pairs),
+ document_id,
+ )
+
+ for segment, summary_record in segment_summary_pairs:
+ try:
+ # Delete old vector
+ if summary_record.summary_index_node_id:
+ try:
+ from core.rag.datasource.vdb.vector_factory import Vector
+
+ vector = Vector(dataset)
+ vector.delete_by_ids([summary_record.summary_index_node_id])
+ except Exception as e:
+ logger.warning(
+ "Failed to delete old summary vector for segment %s: %s",
+ segment.id,
+ str(e),
+ )
+
+ # Re-vectorize with new embedding model
+ SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
+ session.commit()
+ total_segments_processed += 1
+
+ except Exception as e:
+ logger.error(
+ "Failed to re-vectorize summary for segment %s: %s",
+ segment.id,
+ str(e),
+ exc_info=True,
+ )
+ total_segments_failed += 1
+ # Update summary record with error status
+ summary_record.status = "error"
+ summary_record.error = f"Re-vectorization failed: {str(e)}"
+ session.add(summary_record)
+ session.commit()
+ continue
+
+ else:
+ # For summary_model change: require document indexing_status == "completed"
+ # Get all documents with completed indexing status
+ dataset_documents = session.scalars(
+ select(DatasetDocument).where(
+ DatasetDocument.dataset_id == dataset_id,
+ DatasetDocument.indexing_status == "completed",
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
+ )
+ ).all()
+
+ if not dataset_documents:
+ logger.info(
+ click.style(
+ f"No documents found for summary regeneration in dataset {dataset_id}",
+ fg="cyan",
+ )
+ )
+ return
+
+ logger.info(
+ "Found %s documents for summary regeneration in dataset %s",
+ len(dataset_documents),
+ dataset_id,
+ )
+
+ for dataset_document in dataset_documents:
+ # Skip qa_model documents
+ if dataset_document.doc_form == "qa_model":
+ continue
+
+ try:
+ # Get all segments with existing summaries
+ segments = (
+ session.query(DocumentSegment)
+ .join(
+ DocumentSegmentSummary,
+ DocumentSegment.id == DocumentSegmentSummary.chunk_id,
+ )
+ .where(
+ DocumentSegment.document_id == dataset_document.id,
+ DocumentSegment.dataset_id == dataset_id,
+ DocumentSegment.status == "completed",
+ DocumentSegment.enabled == True,
+ DocumentSegmentSummary.dataset_id == dataset_id,
+ )
+ .order_by(DocumentSegment.position.asc())
+ .all()
+ )
+
+ if not segments:
+ continue
+
+ logger.info(
+ "Regenerating summaries for %s segments in document %s",
+ len(segments),
+ dataset_document.id,
+ )
+
+ for segment in segments:
+ summary_record = None
+ try:
+ # Get existing summary record
+ summary_record = (
+ session.query(DocumentSegmentSummary)
+ .filter_by(
+ chunk_id=segment.id,
+ dataset_id=dataset_id,
+ )
+ .first()
+ )
+
+ if not summary_record:
+ logger.warning("Summary record not found for segment %s, skipping", segment.id)
+ continue
+
+ # Regenerate both summary content and vectors (for summary_model change)
+ SummaryIndexService.generate_and_vectorize_summary(
+ segment, dataset, summary_index_setting
+ )
+ session.commit()
+ total_segments_processed += 1
+
+ except Exception as e:
+ logger.error(
+ "Failed to regenerate summary for segment %s: %s",
+ segment.id,
+ str(e),
+ exc_info=True,
+ )
+ total_segments_failed += 1
+ # Update summary record with error status
+ if summary_record:
+ summary_record.status = "error"
+ summary_record.error = f"Regeneration failed: {str(e)}"
+ session.add(summary_record)
+ session.commit()
+ continue
+
+ except Exception as e:
+ logger.error(
+ "Failed to process document %s for summary regeneration: %s",
+ dataset_document.id,
+ str(e),
+ exc_info=True,
+ )
+ continue
+
+ end_at = time.perf_counter()
+ if regenerate_vectors_only:
+ logger.info(
+ click.style(
+ f"Summary re-vectorization completed for dataset {dataset_id}: "
+ f"{total_segments_processed} segments processed successfully, "
+ f"{total_segments_failed} segments failed, "
+ f"latency: {end_at - start_at:.2f}s",
+ fg="green",
+ )
+ )
+ else:
+ logger.info(
+ click.style(
+ f"Summary index regeneration completed for dataset {dataset_id}: "
+ f"{total_segments_processed} segments processed successfully, "
+ f"{total_segments_failed} segments failed, "
+ f"latency: {end_at - start_at:.2f}s",
+ fg="green",
+ )
+ )
+
+ except Exception:
+ logger.exception("Regenerate summary index failed for dataset %s", dataset_id)
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index 3227f6da96..817249845a 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -1,15 +1,20 @@
import logging
import time
from collections.abc import Callable
+from typing import Any, cast
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
+from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
+from configs import dify_config
+from core.db.session_factory import session_factory
from extensions.ext_database import db
+from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from models import (
ApiToken,
AppAnnotationHitHistory,
@@ -40,6 +45,7 @@ from models.workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
+ WorkflowArchiveLog,
)
from repositories.factory import DifyAPIRepositoryFactory
@@ -64,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_app_workflow_runs(tenant_id, app_id)
_delete_app_workflow_node_executions(tenant_id, app_id)
_delete_app_workflow_app_logs(tenant_id, app_id)
+ if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED:
+ _delete_app_workflow_archive_logs(tenant_id, app_id)
+ _delete_archived_workflow_run_files(tenant_id, app_id)
_delete_app_conversations(tenant_id, app_id)
_delete_app_messages(tenant_id, app_id)
_delete_workflow_tool_providers(tenant_id, app_id)
@@ -77,7 +86,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_workflow_webhook_triggers(tenant_id, app_id)
_delete_workflow_schedule_plans(tenant_id, app_id)
_delete_workflow_trigger_logs(tenant_id, app_id)
-
end_at = time.perf_counter()
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
except SQLAlchemyError as e:
@@ -89,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
- def del_model_config(model_config_id: str):
- db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
+ def del_model_config(session, model_config_id: str):
+ session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@@ -101,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
- def del_site(site_id: str):
- db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
+ def del_site(session, site_id: str):
+ session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
@@ -113,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
- def del_mcp_server(mcp_server_id: str):
- db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
+ def del_mcp_server(session, mcp_server_id: str):
+ session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@@ -125,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
- def del_api_token(api_token_id: str):
- db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
+ def del_api_token(session, api_token_id: str):
+ session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
@@ -137,8 +145,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
- def del_installed_app(installed_app_id: str):
- db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
+ def del_installed_app(session, installed_app_id: str):
+ session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -149,10 +157,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
- def del_recommended_app(recommended_app_id: str):
- db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
- synchronize_session=False
- )
+ def del_recommended_app(session, recommended_app_id: str):
+ session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
@@ -163,8 +169,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
- def del_annotation_hit_history(annotation_hit_history_id: str):
- db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
+ def del_annotation_hit_history(session, annotation_hit_history_id: str):
+ session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
@@ -175,8 +181,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
"annotation hit history",
)
- def del_annotation_setting(annotation_setting_id: str):
- db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
+ def del_annotation_setting(session, annotation_setting_id: str):
+ session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@@ -189,8 +195,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
- def del_dataset_join(dataset_join_id: str):
- db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
+ def del_dataset_join(session, dataset_join_id: str):
+ session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@@ -201,8 +207,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
- def del_workflow(workflow_id: str):
- db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
+ def del_workflow(session, workflow_id: str):
+ session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -241,10 +247,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
- def del_workflow_app_log(workflow_app_log_id: str):
- db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
- synchronize_session=False
- )
+ def del_workflow_app_log(session, workflow_app_log_id: str):
+ session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -254,12 +258,51 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
)
-def _delete_app_conversations(tenant_id: str, app_id: str):
- def del_conversation(conversation_id: str):
- db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
+ def del_workflow_archive_log(workflow_archive_log_id: str):
+ db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
- db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
+
+ _delete_records(
+ """select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
+ {"tenant_id": tenant_id, "app_id": app_id},
+ del_workflow_archive_log,
+ "workflow archive log",
+ )
+
+
+def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
+ prefix = f"{tenant_id}/app_id={app_id}/"
+ try:
+ archive_storage = get_archive_storage()
+ except ArchiveStorageNotConfiguredError as e:
+ logger.info("Archive storage not configured, skipping archive file cleanup: %s", e)
+ return
+
+ try:
+ keys = archive_storage.list_objects(prefix)
+ except Exception:
+ logger.exception("Failed to list archive files for app %s", app_id)
+ return
+
+ deleted = 0
+ for key in keys:
+ try:
+ archive_storage.delete_object(key)
+ deleted += 1
+ except Exception:
+ logger.exception("Failed to delete archive object %s", key)
+
+ logger.info("Deleted %s archive objects for app %s", deleted, app_id)
+
+
+def _delete_app_conversations(tenant_id: str, app_id: str):
+ def del_conversation(session, conversation_id: str):
+ session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
+ synchronize_session=False
+ )
+ session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@@ -270,28 +313,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
def _delete_conversation_variables(*, app_id: str):
- stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
- with db.engine.connect() as conn:
- conn.execute(stmt)
- conn.commit()
+ with session_factory.create_session() as session:
+ stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
+ session.execute(stmt)
+ session.commit()
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
- def del_message(message_id: str):
- db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
+ def del_message(session, message_id: str):
+ session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
+ session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
+ session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
+ session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
- db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
- synchronize_session=False
- )
- db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
- db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
- db.session.query(Message).where(Message.id == message_id).delete()
+ session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
+ session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
+ session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
@@ -302,8 +343,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
- def del_tool_provider(tool_provider_id: str):
- db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
+ def del_tool_provider(session, tool_provider_id: str):
+ session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@@ -316,8 +357,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
- def del_tag_binding(tag_binding_id: str):
- db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
+ def del_tag_binding(session, tag_binding_id: str):
+ session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@@ -328,8 +369,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
- def del_end_user(end_user_id: str):
- db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
+ def del_end_user(session, end_user_id: str):
+ session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -340,10 +381,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
- def del_trace_app_config(trace_app_config_id: str):
- db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
- synchronize_session=False
- )
+ def del_trace_app_config(session, trace_app_config_id: str):
+ session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
@@ -381,14 +420,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
- with db.engine.begin() as conn:
+ with session_factory.create_session() as session:
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
- result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
+ result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
rows = list(result)
if not rows:
@@ -399,7 +438,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
# Clean up associated Offload data first
if file_ids:
- files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
+ files_deleted = _delete_draft_variable_offload_data(session, file_ids)
total_files_deleted += files_deleted
# Delete the draft variables
@@ -407,8 +446,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
- deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
- batch_deleted = deleted_result.rowcount
+ deleted_result = cast(
+ CursorResult[Any],
+ session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}),
+ )
+ batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0)
total_deleted += batch_deleted
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
@@ -423,7 +465,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
return total_deleted
-def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
+def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
"""
Delete Offload data associated with WorkflowDraftVariable file_ids.
@@ -434,7 +476,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
4. Deletes WorkflowDraftVariableFile records
Args:
- conn: Database connection
+ session: Database connection
file_ids: List of WorkflowDraftVariableFile IDs
Returns:
@@ -450,12 +492,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
try:
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
query_sql = """
- SELECT wdvf.id, uf.key, uf.id as upload_file_id
- FROM workflow_draft_variable_files wdvf
- JOIN upload_files uf ON wdvf.upload_file_id = uf.id
- WHERE wdvf.id IN :file_ids
- """
- result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
+ SELECT wdvf.id, uf.key, uf.id as upload_file_id
+ FROM workflow_draft_variable_files wdvf
+ JOIN upload_files uf ON wdvf.upload_file_id = uf.id
+ WHERE wdvf.id IN :file_ids \
+ """
+ result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
file_records = list(result)
# Delete from object storage and collect upload file IDs
@@ -473,17 +515,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
# Delete UploadFile records
if upload_file_ids:
delete_upload_files_sql = """
- DELETE FROM upload_files
- WHERE id IN :upload_file_ids
- """
- conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
+ DELETE \
+ FROM upload_files
+ WHERE id IN :upload_file_ids \
+ """
+ session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
# Delete WorkflowDraftVariableFile records
delete_variable_files_sql = """
- DELETE FROM workflow_draft_variable_files
- WHERE id IN :file_ids
- """
- conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
+ DELETE \
+ FROM workflow_draft_variable_files
+ WHERE id IN :file_ids \
+ """
+ session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
except Exception:
logging.exception("Error deleting draft variable offload data:")
@@ -493,8 +537,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
def _delete_app_triggers(tenant_id: str, app_id: str):
- def del_app_trigger(trigger_id: str):
- db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
+ def del_app_trigger(session, trigger_id: str):
+ session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -505,8 +549,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
- def del_plugin_trigger(trigger_id: str):
- db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
+ def del_plugin_trigger(session, trigger_id: str):
+ session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
)
@@ -519,8 +563,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
- def del_webhook_trigger(trigger_id: str):
- db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
+ def del_webhook_trigger(session, trigger_id: str):
+ session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
)
@@ -533,10 +577,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
- def del_schedule_plan(plan_id: str):
- db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
- synchronize_session=False
- )
+ def del_schedule_plan(session, plan_id: str):
+ session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -547,8 +589,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
- def del_trigger_log(log_id: str):
- db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
+ def del_trigger_log(session, log_id: str):
+ session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -560,18 +602,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
- with db.engine.begin() as conn:
- rs = conn.execute(sa.text(query_sql), params)
- if rs.rowcount == 0:
+ with session_factory.create_session() as session:
+ rs = session.execute(sa.text(query_sql), params)
+ rows = rs.fetchall()
+ if not rows:
break
- for i in rs:
+ for i in rows:
record_id = str(i.id)
try:
- delete_func(record_id)
- db.session.commit()
+ delete_func(session, record_id)
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
except Exception:
logger.exception("Error occurred while deleting %s %s", name, record_id)
- continue
+ # continue with next record even if one deletion fails
+ session.rollback()
+ break
+ session.commit()
+
rs.close()
diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py
index c0ab2d0b41..55259ab527 100644
--- a/api/tasks/remove_document_from_index_task.py
+++ b/api/tasks/remove_document_from_index_task.py
@@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
@@ -25,52 +25,70 @@ def remove_document_from_index_task(document_id: str):
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).where(Document.id == document_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="red"))
- db.session.close()
- return
+ with session_factory.create_session() as session:
+ document = session.query(Document).where(Document.id == document_id).first()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="red"))
+ return
- if document.indexing_status != "completed":
- logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
- db.session.close()
- return
+ if document.indexing_status != "completed":
+ logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
+ return
- indexing_cache_key = f"document_{document.id}_indexing"
+ indexing_cache_key = f"document_{document.id}_indexing"
- try:
- dataset = document.dataset
+ try:
+ dataset = document.dataset
- if not dataset:
- raise Exception("Document has no dataset")
+ if not dataset:
+ raise Exception("Document has no dataset")
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
- index_node_ids = [segment.index_node_id for segment in segments]
- if index_node_ids:
- try:
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
- except Exception:
- logger.exception("clean dataset %s from index failed", dataset.id)
- # update segment to disable
- db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
- {
- DocumentSegment.enabled: False,
- DocumentSegment.disabled_at: naive_utc_now(),
- DocumentSegment.disabled_by: document.disabled_by,
- DocumentSegment.updated_at: naive_utc_now(),
- }
- )
- db.session.commit()
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
- end_at = time.perf_counter()
- logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
- except Exception:
- logger.exception("remove document from index failed")
- if not document.archived:
- document.enabled = True
- db.session.commit()
- finally:
- redis_client.delete(indexing_cache_key)
- db.session.close()
+ # Disable summary indexes for all segments in this document
+ from services.summary_index_service import SummaryIndexService
+
+ segment_ids_list = [segment.id for segment in segments]
+ if segment_ids_list:
+ try:
+ SummaryIndexService.disable_summaries_for_segments(
+ dataset=dataset,
+ segment_ids=segment_ids_list,
+ disabled_by=document.disabled_by,
+ )
+ except Exception as e:
+ logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
+
+ index_node_ids = [segment.index_node_id for segment in segments]
+ if index_node_ids:
+ try:
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
+ except Exception:
+ logger.exception("clean dataset %s from index failed", dataset.id)
+ # update segment to disable
+ session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
+ {
+ DocumentSegment.enabled: False,
+ DocumentSegment.disabled_at: naive_utc_now(),
+ DocumentSegment.disabled_by: document.disabled_by,
+ DocumentSegment.updated_at: naive_utc_now(),
+ }
+ )
+ session.commit()
+
+ end_at = time.perf_counter()
+ logger.info(
+ click.style(
+ f"Document removed from index: {document.id} latency: {end_at - start_at}",
+ fg="green",
+ )
+ )
+ except Exception:
+ logger.exception("remove document from index failed")
+ if not document.archived:
+ document.enabled = True
+ session.commit()
+ finally:
+ redis_client.delete(indexing_cache_key)
diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py
index 9d208647e6..f20b15ac83 100644
--- a/api/tasks/retry_document_indexing_task.py
+++ b/api/tasks/retry_document_indexing_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
@@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
- try:
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if not dataset:
- logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
- return
- user = db.session.query(Account).where(Account.id == user_id).first()
- if not user:
- logger.info(click.style(f"User not found: {user_id}", fg="red"))
- return
- tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
- if not tenant:
- raise ValueError("Tenant not found")
- user.current_tenant = tenant
+ with session_factory.create_session() as session:
+ try:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if not dataset:
+ logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ return
+ user = session.query(Account).where(Account.id == user_id).first()
+ if not user:
+ logger.info(click.style(f"User not found: {user_id}", fg="red"))
+ return
+ tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
+ if not tenant:
+ raise ValueError("Tenant not found")
+ user.current_tenant = tenant
- for document_id in document_ids:
- retry_indexing_cache_key = f"document_{document_id}_is_retried"
- # check document limit
- features = FeatureService.get_features(tenant.id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
- )
- except Exception as e:
+ for document_id in document_ids:
+ retry_indexing_cache_key = f"document_{document_id}_is_retried"
+ # check document limit
+ features = FeatureService.get_features(tenant.id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ document = (
+ session.query(Document)
+ .where(Document.id == document_id, Document.dataset_id == dataset_id)
+ .first()
+ )
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ redis_client.delete(retry_indexing_cache_key)
+ return
+
+ logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
- db.session.query(Document)
- .where(Document.id == document_id, Document.dataset_id == dataset_id)
- .first()
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
- if document:
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
+ return
+ try:
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+
+ segments = session.scalars(
+ select(DocumentSegment).where(DocumentSegment.document_id == document_id)
+ ).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+
+ if dataset.runtime_mode == "rag_pipeline":
+ rag_pipeline_service = RagPipelineService()
+ rag_pipeline_service.retry_error_document(dataset, document, user)
+ else:
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(retry_indexing_cache_key)
+ except Exception as ex:
document.indexing_status = "error"
- document.error = str(e)
+ document.error = str(ex)
document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- redis_client.delete(retry_indexing_cache_key)
- return
-
- logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ session.add(document)
+ session.commit()
+ logger.info(click.style(str(ex), fg="yellow"))
+ redis_client.delete(retry_indexing_cache_key)
+ logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
+ except Exception as e:
+ logger.exception(
+ "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
- return
- try:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
-
- segments = db.session.scalars(
- select(DocumentSegment).where(DocumentSegment.document_id == document_id)
- ).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
-
- if dataset.runtime_mode == "rag_pipeline":
- rag_pipeline_service = RagPipelineService()
- rag_pipeline_service.retry_error_document(dataset, document, user)
- else:
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(retry_indexing_cache_key)
- except Exception as ex:
- document.indexing_status = "error"
- document.error = str(ex)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- logger.info(click.style(str(ex), fg="yellow"))
- redis_client.delete(retry_indexing_cache_key)
- logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
- end_at = time.perf_counter()
- logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
- except Exception as e:
- logger.exception(
- "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
- )
- raise e
- finally:
- db.session.close()
+ raise e
diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py
index 0dc1d841f4..f1c8c56995 100644
--- a/api/tasks/sync_website_document_indexing_task.py
+++ b/api/tasks/sync_website_document_indexing_task.py
@@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
-from sqlalchemy import select
+from sqlalchemy import delete, select
+from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if dataset is None:
- raise ValueError("Dataset not found")
+ with session_factory.create_session() as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if dataset is None:
+ raise ValueError("Dataset not found")
- sync_indexing_cache_key = f"document_{document_id}_is_sync"
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
- try:
- if features.billing.enabled:
- vector_space = features.vector_space
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
- )
- except Exception as e:
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- )
- if document:
+ sync_indexing_cache_key = f"document_{document_id}_is_sync"
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ document = (
+ session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ )
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+ redis_client.delete(sync_indexing_cache_key)
+ return
+
+ logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
+ document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ if not document:
+ logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
+ return
+ try:
+ # clean old data
+ index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
+
+ segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ segment_ids = [segment.id for segment in segments]
+ segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
+ session.execute(segment_delete_stmt)
+ session.commit()
+
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ session.add(document)
+ session.commit()
+
+ indexing_runner = IndexingRunner()
+ indexing_runner.run([document])
+ redis_client.delete(sync_indexing_cache_key)
+ except Exception as ex:
document.indexing_status = "error"
- document.error = str(e)
+ document.error = str(ex)
document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- redis_client.delete(sync_indexing_cache_key)
- return
-
- logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
- document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- if not document:
- logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
- return
- try:
- # clean old data
- index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
-
- segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
-
- indexing_runner = IndexingRunner()
- indexing_runner.run([document])
- redis_client.delete(sync_indexing_cache_key)
- except Exception as ex:
- document.indexing_status = "error"
- document.error = str(ex)
- document.stopped_at = naive_utc_now()
- db.session.add(document)
- db.session.commit()
- logger.info(click.style(str(ex), fg="yellow"))
- redis_client.delete(sync_indexing_cache_key)
- logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
- end_at = time.perf_counter()
- logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
+ session.add(document)
+ session.commit()
+ logger.info(click.style(str(ex), fg="yellow"))
+ redis_client.delete(sync_indexing_cache_key)
+ logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
+ end_at = time.perf_counter()
+ logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py
index ee1d31aa91..d18ea2c23c 100644
--- a/api/tasks/trigger_processing_tasks.py
+++ b/api/tasks/trigger_processing_tasks.py
@@ -16,6 +16,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginInvokeError
@@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
-from extensions.ext_database import db
from models.enums import (
AppTriggerType,
CreatorUserRole,
@@ -257,7 +257,7 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py
index ed92f3f3c5..7698a1a6b8 100644
--- a/api/tasks/trigger_subscription_refresh_tasks.py
+++ b/api/tasks/trigger_subscription_refresh_tasks.py
@@ -7,9 +7,9 @@ from celery import shared_task
from sqlalchemy.orm import Session
from configs import dify_config
+from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.utils.locks import build_trigger_refresh_lock_key
-from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
@@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
try:
now: int = _now_ts()
- with Session(db.engine) as session:
+ with session_factory.create_session() as session:
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
if not subscription:
diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py
index 7d145fb50c..3b3c6e5313 100644
--- a/api/tasks/workflow_execution_tasks.py
+++ b/api/tasks/workflow_execution_tasks.py
@@ -10,11 +10,10 @@ import logging
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
-from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
@@ -46,10 +45,7 @@ def save_workflow_execution_task(
True if successful, False otherwise
"""
try:
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py
index 8f5127670f..b30a4ff15b 100644
--- a/api/tasks/workflow_node_execution_tasks.py
+++ b/api/tasks/workflow_node_execution_tasks.py
@@ -10,13 +10,12 @@ import logging
from celery import shared_task
from sqlalchemy import select
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
-from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@@ -48,10 +47,7 @@ def save_workflow_node_execution_task(
True if successful, False otherwise
"""
try:
- # Create a new session for this task
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py
index f54e02a219..8c64d3ab27 100644
--- a/api/tasks/workflow_schedule_tasks.py
+++ b/api/tasks/workflow_schedule_tasks.py
@@ -1,15 +1,14 @@
import logging
from celery import shared_task
-from sqlalchemy.orm import sessionmaker
+from core.db.session_factory import session_factory
from core.workflow.nodes.trigger_schedule.exc import (
ScheduleExecutionError,
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
-from extensions.ext_database import db
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
@@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
-
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- with session_factory() as session:
+ with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")
diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html
index a07c5f4b16..7b296519f0 100644
--- a/api/templates/invite_member_mail_template_en-US.html
+++ b/api/templates/invite_member_mail_template_en-US.html
@@ -83,7 +83,30 @@
Dear {{ to }},
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
Click the button below to log in to Dify and join the workspace.
diff --git a/api/templates/register_email_when_account_exist_template_en-US.html b/api/templates/register_email_when_account_exist_template_en-US.html
index ac5042c274..e2bb99c989 100644
--- a/api/templates/register_email_when_account_exist_template_en-US.html
+++ b/api/templates/register_email_when_account_exist_template_en-US.html
@@ -115,7 +115,30 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here:
+ If the button doesn't work, copy and paste this link into your browser:
+
+ {{ login_url }}
+
+
+
If you forgot your password, you can reset it here: Reset Password
diff --git a/api/templates/register_email_when_account_exist_template_zh-CN.html b/api/templates/register_email_when_account_exist_template_zh-CN.html
index 326b58343a..6a5bbd135b 100644
--- a/api/templates/register_email_when_account_exist_template_zh-CN.html
+++ b/api/templates/register_email_when_account_exist_template_zh-CN.html
@@ -115,7 +115,30 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录:
diff --git a/api/templates/without-brand/invite_member_mail_template_en-US.html b/api/templates/without-brand/invite_member_mail_template_en-US.html
index f9157284fa..687ece617a 100644
--- a/api/templates/without-brand/invite_member_mail_template_en-US.html
+++ b/api/templates/without-brand/invite_member_mail_template_en-US.html
@@ -92,12 +92,34 @@
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
create, and collaborate to build and operate AI applications.
Click the button below to log in to {{application_title}} and join the workspace.