Merge remote-tracking branch 'upstream/main' into feat/replayable-stream

This commit is contained in:
yunlu.wen 2026-04-21 16:05:26 +08:00
commit 91008c0329
2538 changed files with 46543 additions and 26039 deletions

View File

@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path:
- ✅ **Import real project components** directly (including base components and siblings)
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`)
- ❌ **DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
@ -325,12 +325,12 @@ For more detailed information, refer to:
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
- `web/app/components/base/button/index.spec.tsx` - Component tests
- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
- `web/vitest.config.ts` - Vitest configuration
- `web/vite.config.ts` - Vite/Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Integration vs Mocking
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations

View File

@ -2,29 +2,27 @@
## ⚠️ Important: What NOT to Mock
### DO NOT Mock Base Components
### DO NOT Mock Base Components or dify-ui Primitives
**Never mock components from `@/app/components/base/`** such as:
**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as:
- `Loading`, `Spinner`
- `Button`, `Input`, `Select`
- `Tooltip`, `Modal`, `Dropdown`
- `Icon`, `Badge`, `Tag`
- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag`
- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast`
**Why?**
- Base components will have their own dedicated tests
- These components have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
// ❌ WRONG: Don't mock base components
// ❌ WRONG: Don't mock base components or dify-ui primitives
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => <button>{children}</button> }))
// ✅ CORRECT: Import and use real base components
// ✅ CORRECT: Import and use the real components
import Loading from '@/app/components/base/loading'
import Button from '@/app/components/base/button'
import { Button } from '@langgenius/dify-ui/button'
// They will render normally in tests
```
@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ✅ DO
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly
1. **Use real project components** - Prefer importing over mocking
1. **Use real Zustand stores** - Set test state via `store.setState()`
1. **Reset mocks in `beforeEach`**, not `afterEach`
@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.)
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
```
Need to use a component in test?
├─ Is it from @/app/components/base/*?
├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*?
│ └─ YES → Import real component, DO NOT mock
├─ Is it a project component?

View File

@ -1,19 +0,0 @@
name: Anti-Slop PR Check
on:
pull_request_target:
types: [opened, edited, synchronize]
permissions:
pull-requests: write
contents: read
jobs:
anti-slop:
runs-on: ubuntu-latest
steps:
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
close-pr: false
failure-add-pr-labels: "needs-revision"

View File

@ -35,7 +35,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -84,7 +84,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -105,7 +105,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -156,7 +156,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -25,7 +25,7 @@ jobs:
- name: Check Docker Compose inputs
if: github.event_name != 'merge_group'
id: docker-compose-changes
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
docker/generate_docker_compose
@ -35,7 +35,7 @@ jobs:
- name: Check web inputs
if: github.event_name != 'merge_group'
id: web-changes
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
@ -48,7 +48,7 @@ jobs:
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@ -58,7 +58,7 @@ jobs:
python-version: "3.11"
- if: github.event_name != 'merge_group'
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
- name: Generate Docker Compose
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
@ -120,8 +120,7 @@ jobs:
- name: ESLint autofix
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
run: |
cd web
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
- if: github.event_name != 'merge_group'
uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4

View File

@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@ -76,13 +76,11 @@ jobs:
diff += '\\n\\n... (truncated) ...';
}
const body = diff.trim()
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
if (diff.trim()) {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
});
}

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true

View File

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true

View File

@ -25,7 +25,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: false
python-version: "3.12"
@ -73,10 +73,12 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
e2e/**
sdks/nodejs-client/**
packages/**
package.json
pnpm-lock.yaml
@ -93,16 +95,16 @@ jobs:
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
path: .eslintcache
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
working-directory: .
run: vp run lint:ci
- name: Web tsslint
@ -112,7 +114,7 @@ jobs:
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
working-directory: .
run: vp run type-check
- name: Web dead code check
@ -122,9 +124,9 @@ jobs:
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: web/.eslintcache
path: .eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
@ -140,7 +142,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
**.sh

View File

@ -30,7 +30,7 @@ jobs:
persist-credentials: false
- name: Use Node.js
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
node-version: 22
cache: ''

View File

@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -65,7 +65,7 @@ jobs:
# tiflash
- name: Set up Full Vector Store Matrix
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@ -33,7 +33,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -62,7 +62,7 @@ jobs:
# tiflash
- name: Set up Vector Stores for Smoke Coverage
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@ -28,7 +28,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -89,3 +89,37 @@ jobs:
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
dify-ui-test:
name: dify-ui Tests
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./packages/dify-ui
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Install Chromium for Browser Mode
run: vp exec playwright install --with-deps chromium
- name: Run dify-ui tests
run: vp test run --coverage --silent=passed-only
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: packages/dify-ui/coverage
flags: dify-ui
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

7
.gitignore vendored
View File

@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
!.vscode/settings.example.json
!.vscode/README.md
api/.vscode
# vscode Code History Extension
@ -236,9 +237,15 @@ scripts/stress-test/reports/
.playwright-mcp/
.serena/
# vitest browser mode attachments (failure screenshots, traces, etc.)
.vitest-attachments/
**/__screenshots__/
# settings
*.local.json
*.local.md
# Code Agent Folder
.qoder/*
.eslintcache

View File

@ -56,44 +56,9 @@ if $api_modified; then
fi
fi
if $web_modified; then
if $skip_web_checks; then
echo "Git operation in progress, skipping web checks"
exit 0
fi
echo "Running ESLint on web module"
if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then
web_ts_modified=false
else
ts_diff_status=$?
if [ $ts_diff_status -eq 1 ]; then
web_ts_modified=true
else
echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)."
exit $ts_diff_status
fi
fi
cd ./web || exit 1
vp staged
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"
if ! npm run type-check:tsgo; then
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
exit 1
fi
else
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! npm run knip; then
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
exit 1
fi
cd ../
if $skip_web_checks; then
echo "Git operation in progress, skipping web checks"
exit 0
fi
vp staged

View File

@ -1,12 +1,16 @@
{
// Disable the default formatter, use eslint instead
"prettier.enable": false,
"editor.formatOnSave": false,
"cucumber.features": [
"e2e/features/**/*.feature",
],
"cucumber.glue": [
"e2e/features/**/*.ts",
],
"tailwindCSS.experimental.configFile": "web/app/styles/globals.css",
// Auto fix
"editor.codeActionsOnSave": {
"source.fixAll.eslint": "explicit",
"source.organizeImports": "never"
},
// Silent the stylistic rules in your IDE, but still auto fix them

View File

@ -2,6 +2,7 @@ import base64
import secrets
import click
from sqlalchemy.orm import Session
from constants.languages import languages
from extensions.ext_database import db
@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm):
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = db.session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
with Session(db.engine) as session:
account = session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account = db.session.merge(account)
account.email = normalized_new_email
db.session.commit()
with Session(db.engine) as session:
account = session.merge(account)
account.email = normalized_new_email
session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))

View File

@ -0,0 +1 @@
CURRENT_APP_DSL_VERSION = "0.6.0"

View File

@ -129,6 +129,7 @@ class AppNamePayload(BaseModel):
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -729,7 +730,12 @@ class AppIconApi(Resource):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
app_model = app_service.update_app_icon(
app_model,
args.icon or "",
args.icon_background or "",
args.icon_type,
)
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")

View File

@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
return value
try:

View File

@ -18,12 +18,6 @@ from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
@ -36,19 +30,25 @@ class MCPServerUpdatePayload(BaseModel):
status: str | None = Field(default=None, description="Server status")
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class AppMCPServerResponse(ResponseModel):
id: str
name: str
server_code: str
description: str
status: str
status: AppMCPServerStatus
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
def _parse_json_string(cls, value: Any) -> Any:
def _normalize_parameters(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
@ -70,7 +70,9 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@login_required
@account_initialization_required
@setup_required
@ -85,7 +87,9 @@ class AppMCPServerController(Resource):
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@ -111,13 +115,15 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(
200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__]
)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@ -154,7 +160,7 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required

View File

@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
return str(value_type.exposed_type())
class FullContentDict(TypedDict):
@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"value_type": str(variable_file.value_type.exposed_type()),
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.exposed_type().value,
"value_type": str(v.value_type.exposed_type()),
"value": v.value,
# Do not track edited for env vars.
"edited": False,

View File

@ -50,6 +50,7 @@ from fields.dataset_fields import (
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
@ -889,7 +890,8 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/")
return {"api_base_url": normalize_api_base_url(base)}
@console_ns.route("/datasets/retrieval-setting")

View File

@ -3,18 +3,19 @@ import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
from contextlib import ExitStack
from datetime import datetime
from typing import Any, Literal, cast
import sqlalchemy as sa
from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@ -29,11 +30,9 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
from fields.dataset_fields import dataset_fields
from fields.base import ResponseModel
from fields.document_fields import (
dataset_and_document_fields,
document_fields,
document_metadata_fields,
document_status_fields,
document_with_segments_fields,
)
@ -72,27 +71,100 @@ from ..wraps import (
logger = logging.getLogger(__name__)
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = get_or_create_model("Dataset", dataset_fields)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_model = get_or_create_model("Document", document_fields_copy)
def _normalize_enum(value: Any) -> Any:
if isinstance(value, str) or value is None:
return value
return getattr(value, "value", value)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DatasetResponse(ResponseModel):
id: str
name: str
description: str | None = None
permission: str | None = None
data_source_type: str | None = None
indexing_technique: str | None = None
created_by: str | None = None
created_at: int | None = None
@field_validator("data_source_type", "indexing_technique", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DocumentMetadataResponse(ResponseModel):
id: str
name: str
type: str
value: str | None = None
class DocumentResponse(ResponseModel):
id: str
position: int | None = None
data_source_type: str | None = None
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
data_source_detail_dict: Any = None
dataset_process_rule_id: str | None = None
name: str
created_from: str | None = None
created_by: str | None = None
created_at: int | None = None
tokens: int | None = None
indexing_status: str | None = None
error: str | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
archived: bool | None = None
display_status: str | None = None
word_count: int | None = None
hit_count: int | None = None
doc_form: str | None = None
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
summary_index_status: str | None = None
need_summary: bool | None = None
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> Any:
return _normalize_enum(value)
@field_validator("doc_metadata", mode="before")
@classmethod
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
if value is None:
return []
return value
@field_validator("created_at", "disabled_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class DocumentWithSegmentsResponse(DocumentResponse):
process_rule_dict: Any = None
completed_segments: int | None = None
total_segments: int | None = None
class DatasetAndDocumentResponse(ResponseModel):
dataset: DatasetResponse
documents: list[DocumentResponse]
batch: str
class DocumentRetryPayload(BaseModel):
@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentMetadataUpdatePayload(BaseModel):
doc_type: str | None = None
doc_metadata: Any = None
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
@ -124,7 +201,13 @@ register_schema_models(
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload,
DatasetResponse,
DocumentMetadataResponse,
DocumentResponse,
DocumentWithSegmentsResponse,
DatasetAndDocumentResponse,
)
@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return {"dataset": dataset, "documents": documents, "batch": batch}
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@setup_required
@login_required
@ -426,12 +511,13 @@ class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(
201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__]
)
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@ -479,9 +565,9 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
return DatasetAndDocumentResponse.model_validate(
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource):
@console_ns.doc("update_document_metadata")
@console_ns.doc(description="Update document metadata")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.expect(
console_ns.model(
"UpdateDocumentMetadataRequest",
{
"doc_type": fields.String(description="Document type"),
"doc_metadata": fields.Raw(description="Document metadata"),
},
)
)
@console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
@console_ns.response(200, "Document metadata updated successfully")
@console_ns.response(404, "Document not found")
@console_ns.response(403, "Permission denied")
@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
req_data = request.get_json()
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
doc_type = req_data.get("doc_type")
doc_metadata = req_data.get("doc_metadata")
doc_type = req_data.doc_type
doc_metadata = req_data.doc_metadata
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_model)
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return document
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")

View File

@ -595,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
# Default to the initial phase; any legacy/unexpected client input is
# coerced back to `old_email` so we never trust the caller to declare
# later phases without a verified predecessor token.
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
# The token used to request a new-email code must come from the
# old-email verification step. This prevents the bypass described
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@ -620,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
phase=send_phase,
)
return {"result": "success", "data": token}
@ -655,12 +667,31 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Only advance tokens that were minted by the matching send-code step;
# refuse tokens that have already progressed or lack a phase marker so
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
# is strictly enforced.
phase_transitions = {
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if not isinstance(token_phase, str):
raise InvalidTokenError()
refreshed_phase = phase_transitions.get(token_phase)
if refreshed_phase is None:
raise InvalidTokenError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
# Refresh token data by generating a new token that carries the
# upgraded phase so later steps can check it.
_, new_token = AccountService.generate_change_email_token(
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
user_email,
code=args.code,
old_email=token_data.get("old_email"),
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
@ -690,13 +721,29 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args.token)
# Only tokens that completed both verification phases may be used to
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
# the initial send-code step could be replayed directly here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
raise InvalidTokenError()
# Bind the new email to the token that was mailed and verified, so a
# verified token cannot be reused with a different `new_email` value.
token_email = reset_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != normalized_new_email:
raise InvalidTokenError()
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
# Revoke only after all checks pass so failed attempts don't burn a
# legitimately verified token.
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(

View File

@ -1131,6 +1131,14 @@ class ToolMCPAuthApi(Resource):
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
parsed = urlparse(server_url)
sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}"
logger.warning(
"MCP authorization failed for provider %s (url=%s)",
provider_id,
sanitized_url,
exc_info=True,
)
raise ValueError(f"Failed to connect to MCP server: {e}") from e

View File

@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel):
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user
Get current user.
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id.
As a result, it could only be considered as an end user id. Even when a
concrete end-user ID is supplied, lookups must stay tenant-scoped so one
tenant cannot bind another tenant's user record into the plugin request
context.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
user_model = session.get(EndUser, user_id)
user_model = session.scalar(
select(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.limit(1)
)
if not user_model:
user_model = EndUser(

View File

@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:

View File

@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile

View File

@ -299,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
tool_instance = tool_instances.get(prompt_tool.name)
if tool_instance:
self.update_prompt_message_tool(tool_instance, prompt_tool)
iteration_step += 1

View File

@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class ModelConfigConverter:

View File

@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Iterator
from collections.abc import Generator # Changed from Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
@contextmanager
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
token = _current_file_access_scope.set(scope)
try:
yield

View File

@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile

View File

@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.helper.ssrf_proxy import ssrf_proxy
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)

View File

@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = runtime_state.exceptions_count
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
def _update_node_execution(
self,

View File

@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.CUSTOM,
file_type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),

View File

@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:param session: optional database session
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
# fix origin data
credential_record = session.execute(stmt).scalar_one_or_none()
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def _generate_provider_credential_name(self, session) -> str:
"""
@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credential_name = self._generate_provider_credential_name(pre_session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
credentials = self.validate_provider_credentials(credentials=credentials)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
session.flush()
if not provider_record:
# If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
credential_name=credential_name, session=pre_session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
credentials = self.validate_provider_credentials(
credentials=credentials, credential_id=credential_id, session=session
)
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
model: str,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
"""
Validate custom model credentials.
@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
)
credential_record = s.execute(stmt).scalar_one_or_none()
credential_record = session.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
:param credentials: model credentials dict
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
model=model, model_type=model_type, session=pre_session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
)
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
session.add(credential)
session.flush()
# save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
session=session,
session=pre_session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
session=session,
)
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:

View File

@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
inputs_json_str = dumps_with_segments(inputs).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded

View File

@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)

View File

@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@ -267,4 +268,47 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
return HttpResponse(
status_code=response.status_code,
headers=dict(response.headers),
content=response.content,
url=str(response.url) if response.url else None,
reason_phrase=response.reason_phrase,
fallback_text=response.text,
)
class GraphonSSRFProxy:
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
ssrf_proxy = SSRFProxy()
graphon_ssrf_proxy = GraphonSSRFProxy()

View File

@ -303,9 +303,16 @@ class StreamableHTTPTransport:
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
error_msg = (
f"MCP server URL returned 404 Not Found: {self.url} "
"— verify the server URL is correct and the server is running"
if is_initialization
else "Session terminated by server"
)
self._send_session_terminated_error(
ctx.server_to_client_queue,
message.root.id,
message=error_msg,
)
return
@ -381,12 +388,13 @@ class StreamableHTTPTransport:
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
message: str = "Session terminated by server",
):
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated by server"),
error=ErrorData(code=32600, message=message),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
server_to_client_queue.put(session_message)

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class ModelInstance:
@ -168,7 +170,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@ -193,7 +195,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@ -213,7 +215,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@ -235,7 +237,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@ -252,7 +254,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@ -277,7 +279,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@ -305,7 +307,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@ -324,7 +326,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@ -340,7 +342,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@ -357,14 +359,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke

View File

@ -1,8 +1,8 @@
from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
from pydantic import BaseModel
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
from core.ops.utils import validate_project_name, validate_url
class TracingProviderEnum(StrEnum):
@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel):
return validate_project_name(v, default_name)
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
"""
Model class for Langfuse tracing config.
"""
public_key: str
secret_key: str
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
"""
Model class for Langsmith tracing config.
"""
api_key: str
project: str
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
"""
Model class for Weave tracing config.
"""
api_key: str
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
host: str | None = None
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
# aliyun uses two URL formats, which may include a URL path
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
class MLflowConfig(BaseTracingConfig):
"""
Model class for MLflow tracing config.
"""
tracking_uri: str = "http://localhost:5000"
experiment_id: str = "0" # Default experiment id in MLflow is 0
username: str | None = None
password: str | None = None
@field_validator("tracking_uri")
@classmethod
def tracking_uri_validator(cls, v, info: ValidationInfo):
if isinstance(v, str) and v.startswith("databricks"):
raise ValueError(
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
)
return validate_url_with_path(v, "http://localhost:5000")
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
class DatabricksConfig(BaseTracingConfig):
"""
Model class for Databricks (Databricks-managed MLflow) tracing config.
"""
experiment_id: str
host: str
client_id: str | None = None
client_secret: str | None = None
personal_access_token: str | None = None
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict):
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
try:
match provider:
case TracingProviderEnum.LANGFUSE:
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
case TracingProviderEnum.LANGSMITH:
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
case TracingProviderEnum.LANGSMITH:
from dify_trace_langsmith.config import LangSmithConfig
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
case TracingProviderEnum.OPIK:
from core.ops.entities.config_entity import OpikConfig
from core.ops.opik_trace.opik_trace import OpikDataTrace
case TracingProviderEnum.OPIK:
from dify_trace_opik.config import OpikConfig
from dify_trace_opik.opik_trace import OpikDataTrace
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
case TracingProviderEnum.WEAVE:
from core.ops.entities.config_entity import WeaveConfig
from core.ops.weave_trace.weave_trace import WeaveDataTrace
case TracingProviderEnum.WEAVE:
from dify_trace_weave.config import WeaveConfig
from dify_trace_weave.weave_trace import WeaveDataTrace
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import ArizeConfig
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
from dify_trace_arize_phoenix.config import ArizeConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import PhoenixConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
from dify_trace_arize_phoenix.config import PhoenixConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.entities.config_entity import AliyunConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
from dify_trace_aliyun.config import AliyunConfig
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.MLFLOW:
from core.ops.entities.config_entity import MLflowConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.MLFLOW:
from dify_trace_mlflow.config import MLflowConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from core.ops.entities.config_entity import DatabricksConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from dify_trace_mlflow.config import DatabricksConfig
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
case TracingProviderEnum.TENCENT:
from dify_trace_tencent.config import TencentConfig
from dify_trace_tencent.tencent_trace import TencentDataTrace
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
except ImportError:
raise ImportError(f"Provider {provider} is not installed.")
provider_config_map = OpsTraceProviderConfigMap()

View File

@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_Hans
provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_US
else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")

View File

@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):

View File

@ -70,12 +70,32 @@ class ProviderManager:
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
Configuration assembly is cached per manager instance so call chains that
share one request-scoped manager can reuse the same provider graph instead
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
when a long-lived manager needs to observe writes performed within the same
instance scope.
"""
decoding_rsa_key: Any | None
decoding_cipher_rsa: Any | None
_model_runtime: ModelRuntime
_configurations_cache: dict[str, ProviderConfigurations]
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
self._configurations_cache = {}
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
"""Drop assembled provider configurations cached on this manager instance."""
if tenant_id is None:
self._configurations_cache.clear()
return
self._configurations_cache.pop(tenant_id, None)
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@ -114,6 +134,10 @@ class ProviderManager:
:param tenant_id:
:return:
"""
cached_configurations = self._configurations_cache.get(tenant_id)
if cached_configurations is not None:
return cached_configurations
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
@ -273,6 +297,8 @@ class ProviderManager:
provider_configurations[str(provider_id_entity)] = provider_configuration
self._configurations_cache[tenant_id] = provider_configurations
# Return the encapsulated object
return provider_configurations

View File

@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
if keyword_data_source_type == "database":
if dataset_keyword_table is None:
return
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:

View File

@ -1,4 +1,5 @@
import re
from collections.abc import Callable
from operator import itemgetter
from typing import cast
@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
top_k = kwargs.pop("topK", top_k)
top_k = cast(int | None, kwargs.pop("topK", top_k))
if top_k is None:
top_k = 20
cut = getattr(jieba, "cut", None)
if self._lcut:
tokens = self._lcut(sentence)
elif callable(cut):
tokens = list(cut(sentence))
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
else:
tokens = re.findall(r"\w+", sentence)
@ -108,7 +111,7 @@ class JiebaKeywordTableHandler:
sentence=text,
topK=max_keywords_per_chunk,
)
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(set(keywords)))

View File

@ -158,7 +158,7 @@ class RetrievalService:
)
if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600):
for _ in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()

View File

@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding

View File

@ -94,6 +94,7 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file
if not file_path:
@ -104,6 +105,7 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
assert upload_file is not None, "upload_file is required"
etl_type = dify_config.ETL_TYPE
extractor: BaseExtractor | None = None
if etl_type == "Unstructured":

View File

@ -3,6 +3,7 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
import inspect
import logging
import mimetypes
import os
@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
_closed: bool
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def close(self) -> None:
"""Best-effort cleanup for downloaded temporary files."""
if getattr(self, "_closed", False):
return
self._closed = True
temp_file = getattr(self, "temp_file", None)
if temp_file is None:
return
try:
close_result = temp_file.close()
if inspect.isawaitable(close_result):
close_awaitable = getattr(close_result, "close", None)
if callable(close_awaitable):
close_awaitable()
except Exception:
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
def __del__(self):
if hasattr(self, "temp_file"):
self.temp_file.close()
self.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""

View File

@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
result: LLMResult = model_instance.invoke_llm(
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,

View File

@ -4,12 +4,12 @@ from __future__ import annotations
import codecs
import re
from collections.abc import Collection
from collections.abc import Set as AbstractSet
from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T],
embedding_model_instance: ModelInstance | None,
allowed_special: Literal["all"] | set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
) -> T:
def _token_encoder(texts: list[str]) -> list[int]:
@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts]
_ = _token_encoder # kept for future token-length wiring
return cls(length_function=_character_encoder, **kwargs)

View File

@ -4,7 +4,8 @@ import copy
import logging
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable, Sequence, Set
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Set as AbstractSet
from dataclasses import dataclass
from typing import Any, Literal
@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
allowed_special: Literal["all"] | Set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
def split_text(self, text: str) -> list[str]:
def _encode(_text: str) -> list[int]:

View File

@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,

View File

@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError):
pass
class ApiToolProviderNotFoundError(ValueError):
error_code = "api_tool_provider_not_found"
provider_name: str
tenant_id: str
def __init__(self, provider_name: str, tenant_id: str):
self.provider_name = provider_name
self.tenant_id = tenant_id
super().__init__(f"api provider {provider_name} does not exist")
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
error_code = "workflow_tool_human_input_not_supported"
description = "Workflow with Human Input nodes cannot be published as a workflow tool."

View File

@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
type=get_file_type_by_mime_type(tool_file.mimetype),
file_type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),

View File

@ -1082,7 +1082,12 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.value
if not isinstance(variable_selector, list) or not all(
isinstance(selector_part, str) for selector_part in variable_selector
):
raise ToolParameterError("Variable tool input must be a variable selector")
variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value

View File

@ -41,6 +41,10 @@ def safe_json_value(v):
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.integer):
return int(v)
elif isinstance(v, np.floating):
return float(v)
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):

View File

@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke

View File

@ -105,7 +105,7 @@ class Article:
def extract_using_readabilipy(html: str):
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False)
article = Article(
title=json_article.get("title") or "",
author=json_article.get("byline") or "",

View File

@ -357,7 +357,10 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
transfer_method_value = file_dict.get("transfer_method")
if not isinstance(transfer_method_value, str):
raise ValueError("Workflow file mapping is missing a valid transfer_method")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id

View File

@ -1,8 +1,8 @@
"""Workflow-layer adapters for legacy human-input payload keys.
"""Workflow-to-Graphon adapters for persisted node payloads.
Stored workflow graphs and editor payloads may still use Dify-specific human
input recipient keys. Normalize them here before handing configs to
`graphon` so graph-owned models only see graph-neutral field names.
Stored workflow graphs and editor payloads still contain a small set of
Dify-owned field spellings and value shapes. Adapt them here before handing the
payload to Graphon so Graphon-owned models only see current contracts.
"""
from __future__ import annotations
@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
normalized = normalize_human_input_node_data_for_graph(node_data)
normalized = adapt_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
return normalized
return normalize_human_input_node_data_for_graph(normalized)
node_type = normalized.get("type")
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
return adapt_human_input_node_data_for_graph(normalized)
if node_type == BuiltinNodeTypes.TOOL:
return _adapt_tool_node_data_for_graph(normalized)
return normalized
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
if data_mapping is None:
return normalized
normalized["data"] = normalize_node_data_for_graph(data_mapping)
normalized["data"] = adapt_node_data_for_graph(data_mapping)
return normalized
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(node_data)
raw_tool_configurations = normalized.get("tool_configurations")
if not isinstance(raw_tool_configurations, Mapping):
return normalized
existing_tool_parameters = normalized.get("tool_parameters")
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
normalized_tool_configurations: dict[str, Any] = {}
found_legacy_tool_inputs = False
for name, value in raw_tool_configurations.items():
if not isinstance(value, Mapping):
normalized_tool_configurations[name] = value
continue
input_type = value.get("type")
input_value = value.get("value")
if input_type not in {"mixed", "variable", "constant"}:
normalized_tool_configurations[name] = value
continue
found_legacy_tool_inputs = True
normalized_tool_parameters.setdefault(name, dict(value))
flattened_value = _flatten_legacy_tool_configuration_value(
input_type=input_type,
input_value=input_value,
)
if flattened_value is not None:
normalized_tool_configurations[name] = flattened_value
if not found_legacy_tool_inputs:
return normalized
normalized["tool_parameters"] = normalized_tool_parameters
normalized["tool_configurations"] = normalized_tool_configurations
return normalized
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
return input_value
if (
input_type == "variable"
and isinstance(input_value, list)
and all(isinstance(item, str) for item in input_value)
):
return "{{#" + ".".join(input_value) + "#}}"
return None
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@ -291,9 +349,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
"adapt_human_input_node_data_for_graph",
"adapt_node_config_for_graph",
"adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
"normalize_human_input_node_data_for_graph",
"normalize_node_config_for_graph",
"normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]

View File

@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
"""Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_http_client = graphon_ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
# Graph configs are initially validated against permissive shared node data.
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
return node_class.validate_node_data(node_data)
validate_node_data = getattr(node_class, "validate_node_data", None)
if callable(validate_node_data):
return cast("BaseNodeData", validate_node_data(node_data))
return node_data
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .human_input_compat import (
from .human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,

View File

@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: AgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: DatasourceNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: KnowledgeIndexNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(id, config, graph_init_params, graph_runtime_state)
super().__init__(
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()

View File

@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@ -50,6 +49,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
if value is None or isinstance(value, (str, float)):
return value
if isinstance(value, int) and not isinstance(value, bool):
return value
return str(value)
def _normalize_metadata_filter_sequence_item(value: object) -> str:
return value if isinstance(value, str) else str(value)
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: KnowledgeRetrievalNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
resolved_values: list[str] = []
for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values

View File

@ -6,9 +6,9 @@ import click
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.document_index_event import document_index_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
from models.enums import IndexingStatus
@ -22,24 +22,25 @@ def handle(sender, **kwargs):
document_ids = kwargs.get("document_ids", [])
documents = []
start_at = time.perf_counter()
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
with session_factory.create_session() as session:
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = db.session.scalar(
select(Document).where(
Document.id == document_id,
Document.dataset_id == dataset_id,
document = session.scalar(
select(Document).where(
Document.id == document_id,
Document.dataset_id == dataset_id,
)
)
)
if not document:
raise NotFound("Document not found")
if not document:
raise NotFound("Document not found")
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
document.indexing_status = IndexingStatus.PARSING
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
session.commit()
with contextlib.suppress(Exception):
try:

View File

@ -1,5 +1,5 @@
from core.db.session_factory import session_factory
from events.app_event import app_was_created
from extensions.ext_database import db
from models.model import InstalledApp
@ -12,5 +12,6 @@ def handle(sender, **kwargs):
app_id=app.id,
app_owner_tenant_id=app.tenant_id,
)
db.session.add(installed_app)
db.session.commit()
with session_factory.create_session() as session:
session.add(installed_app)
session.commit()

View File

@ -1,5 +1,5 @@
from core.db.session_factory import session_factory
from events.app_event import app_was_created
from extensions.ext_database import db
from models.enums import CustomizeTokenStrategy
from models.model import Site
@ -22,6 +22,6 @@ def handle(sender, **kwargs):
created_by=app.created_by,
updated_by=app.updated_by,
)
db.session.add(site)
db.session.commit()
with session_factory.create_session() as session:
session.add(site)
session.commit()

View File

@ -10,8 +10,8 @@ from typing import Any
from sqlalchemy import select
from core.app.file_access import FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.workflow.file_reference import build_file_reference
from extensions.ext_database import db
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
from models import ToolFile, UploadFile
@ -135,29 +135,30 @@ def _build_from_local_file(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if row is None:
raise ValueError("Invalid upload file")
with session_factory.create_session() as session:
row = session.scalar(access_controller.apply_upload_file_filters(stmt))
if row is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type", "custom"),
strict_type_validation=strict_type_validation,
)
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type", "custom"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
size=row.size,
storage_key=row.key,
)
return File(
file_id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
size=row.size,
storage_key=row.key,
)
def _build_from_remote_url(
@ -179,32 +180,33 @@ def _build_from_remote_url(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if upload_file is None:
raise ValueError("Invalid upload file")
with session_factory.create_session() as session:
upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
if upload_file is None:
raise ValueError("Invalid upload file")
detected_file_type = standardize_file_type(
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
detected_file_type = standardize_file_type(
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
size=upload_file.size,
storage_key=upload_file.key,
)
return File(
file_id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
size=upload_file.size,
storage_key=upload_file.key,
)
url = mapping.get("url") or mapping.get("remote_url")
if not url:
@ -220,9 +222,9 @@ def _build_from_remote_url(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=filename,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@ -247,30 +249,31 @@ def _build_from_tool_file(
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
with session_factory.create_session() as session:
tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt))
if tool_file is None:
raise ValueError(f"ToolFile {tool_file_id} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)
return File(
file_id=mapping.get("id"),
filename=tool_file.name,
file_type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)
def _build_from_datasource_file(
@ -289,31 +292,32 @@ def _build_from_datasource_file(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
with session_factory.create_session() as session:
datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
if datasource_file is None:
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),
strict_type_validation=strict_type_validation,
)
return File(
id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
url=datasource_file.source_url,
)
return File(
file_id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
file_type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
extension=extension,
mime_type=datasource_file.mime_type,
size=datasource_file.size,
storage_key=datasource_file.key,
url=datasource_file.source_url,
)
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:

View File

@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return v.value_type.exposed_type().value
return str(v.value_type.exposed_type())
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
return value_type.exposed_type().value
return str(value_type.exposed_type())

View File

@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:

View File

@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
"value_type": value.value_type.exposed_type().value,
"value_type": str(value.value_type.exposed_type()),
"description": value.description,
}
if isinstance(value, dict):

View File

@ -1,5 +1,5 @@
import contextvars
from collections.abc import Iterator
from collections.abc import Generator # Changed from Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
@ -13,7 +13,7 @@ if TYPE_CHECKING:
def preserve_flask_contexts(
flask_app: Flask,
context_vars: contextvars.Context,
) -> Iterator[None]:
) -> Generator[None, None, None]: # Changed from Iterator[None]
"""
A context manager that handles:
1. flask-login's UserProxy copy

View File

@ -6,8 +6,8 @@ from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
)
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []

View File

@ -47,23 +47,17 @@ def _cookie_domain() -> str | None:
def _real_cookie_name(cookie_name: str) -> str:
if is_secure() and _cookie_domain() is None:
return "__Host-" + cookie_name
else:
return cookie_name
return cookie_name
def _try_extract_from_header(request: Request) -> str | None:
auth_header = request.headers.get("Authorization")
if auth_header:
if " " not in auth_header:
return None
else:
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
return None
else:
return auth_token
return None
if not auth_header or " " not in auth_header:
return None
auth_scheme, auth_token = auth_header.split(None, 1)
if auth_scheme.lower() != "bearer":
return None
return auth_token
def extract_refresh_token(request: Request) -> str | None:
@ -90,14 +84,9 @@ def extract_webapp_access_token(request: Request) -> str | None:
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
def _try_extract_passport_token_from_header(request: Request) -> str | None:
return request.headers.get(HEADER_NAME_PASSPORT)
ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request)
return ret
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) or request.headers.get(
HEADER_NAME_PASSPORT
)
def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"):
@ -209,22 +198,18 @@ def check_csrf_token(request: Request, user_id: str):
if not csrf_token:
_unauthorized()
verified = {}
try:
verified = PassportService().verify(csrf_token)
except:
except Exception:
_unauthorized()
raise # unreachable, but helps the type checker see verified is always bound
if verified.get("sub") != user_id:
_unauthorized()
exp: int | None = verified.get("exp")
if not exp:
if not exp or exp < int(datetime.now(UTC).timestamp()):
_unauthorized()
else:
time_now = int(datetime.now().timestamp())
if exp < time_now:
_unauthorized()
def generate_csrf_token(user_id: str) -> str:

3
api/libs/url_utils.py Normal file
View File

@ -0,0 +1,3 @@
def normalize_api_base_url(base_url: str) -> str:
"""Normalize a base URL to always end with /v1, avoiding double /v1 suffixes."""
return base_url.rstrip("/").removesuffix("/v1").rstrip("/") + "/v1"

View File

@ -3,6 +3,7 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -36,24 +37,24 @@ class WorkflowComment(Base):
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
position_x: Mapped[float] = mapped_column(sa.Float)
position_y: Mapped[float] = mapped_column(sa.Float)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
# Relationships
@ -143,20 +144,20 @@ class WorkflowCommentReply(Base):
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@ -187,18 +188,18 @@ class WorkflowCommentMention(Base):
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

View File

@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase):
)
class DocumentSegmentSummary(Base):
class DocumentSegmentSummary(TypeBase):
__tablename__ = "document_segment_summaries"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
@ -1725,25 +1725,40 @@ class DocumentSegmentSummary(Base):
sa.Index("document_segment_summaries_status_idx", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
status: Mapped[SummaryStatus] = mapped_column(
EnumText(SummaryStatus, length=32),
nullable=False,
server_default=sa.text("'generating'"),
default=SummaryStatus.GENERATING,
)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
def __repr__(self):

View File

@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
from core.workflow.human_input_compat import DeliveryMethodType
from core.workflow.human_input_adapter import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string

View File

@ -25,6 +25,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers
from libs.helper import generate_string # type: ignore[import-not-found]
from libs.url_utils import normalize_api_base_url
from libs.uuid_utils import uuidv7
from models.utils.file_input_compat import build_file_from_input_mapping
@ -446,7 +447,8 @@ class App(Base):
@property
def api_base_url(self) -> str:
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/")
return normalize_api_base_url(base)
@property
def tenant(self) -> Tenant | None:

View File

@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
from graphon.file import File, FileTransferMethod
from graphon.file import File, FileTransferMethod, FileType
from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
@lru_cache(maxsize=1)
@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
"""Build a graph `File` directly from serialized metadata."""
def _coerce_file_type(value: Any) -> FileType:
if isinstance(value, FileType):
return value
if isinstance(value, str):
return FileType.value_of(value)
raise ValueError("file type is required in file mapping")
mapping = dict(file_mapping)
transfer_method_value = mapping.get("transfer_method")
if isinstance(transfer_method_value, FileTransferMethod):
transfer_method = transfer_method_value
elif isinstance(transfer_method_value, str):
transfer_method = FileTransferMethod.value_of(transfer_method_value)
else:
raise ValueError("transfer_method is required in file mapping")
file_id = mapping.get("file_id")
if not isinstance(file_id, str) or not file_id:
legacy_id = mapping.get("id")
file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
related_id = resolve_file_record_id(mapping)
if related_id is None:
raw_related_id = mapping.get("related_id")
related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
remote_url = url if isinstance(url, str) and url else None
reference = mapping.get("reference")
if not isinstance(reference, str) or not reference:
reference = None
filename = mapping.get("filename")
if not isinstance(filename, str):
filename = None
extension = mapping.get("extension")
if not isinstance(extension, str):
extension = None
mime_type = mapping.get("mime_type")
if not isinstance(mime_type, str):
mime_type = None
size = mapping.get("size", -1)
if not isinstance(size, int):
size = -1
storage_key = mapping.get("storage_key")
if not isinstance(storage_key, str):
storage_key = None
tenant_id = mapping.get("tenant_id")
if not isinstance(tenant_id, str):
tenant_id = None
dify_model_identity = mapping.get("dify_model_identity")
if not isinstance(dify_model_identity, str):
dify_model_identity = FILE_MODEL_IDENTITY
tool_file_id = mapping.get("tool_file_id")
if not isinstance(tool_file_id, str):
tool_file_id = None
upload_file_id = mapping.get("upload_file_id")
if not isinstance(upload_file_id, str):
upload_file_id = None
datasource_file_id = mapping.get("datasource_file_id")
if not isinstance(datasource_file_id, str):
datasource_file_id = None
return File(
file_id=file_id,
tenant_id=tenant_id,
file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
transfer_method=transfer_method,
remote_url=remote_url,
reference=reference,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
storage_key=storage_key,
dify_model_identity=dify_model_identity,
url=remote_url,
tool_file_id=tool_file_id,
upload_file_id=upload_file_id,
datasource_file_id=datasource_file_id,
)
def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
"""Recursively rebuild serialized graph file payloads into `File` objects.
`graphon` 0.2.2 no longer accepts legacy serialized file mappings via
`model_validate_json()`. Dify keeps this recovery path at the model boundary
so historical JSON blobs remain readable without reintroducing global graph
patches or test-local coercion.
"""
if isinstance(value, list):
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
if isinstance(value, dict):
if maybe_file_object(value):
return build_file_from_mapping_without_lookup(file_mapping=value)
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
return value
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
if isinstance(url, str) and url:
mapping["remote_url"] = url
return File.model_validate(mapping)
return build_file_from_mapping_without_lookup(file_mapping=mapping)
return file_factory.build_from_mapping(
mapping=mapping,

View File

@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
from .utils.file_input_compat import build_file_from_stored_mapping
from .utils.file_input_compat import (
build_file_from_mapping_without_lookup,
build_file_from_stored_mapping,
)
logger = logging.getLogger(__name__)
@ -290,7 +293,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
return File.model_validate(normalized_file)
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
file_list.append(File.model_validate(normalized_file))
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)

View File

@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Difys API
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
## Excluding Providers
In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-<provider>` and `--group trace-<provider>` to enable specific providers.

View File

@ -0,0 +1,78 @@
# Trace providers
This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others).
Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping.
## Architecture
| Layer | Location | Role |
|--------|----------|------|
| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads |
| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class |
| Your package | `api/providers/trace/trace-<name>/` | Pydantic config + subclass of `BaseTraceInstance` |
At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype.
## What you implement
### 1. Config model (`BaseTracingConfig`)
Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate.
Fields fall into two groups used by the manager:
- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords).
- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints).
List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct.
### 2. Trace instance (`BaseTraceInstance`)
Subclass `BaseTraceInstance` and implement:
```python
def trace(self, trace_info: BaseTraceInfo) -> None:
...
```
Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including:
- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace`
- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo`
- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo`
You may ignore categories your backend does not support; existing providers often no-op unhandled types.
Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context.
### 3. Register in the API core
Upstream changes are required so Dify knows your provider exists:
1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`).
2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning:
- `config_class`: your Pydantic config type
- `secret_keys` / `other_keys`: lists of field names as above
- `trace_instance`: your `BaseTraceInstance` subclass
Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`.
If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app.
## Package layout
Each provider is a normal uv workspace member, for example:
- `api/providers/trace/trace-<name>/pyproject.toml` — project name `dify-trace-<name>`, dependencies on vendor SDKs
- `api/providers/trace/trace-<name>/src/dify_trace_<name>/``config.py`, `<name>_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`.
Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`.
## Wiring into the `api` workspace
In `api/pyproject.toml`:
1. **`[tool.uv.sources]`** — `dify-trace-<name> = { workspace = true }`
2. **`[dependency-groups]`** — add `trace-<name> = ["dify-trace-<name>"]` and include `dify-trace-<name>` in `trace-all` if it should ship with the default bundle
After changing metadata, run **`uv sync`** from `api/`.

View File

@ -0,0 +1,14 @@
[project]
name = "dify-trace-aliyun"
version = "0.0.1"
dependencies = [
# versions inherited from parent
"opentelemetry-api",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
"opentelemetry-semantic-conventions",
]
description = "Dify ops tracing provider (Aliyun)."
[tool.setuptools.packages.find]
where = ["src"]

Some files were not shown because too many files have changed in this diff Show More