mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
78240ed199
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -1,3 +1,10 @@
|
|||||||
web:
|
web:
|
||||||
- changed-files:
|
- changed-files:
|
||||||
- any-glob-to-any-file: 'web/**'
|
- any-glob-to-any-file:
|
||||||
|
- 'web/**'
|
||||||
|
- 'packages/**'
|
||||||
|
- 'package.json'
|
||||||
|
- 'pnpm-lock.yaml'
|
||||||
|
- 'pnpm-workspace.yaml'
|
||||||
|
- '.npmrc'
|
||||||
|
- '.nvmrc'
|
||||||
|
|||||||
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import { execFileSync } from 'node:child_process'
|
||||||
|
import fs from 'node:fs'
|
||||||
|
import path from 'node:path'
|
||||||
|
|
||||||
|
const repoRoot = process.cwd()
|
||||||
|
const baseSha = process.env.BASE_SHA || ''
|
||||||
|
const headSha = process.env.HEAD_SHA || ''
|
||||||
|
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||||
|
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
|
||||||
|
|
||||||
|
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||||
|
|
||||||
|
const readCurrentJson = (fileStem) => {
|
||||||
|
const filePath = englishPath(fileStem)
|
||||||
|
if (!fs.existsSync(filePath))
|
||||||
|
return null
|
||||||
|
|
||||||
|
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||||
|
}
|
||||||
|
|
||||||
|
const readBaseJson = (fileStem) => {
|
||||||
|
if (!baseSha)
|
||||||
|
return null
|
||||||
|
|
||||||
|
try {
|
||||||
|
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||||
|
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||||
|
return JSON.parse(content)
|
||||||
|
}
|
||||||
|
catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||||
|
|
||||||
|
const changes = {}
|
||||||
|
|
||||||
|
for (const fileStem of files) {
|
||||||
|
const currentJson = readCurrentJson(fileStem)
|
||||||
|
const beforeJson = readBaseJson(fileStem) || {}
|
||||||
|
const afterJson = currentJson || {}
|
||||||
|
const added = {}
|
||||||
|
const updated = {}
|
||||||
|
const deleted = []
|
||||||
|
|
||||||
|
for (const [key, value] of Object.entries(afterJson)) {
|
||||||
|
if (!(key in beforeJson)) {
|
||||||
|
added[key] = value
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!compareJson(beforeJson[key], value)) {
|
||||||
|
updated[key] = {
|
||||||
|
before: beforeJson[key],
|
||||||
|
after: value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const key of Object.keys(beforeJson)) {
|
||||||
|
if (!(key in afterJson))
|
||||||
|
deleted.push(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
changes[fileStem] = {
|
||||||
|
fileDeleted: currentJson === null,
|
||||||
|
added,
|
||||||
|
updated,
|
||||||
|
deleted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.writeFileSync(
|
||||||
|
outputPath,
|
||||||
|
JSON.stringify({
|
||||||
|
baseSha,
|
||||||
|
headSha,
|
||||||
|
files,
|
||||||
|
changes,
|
||||||
|
})
|
||||||
|
)
|
||||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -39,9 +39,11 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
web/**
|
web/**
|
||||||
|
packages/**
|
||||||
package.json
|
package.json
|
||||||
pnpm-lock.yaml
|
pnpm-lock.yaml
|
||||||
pnpm-workspace.yaml
|
pnpm-workspace.yaml
|
||||||
|
.npmrc
|
||||||
.nvmrc
|
.nvmrc
|
||||||
- name: Check api inputs
|
- name: Check api inputs
|
||||||
if: github.event_name != 'merge_group'
|
if: github.event_name != 'merge_group'
|
||||||
|
|||||||
2
.github/workflows/docker-build.yml
vendored
2
.github/workflows/docker-build.yml
vendored
@ -8,9 +8,11 @@ on:
|
|||||||
- api/Dockerfile
|
- api/Dockerfile
|
||||||
- web/docker/**
|
- web/docker/**
|
||||||
- web/Dockerfile
|
- web/Dockerfile
|
||||||
|
- packages/**
|
||||||
- package.json
|
- package.json
|
||||||
- pnpm-lock.yaml
|
- pnpm-lock.yaml
|
||||||
- pnpm-workspace.yaml
|
- pnpm-workspace.yaml
|
||||||
|
- .npmrc
|
||||||
- .nvmrc
|
- .nvmrc
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|||||||
4
.github/workflows/main-ci.yml
vendored
4
.github/workflows/main-ci.yml
vendored
@ -65,9 +65,11 @@ jobs:
|
|||||||
- 'docker/volumes/sandbox/conf/**'
|
- 'docker/volumes/sandbox/conf/**'
|
||||||
web:
|
web:
|
||||||
- 'web/**'
|
- 'web/**'
|
||||||
|
- 'packages/**'
|
||||||
- 'package.json'
|
- 'package.json'
|
||||||
- 'pnpm-lock.yaml'
|
- 'pnpm-lock.yaml'
|
||||||
- 'pnpm-workspace.yaml'
|
- 'pnpm-workspace.yaml'
|
||||||
|
- '.npmrc'
|
||||||
- '.nvmrc'
|
- '.nvmrc'
|
||||||
- '.github/workflows/web-tests.yml'
|
- '.github/workflows/web-tests.yml'
|
||||||
- '.github/actions/setup-web/**'
|
- '.github/actions/setup-web/**'
|
||||||
@ -77,9 +79,11 @@ jobs:
|
|||||||
- 'api/uv.lock'
|
- 'api/uv.lock'
|
||||||
- 'e2e/**'
|
- 'e2e/**'
|
||||||
- 'web/**'
|
- 'web/**'
|
||||||
|
- 'packages/**'
|
||||||
- 'package.json'
|
- 'package.json'
|
||||||
- 'pnpm-lock.yaml'
|
- 'pnpm-lock.yaml'
|
||||||
- 'pnpm-workspace.yaml'
|
- 'pnpm-workspace.yaml'
|
||||||
|
- '.npmrc'
|
||||||
- '.nvmrc'
|
- '.nvmrc'
|
||||||
- 'docker/docker-compose.middleware.yaml'
|
- 'docker/docker-compose.middleware.yaml'
|
||||||
- 'docker/middleware.env.example'
|
- 'docker/middleware.env.example'
|
||||||
|
|||||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@ -77,9 +77,11 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
web/**
|
web/**
|
||||||
|
packages/**
|
||||||
package.json
|
package.json
|
||||||
pnpm-lock.yaml
|
pnpm-lock.yaml
|
||||||
pnpm-workspace.yaml
|
pnpm-workspace.yaml
|
||||||
|
.npmrc
|
||||||
.nvmrc
|
.nvmrc
|
||||||
.github/workflows/style.yml
|
.github/workflows/style.yml
|
||||||
.github/actions/setup-web/**
|
.github/actions/setup-web/**
|
||||||
|
|||||||
1
.github/workflows/tool-test-sdks.yaml
vendored
1
.github/workflows/tool-test-sdks.yaml
vendored
@ -9,6 +9,7 @@ on:
|
|||||||
- package.json
|
- package.json
|
||||||
- pnpm-lock.yaml
|
- pnpm-lock.yaml
|
||||||
- pnpm-workspace.yaml
|
- pnpm-workspace.yaml
|
||||||
|
- .npmrc
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||||
|
|||||||
101
.github/workflows/translate-i18n-claude.yml
vendored
101
.github/workflows/translate-i18n-claude.yml
vendored
@ -68,89 +68,7 @@ jobs:
|
|||||||
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
|
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
|
||||||
|
|
||||||
generate_changes_json() {
|
generate_changes_json() {
|
||||||
node <<'NODE'
|
node .github/scripts/generate-i18n-changes.mjs
|
||||||
const { execFileSync } = require('node:child_process')
|
|
||||||
const fs = require('node:fs')
|
|
||||||
const path = require('node:path')
|
|
||||||
|
|
||||||
const repoRoot = process.cwd()
|
|
||||||
const baseSha = process.env.BASE_SHA || ''
|
|
||||||
const headSha = process.env.HEAD_SHA || ''
|
|
||||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
|
||||||
|
|
||||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
|
||||||
|
|
||||||
const readCurrentJson = (fileStem) => {
|
|
||||||
const filePath = englishPath(fileStem)
|
|
||||||
if (!fs.existsSync(filePath))
|
|
||||||
return null
|
|
||||||
|
|
||||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
|
||||||
}
|
|
||||||
|
|
||||||
const readBaseJson = (fileStem) => {
|
|
||||||
if (!baseSha)
|
|
||||||
return null
|
|
||||||
|
|
||||||
try {
|
|
||||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
|
||||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
|
||||||
return JSON.parse(content)
|
|
||||||
}
|
|
||||||
catch (error) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
|
||||||
|
|
||||||
const changes = {}
|
|
||||||
|
|
||||||
for (const fileStem of files) {
|
|
||||||
const currentJson = readCurrentJson(fileStem)
|
|
||||||
const beforeJson = readBaseJson(fileStem) || {}
|
|
||||||
const afterJson = currentJson || {}
|
|
||||||
const added = {}
|
|
||||||
const updated = {}
|
|
||||||
const deleted = []
|
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(afterJson)) {
|
|
||||||
if (!(key in beforeJson)) {
|
|
||||||
added[key] = value
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!compareJson(beforeJson[key], value)) {
|
|
||||||
updated[key] = {
|
|
||||||
before: beforeJson[key],
|
|
||||||
after: value,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const key of Object.keys(beforeJson)) {
|
|
||||||
if (!(key in afterJson))
|
|
||||||
deleted.push(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
changes[fileStem] = {
|
|
||||||
fileDeleted: currentJson === null,
|
|
||||||
added,
|
|
||||||
updated,
|
|
||||||
deleted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fs.writeFileSync(
|
|
||||||
'/tmp/i18n-changes.json',
|
|
||||||
JSON.stringify({
|
|
||||||
baseSha,
|
|
||||||
headSha,
|
|
||||||
files,
|
|
||||||
changes,
|
|
||||||
})
|
|
||||||
)
|
|
||||||
NODE
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
|
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
|
||||||
@ -270,7 +188,7 @@ jobs:
|
|||||||
Tool rules:
|
Tool rules:
|
||||||
- Use Read for repository files.
|
- Use Read for repository files.
|
||||||
- Use Edit for JSON updates.
|
- Use Edit for JSON updates.
|
||||||
- Use Bash only for `pnpm`.
|
- Use Bash only for `vp`.
|
||||||
- Do not use Bash for `git`, `gh`, or branch management.
|
- Do not use Bash for `git`, `gh`, or branch management.
|
||||||
|
|
||||||
Required execution plan:
|
Required execution plan:
|
||||||
@ -292,7 +210,7 @@ jobs:
|
|||||||
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
|
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
|
||||||
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
|
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
|
||||||
4. Run a scoped pre-check before editing:
|
4. Run a scoped pre-check before editing:
|
||||||
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||||
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
||||||
5. Apply translations.
|
5. Apply translations.
|
||||||
- For every target language and scoped file:
|
- For every target language and scoped file:
|
||||||
@ -300,19 +218,19 @@ jobs:
|
|||||||
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
|
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
|
||||||
- ADD missing keys.
|
- ADD missing keys.
|
||||||
- UPDATE stale translations when the English value changed.
|
- UPDATE stale translations when the English value changed.
|
||||||
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||||
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
||||||
- Match the existing terminology and register used by each locale.
|
- Match the existing terminology and register used by each locale.
|
||||||
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
||||||
6. Verify only the edited files.
|
6. Verify only the edited files.
|
||||||
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
|
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
|
||||||
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||||
- If verification fails, fix the remaining problems before continuing.
|
- If verification fails, fix the remaining problems before continuing.
|
||||||
7. Stop after the scoped locale files are updated and verification passes.
|
7. Stop after the scoped locale files are updated and verification passes.
|
||||||
- Do not create branches, commits, or pull requests.
|
- Do not create branches, commits, or pull requests.
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--max-turns 120
|
--max-turns 120
|
||||||
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
|
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
|
||||||
|
|
||||||
- name: Prepare branch metadata
|
- name: Prepare branch metadata
|
||||||
id: pr_meta
|
id: pr_meta
|
||||||
@ -354,6 +272,7 @@ jobs:
|
|||||||
- name: Create or update translation PR
|
- name: Create or update translation PR
|
||||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||||
env:
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
|
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
|
||||||
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
|
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
|
||||||
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
|
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
|
||||||
@ -402,8 +321,8 @@ jobs:
|
|||||||
'',
|
'',
|
||||||
'## Verification',
|
'## Verification',
|
||||||
'',
|
'',
|
||||||
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||||
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
|
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
|
||||||
'',
|
'',
|
||||||
'## Notes',
|
'## Notes',
|
||||||
'',
|
'',
|
||||||
|
|||||||
83
.github/workflows/trigger-i18n-sync.yml
vendored
83
.github/workflows/trigger-i18n-sync.yml
vendored
@ -42,88 +42,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||||
node <<'NODE'
|
node .github/scripts/generate-i18n-changes.mjs
|
||||||
const { execFileSync } = require('node:child_process')
|
|
||||||
const fs = require('node:fs')
|
|
||||||
const path = require('node:path')
|
|
||||||
|
|
||||||
const repoRoot = process.cwd()
|
|
||||||
const baseSha = process.env.BASE_SHA || ''
|
|
||||||
const headSha = process.env.HEAD_SHA || ''
|
|
||||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
|
||||||
|
|
||||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
|
||||||
|
|
||||||
const readCurrentJson = (fileStem) => {
|
|
||||||
const filePath = englishPath(fileStem)
|
|
||||||
if (!fs.existsSync(filePath))
|
|
||||||
return null
|
|
||||||
|
|
||||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
|
||||||
}
|
|
||||||
|
|
||||||
const readBaseJson = (fileStem) => {
|
|
||||||
if (!baseSha)
|
|
||||||
return null
|
|
||||||
|
|
||||||
try {
|
|
||||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
|
||||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
|
||||||
return JSON.parse(content)
|
|
||||||
}
|
|
||||||
catch (error) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
|
||||||
|
|
||||||
const changes = {}
|
|
||||||
|
|
||||||
for (const fileStem of files) {
|
|
||||||
const beforeJson = readBaseJson(fileStem) || {}
|
|
||||||
const afterJson = readCurrentJson(fileStem) || {}
|
|
||||||
const added = {}
|
|
||||||
const updated = {}
|
|
||||||
const deleted = []
|
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(afterJson)) {
|
|
||||||
if (!(key in beforeJson)) {
|
|
||||||
added[key] = value
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!compareJson(beforeJson[key], value)) {
|
|
||||||
updated[key] = {
|
|
||||||
before: beforeJson[key],
|
|
||||||
after: value,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const key of Object.keys(beforeJson)) {
|
|
||||||
if (!(key in afterJson))
|
|
||||||
deleted.push(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
changes[fileStem] = {
|
|
||||||
fileDeleted: readCurrentJson(fileStem) === null,
|
|
||||||
added,
|
|
||||||
updated,
|
|
||||||
deleted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fs.writeFileSync(
|
|
||||||
'/tmp/i18n-changes.json',
|
|
||||||
JSON.stringify({
|
|
||||||
baseSha,
|
|
||||||
headSha,
|
|
||||||
files,
|
|
||||||
changes,
|
|
||||||
})
|
|
||||||
)
|
|
||||||
NODE
|
|
||||||
|
|
||||||
if [ -n "$CHANGED_FILES" ]; then
|
if [ -n "$CHANGED_FILES" ]; then
|
||||||
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
||||||
|
|||||||
@ -81,8 +81,8 @@ if $web_modified; then
|
|||||||
|
|
||||||
if $web_ts_modified; then
|
if $web_ts_modified; then
|
||||||
echo "Running TypeScript type-check:tsgo"
|
echo "Running TypeScript type-check:tsgo"
|
||||||
if ! pnpm run type-check:tsgo; then
|
if ! npm run type-check:tsgo; then
|
||||||
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
|
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
@ -90,8 +90,8 @@ if $web_modified; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Running knip"
|
echo "Running knip"
|
||||||
if ! pnpm run knip; then
|
if ! npm run knip; then
|
||||||
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
|
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
79
api/controllers/common/controller_schemas.py
Normal file
79
api/controllers/common/controller_schemas.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
|
|
||||||
|
# --- Conversation schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationRenamePayload(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
auto_generate: bool = False
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_name_requirement(self):
|
||||||
|
if not self.auto_generate:
|
||||||
|
if self.name is None or not self.name.strip():
|
||||||
|
raise ValueError("name is required when auto_generate is false")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
# --- Message schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class MessageListQuery(BaseModel):
|
||||||
|
conversation_id: UUIDStrOrEmpty
|
||||||
|
first_id: UUIDStrOrEmpty | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFeedbackPayload(BaseModel):
|
||||||
|
rating: Literal["like", "dislike"] | None = None
|
||||||
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Saved message schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class SavedMessageListQuery(BaseModel):
|
||||||
|
last_id: UUIDStrOrEmpty | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
class SavedMessageCreatePayload(BaseModel):
|
||||||
|
message_id: UUIDStrOrEmpty
|
||||||
|
|
||||||
|
|
||||||
|
# --- Workflow schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultBlockConfigQuery(BaseModel):
|
||||||
|
q: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, le=99999)
|
||||||
|
limit: int = Field(default=10, ge=1, le=100)
|
||||||
|
user_id: str | None = None
|
||||||
|
named_only: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
files: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowUpdatePayload(BaseModel):
|
||||||
|
marked_name: str | None = Field(default=None, max_length=20)
|
||||||
|
marked_comment: str | None = Field(default=None, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Audio schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class TextToAudioPayload(BaseModel):
|
||||||
|
message_id: str | None = None
|
||||||
|
voice: str | None = None
|
||||||
|
text: str | None = None
|
||||||
|
streaming: bool | None = None
|
||||||
@ -7,7 +7,7 @@ from flask import request
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from graphon.enums import WorkflowExecutionStatus
|
from graphon.enums import WorkflowExecutionStatus
|
||||||
from graphon.file import helpers as file_helpers
|
from graphon.file import helpers as file_helpers
|
||||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
@ -26,9 +26,11 @@ from controllers.console.wraps import (
|
|||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.ops.ops_trace_manager import OpsTraceManager
|
from core.ops.ops_trace_manager import OpsTraceManager
|
||||||
|
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from fields.base import ResponseModel
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App, DatasetPermissionEnum, Workflow
|
from models import App, DatasetPermissionEnum, Workflow
|
||||||
from models.model import IconType
|
from models.model import IconType
|
||||||
@ -41,10 +43,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||||||
NotionIcon,
|
NotionIcon,
|
||||||
NotionInfo,
|
NotionInfo,
|
||||||
NotionPage,
|
NotionPage,
|
||||||
PreProcessingRule,
|
|
||||||
RerankingModel,
|
RerankingModel,
|
||||||
Rule,
|
|
||||||
Segmentation,
|
|
||||||
WebsiteInfo,
|
WebsiteInfo,
|
||||||
WeightKeywordSetting,
|
WeightKeywordSetting,
|
||||||
WeightModel,
|
WeightModel,
|
||||||
@ -155,16 +154,6 @@ class AppTracePayload(BaseModel):
|
|||||||
type JSONValue = Any
|
type JSONValue = Any
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
|
||||||
model_config = ConfigDict(
|
|
||||||
from_attributes=True,
|
|
||||||
extra="ignore",
|
|
||||||
populate_by_name=True,
|
|
||||||
serialize_by_alias=True,
|
|
||||||
protected_namespaces=(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||||
if isinstance(value, datetime):
|
if isinstance(value, datetime):
|
||||||
return int(value.timestamp())
|
return int(value.timestamp())
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource, fields, marshal_with
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
@ -71,7 +71,7 @@ class AppImportApi(Resource):
|
|||||||
args = AppImportPayload.model_validate(console_ns.payload)
|
args = AppImportPayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
# Import app
|
# Import app
|
||||||
account = current_user
|
account = current_user
|
||||||
@ -92,11 +92,13 @@ class AppImportApi(Resource):
|
|||||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||||
# Return appropriate status code based on result
|
# Return appropriate status code based on result
|
||||||
status = result.status
|
status = result.status
|
||||||
if status == ImportStatus.FAILED:
|
match status:
|
||||||
return result.model_dump(mode="json"), 400
|
case ImportStatus.FAILED:
|
||||||
elif status == ImportStatus.PENDING:
|
return result.model_dump(mode="json"), 400
|
||||||
return result.model_dump(mode="json"), 202
|
case ImportStatus.PENDING:
|
||||||
return result.model_dump(mode="json"), 200
|
return result.model_dump(mode="json"), 202
|
||||||
|
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
from sqlalchemy import exists, func, select
|
from sqlalchemy import exists, func, select
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
@ -59,10 +60,8 @@ class ChatMessagesQuery(BaseModel):
|
|||||||
return uuid_value(value)
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackPayload(BaseModel):
|
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
|
||||||
message_id: str = Field(..., description="Message ID")
|
message_id: str = Field(..., description="Message ID")
|
||||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
|
||||||
content: str | None = Field(default=None, description="Feedback content")
|
|
||||||
|
|
||||||
@field_validator("message_id")
|
@field_validator("message_id")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||||
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
||||||
@ -142,10 +143,6 @@ class PublishWorkflowPayload(BaseModel):
|
|||||||
marked_comment: str | None = Field(default=None, max_length=100)
|
marked_comment: str | None = Field(default=None, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
class DefaultBlockConfigQuery(BaseModel):
|
|
||||||
q: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertToWorkflowPayload(BaseModel):
|
class ConvertToWorkflowPayload(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
icon_type: str | None = None
|
icon_type: str | None = None
|
||||||
@ -153,18 +150,6 @@ class ConvertToWorkflowPayload(BaseModel):
|
|||||||
icon_background: str | None = None
|
icon_background: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class WorkflowListQuery(BaseModel):
|
|
||||||
page: int = Field(default=1, ge=1, le=99999)
|
|
||||||
limit: int = Field(default=10, ge=1, le=100)
|
|
||||||
user_id: str | None = None
|
|
||||||
named_only: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowUpdatePayload(BaseModel):
|
|
||||||
marked_name: str | None = Field(default=None, max_length=20)
|
|
||||||
marked_comment: str | None = Field(default=None, max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||||
node_id: str
|
node_id: str
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import secrets
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@ -20,35 +20,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
|||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import EmailStr, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
|
from services.entities.auth_entities import (
|
||||||
|
ForgotPasswordCheckPayload,
|
||||||
|
ForgotPasswordResetPayload,
|
||||||
|
ForgotPasswordSendPayload,
|
||||||
|
)
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordSendPayload(BaseModel):
|
|
||||||
email: EmailStr = Field(...)
|
|
||||||
language: str | None = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordCheckPayload(BaseModel):
|
|
||||||
email: EmailStr = Field(...)
|
|
||||||
code: str = Field(...)
|
|
||||||
token: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordResetPayload(BaseModel):
|
|
||||||
token: str = Field(...)
|
|
||||||
new_password: str = Field(...)
|
|
||||||
password_confirm: str = Field(...)
|
|
||||||
|
|
||||||
@field_validator("new_password", "password_confirm")
|
|
||||||
@classmethod
|
|
||||||
def validate_password(cls, value: str) -> str:
|
|
||||||
return valid_password(value)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordEmailResponse(BaseModel):
|
class ForgotPasswordEmailResponse(BaseModel):
|
||||||
result: str = Field(description="Operation result")
|
result: str = Field(description="Operation result")
|
||||||
data: str | None = Field(default=None, description="Reset token")
|
data: str | None = Field(default=None, description="Reset token")
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from libs.token import (
|
|||||||
)
|
)
|
||||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
from services.entities.auth_entities import LoginPayloadBase
|
||||||
from services.errors.account import AccountRegisterError
|
from services.errors.account import AccountRegisterError
|
||||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
@ -49,9 +50,7 @@ from services.feature_service import FeatureService
|
|||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
class LoginPayload(BaseModel):
|
class LoginPayload(LoginPayloadBase):
|
||||||
email: EmailStr = Field(..., description="Email address")
|
|
||||||
password: str = Field(..., description="Password")
|
|
||||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||||
|
|
||||||
|
|||||||
@ -83,11 +83,13 @@ class RagPipelineImportApi(Resource):
|
|||||||
|
|
||||||
# Return appropriate status code based on result
|
# Return appropriate status code based on result
|
||||||
status = result.status
|
status = result.status
|
||||||
if status == ImportStatus.FAILED:
|
match status:
|
||||||
return result.model_dump(mode="json"), 400
|
case ImportStatus.FAILED:
|
||||||
elif status == ImportStatus.PENDING:
|
return result.model_dump(mode="json"), 400
|
||||||
return result.model_dump(mode="json"), 202
|
case ImportStatus.PENDING:
|
||||||
return result.model_dump(mode="json"), 200
|
return result.model_dump(mode="json"), 202
|
||||||
|
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
@ -94,22 +95,6 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
|
|||||||
original_document_id: str | None = None
|
original_document_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class DefaultBlockConfigQuery(BaseModel):
|
|
||||||
q: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowListQuery(BaseModel):
|
|
||||||
page: int = Field(default=1, ge=1, le=99999)
|
|
||||||
limit: int = Field(default=10, ge=1, le=100)
|
|
||||||
user_id: str | None = None
|
|
||||||
named_only: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowUpdatePayload(BaseModel):
|
|
||||||
marked_name: str | None = Field(default=None, max_length=20)
|
|
||||||
marked_comment: str | None = Field(default=None, max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class NodeIdQuery(BaseModel):
|
class NodeIdQuery(BaseModel):
|
||||||
node_id: str
|
node_id: str
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,10 @@ import logging
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import TextToAudioPayload
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
@ -32,14 +32,6 @@ from .. import console_ns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TextToAudioPayload(BaseModel):
|
|
||||||
message_id: str | None = None
|
|
||||||
voice: str | None = None
|
|
||||||
text: str | None = None
|
|
||||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_model(console_ns, TextToAudioPayload)
|
register_schema_model(console_ns, TextToAudioPayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.explore.error import NotChatAppError
|
from controllers.console.explore.error import NotChatAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
|
|||||||
pinned: bool | None = None
|
pinned: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenamePayload(BaseModel):
|
|
||||||
name: str | None = None
|
|
||||||
auto_generate: bool = False
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_name_requirement(self):
|
|
||||||
if not self.auto_generate:
|
|
||||||
if self.name is None or not self.name.strip():
|
|
||||||
raise ValueError("name is required when auto_generate is false")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,10 @@ from typing import Literal
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
AppMoreLikeThisDisabledError,
|
AppMoreLikeThisDisabledError,
|
||||||
@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
|||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models.enums import FeedbackRating
|
from models.enums import FeedbackRating
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
@ -44,17 +44,6 @@ from .. import console_ns
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MessageListQuery(BaseModel):
|
|
||||||
conversation_id: UUIDStrOrEmpty
|
|
||||||
first_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackPayload(BaseModel):
|
|
||||||
rating: Literal["like", "dislike"] | None = None
|
|
||||||
content: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class MoreLikeThisQuery(BaseModel):
|
class MoreLikeThisQuery(BaseModel):
|
||||||
response_mode: Literal["blocking", "streaming"]
|
response_mode: Literal["blocking", "streaming"]
|
||||||
|
|
||||||
|
|||||||
@ -1,28 +1,18 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.explore.error import NotCompletionAppError
|
from controllers.console.explore.error import NotCompletionAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListQuery(BaseModel):
|
|
||||||
last_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageCreatePayload(BaseModel):
|
|
||||||
message_id: UUIDStrOrEmpty
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from graphon.graph_engine.manager import GraphEngineManager
|
from graphon.graph_engine.manager import GraphEngineManager
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel
|
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
CompletionRequestError,
|
CompletionRequestError,
|
||||||
@ -34,12 +33,6 @@ from .. import console_ns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunPayload(BaseModel):
|
|
||||||
inputs: dict[str, Any]
|
|
||||||
files: list[dict[str, Any]] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_model(console_ns, WorkflowRunPayload)
|
register_schema_model(console_ns, WorkflowRunPayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Tenant, TenantStatus
|
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.billing_service import BillingService, SubscriptionPlan
|
from services.billing_service import BillingService, SubscriptionPlan
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -240,8 +240,10 @@ class CustomConfigWorkspaceApi(Resource):
|
|||||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||||
|
|
||||||
custom_config_dict = {
|
custom_config_dict: TenantCustomConfigDict = {
|
||||||
"remove_webapp_brand": args.remove_webapp_brand,
|
"remove_webapp_brand": args.remove_webapp_brand
|
||||||
|
if args.remove_webapp_brand is not None
|
||||||
|
else tenant.custom_config_dict.get("remove_webapp_brand", False),
|
||||||
"replace_webapp_logo": args.replace_webapp_logo
|
"replace_webapp_logo": args.replace_webapp_logo
|
||||||
if args.replace_webapp_logo is not None
|
if args.replace_webapp_logo is not None
|
||||||
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from flask import request
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.wraps import setup_required
|
from controllers.console.wraps import setup_required
|
||||||
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
|
|||||||
|
|
||||||
account.set_tenant_id(workspace_id)
|
account.set_tenant_id(workspace_id)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
dsl_service = AppDslService(session)
|
dsl_service = AppDslService(session)
|
||||||
result = dsl_service.import_app(
|
result = dsl_service.import_app(
|
||||||
account=account,
|
account=account,
|
||||||
@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
|
|||||||
name=args.name,
|
name=args.name,
|
||||||
description=args.description,
|
description=args.description,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if result.status == ImportStatus.FAILED:
|
if result.status == ImportStatus.FAILED:
|
||||||
return result.model_dump(mode="json"), 400
|
return result.model_dump(mode="json"), 400
|
||||||
|
|||||||
@ -2,11 +2,12 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import NotChatAppError
|
from controllers.service_api.app.error import NotChatAppError
|
||||||
@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenamePayload(BaseModel):
|
|
||||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
|
||||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_name_requirement(self):
|
|
||||||
if not self.auto_generate:
|
|
||||||
if self.name is None or not self.name.strip():
|
|
||||||
raise ValueError("name is required when auto_generate is false")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariablesQuery(BaseModel):
|
class ConversationVariablesQuery(BaseModel):
|
||||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
|
|||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import NotChatAppError
|
from controllers.service_api.app.error import NotChatAppError
|
||||||
@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from models.enums import FeedbackRating
|
from models.enums import FeedbackRating
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.errors.message import (
|
from services.errors.message import (
|
||||||
@ -27,17 +26,6 @@ from services.message_service import MessageService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MessageListQuery(BaseModel):
|
|
||||||
conversation_id: UUIDStrOrEmpty
|
|
||||||
first_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackPayload(BaseModel):
|
|
||||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
|
||||||
content: str | None = Field(default=None, description="Feedback content")
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackListQuery(BaseModel):
|
class FeedbackListQuery(BaseModel):
|
||||||
page: int = Field(default=1, ge=1, description="Page number")
|
page: int = Field(default=1, ge=1, description="Page number")
|
||||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
from flask import request
|
from flask import request
|
||||||
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
|||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunPayload(BaseModel):
|
class WorkflowRunPayload(WorkflowRunPayloadBase):
|
||||||
inputs: dict[str, Any]
|
|
||||||
files: list[dict[str, Any]] | None = None
|
|
||||||
response_mode: Literal["blocking", "streaming"] | None = None
|
response_mode: Literal["blocking", "streaming"] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from controllers.service_api.wraps import (
|
|||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
)
|
)
|
||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
|
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.document_fields import document_fields, document_status_fields
|
from fields.document_fields import document_fields, document_status_fields
|
||||||
@ -40,11 +41,8 @@ from models.enums import SegmentStatus
|
|||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
KnowledgeConfig,
|
KnowledgeConfig,
|
||||||
PreProcessingRule,
|
|
||||||
ProcessRule,
|
ProcessRule,
|
||||||
RetrievalModel,
|
RetrievalModel,
|
||||||
Rule,
|
|
||||||
Segmentation,
|
|
||||||
)
|
)
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|||||||
@ -3,10 +3,11 @@ import logging
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import fields, marshal_with
|
from flask_restx import fields, marshal_with
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import field_validator
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
@ -34,12 +35,7 @@ from services.errors.audio import (
|
|||||||
from ..common.schema import register_schema_models
|
from ..common.schema import register_schema_models
|
||||||
|
|
||||||
|
|
||||||
class TextToAudioPayload(BaseModel):
|
class TextToAudioPayload(TextToAudioPayloadBase):
|
||||||
message_id: str | None = None
|
|
||||||
voice: str | None = None
|
|
||||||
text: str | None = None
|
|
||||||
streaming: bool | None = None
|
|
||||||
|
|
||||||
@field_validator("message_id")
|
@field_validator("message_id")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_message_id(cls, value: str | None) -> str | None:
|
def validate_message_id(cls, value: str | None) -> str | None:
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import NotChatAppError
|
from controllers.web.error import NotChatAppError
|
||||||
@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
|
|||||||
return uuid_value(value)
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenamePayload(BaseModel):
|
|
||||||
name: str | None = None
|
|
||||||
auto_generate: bool = False
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_name_requirement(self):
|
|
||||||
if not self.auto_generate:
|
|
||||||
if self.name is None or not self.name.strip():
|
|
||||||
raise ValueError("name is required when auto_generate is false")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import secrets
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@ -19,33 +18,15 @@ from controllers.console.error import EmailSendIpLimitError
|
|||||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import EmailStr, extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
|
from services.entities.auth_entities import (
|
||||||
|
ForgotPasswordCheckPayload,
|
||||||
class ForgotPasswordSendPayload(BaseModel):
|
ForgotPasswordResetPayload,
|
||||||
email: EmailStr
|
ForgotPasswordSendPayload,
|
||||||
language: str | None = None
|
)
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordCheckPayload(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
code: str
|
|
||||||
token: str = Field(min_length=1)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordResetPayload(BaseModel):
|
|
||||||
token: str = Field(min_length=1)
|
|
||||||
new_password: str
|
|
||||||
password_confirm: str
|
|
||||||
|
|
||||||
@field_validator("new_password", "password_confirm")
|
|
||||||
@classmethod
|
|
||||||
def validate_password(cls, value: str) -> str:
|
|
||||||
return valid_password(value)
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||||
|
|
||||||
|
|||||||
@ -29,13 +29,11 @@ from libs.token import (
|
|||||||
)
|
)
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
from services.entities.auth_entities import LoginPayloadBase
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
|
||||||
class LoginPayload(BaseModel):
|
class LoginPayload(LoginPayloadBase):
|
||||||
email: EmailStr
|
|
||||||
password: str
|
|
||||||
|
|
||||||
@field_validator("password")
|
@field_validator("password")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_password(cls, value: str) -> str:
|
def validate_password(cls, value: str) -> str:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
|||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
|
|||||||
return uuid_value(value)
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackPayload(BaseModel):
|
|
||||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
|
||||||
content: str | None = Field(default=None, description="Feedback content")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageMoreLikeThisQuery(BaseModel):
|
class MessageMoreLikeThisQuery(BaseModel):
|
||||||
response_mode: Literal["blocking", "streaming"] = Field(
|
response_mode: Literal["blocking", "streaming"] = Field(
|
||||||
description="Response mode",
|
description="Response mode",
|
||||||
|
|||||||
@ -138,12 +138,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
|||||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
if auth_type == WebAppAuthType.PUBLIC:
|
match auth_type:
|
||||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
case WebAppAuthType.PUBLIC:
|
||||||
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||||
raise WebAppAuthRequiredError("Please login as external user.")
|
case WebAppAuthType.EXTERNAL:
|
||||||
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
if user_auth_type != "external":
|
||||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
raise WebAppAuthRequiredError("Please login as external user.")
|
||||||
|
case WebAppAuthType.INTERNAL:
|
||||||
|
if user_auth_type != "internal":
|
||||||
|
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||||
|
|
||||||
end_user = None
|
end_user = None
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
|
|||||||
@ -1,27 +1,17 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import NotCompletionAppError
|
from controllers.web.error import NotCompletionAppError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListQuery(BaseModel):
|
|
||||||
last_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageCreatePayload(BaseModel):
|
|
||||||
message_id: UUIDStrOrEmpty
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from graphon.graph_engine.manager import GraphEngineManager
|
from graphon.graph_engine.manager import GraphEngineManager
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
|
|||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunPayload(BaseModel):
|
|
||||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
|
||||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
register_schema_models(web_ns, WorkflowRunPayload)
|
register_schema_models(web_ns, WorkflowRunPayload)
|
||||||
|
|||||||
@ -79,21 +79,18 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||||||
if not agent_scratchpad:
|
if not agent_scratchpad:
|
||||||
assistant_messages = []
|
assistant_messages = []
|
||||||
else:
|
else:
|
||||||
assistant_message = AssistantPromptMessage(content="")
|
content = ""
|
||||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
|
||||||
for unit in agent_scratchpad:
|
for unit in agent_scratchpad:
|
||||||
if unit.is_final():
|
if unit.is_final():
|
||||||
assert isinstance(assistant_message.content, str)
|
content += f"Final Answer: {unit.agent_response}"
|
||||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(assistant_message.content, str)
|
content += f"Thought: {unit.thought}\n\n"
|
||||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
|
||||||
if unit.action_str:
|
if unit.action_str:
|
||||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
content += f"Action: {unit.action_str}\n\n"
|
||||||
if unit.observation:
|
if unit.observation:
|
||||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
content += f"Observation: {unit.observation}\n\n"
|
||||||
|
|
||||||
assistant_messages = [assistant_message]
|
assistant_messages = [AssistantPromptMessage(content=content)]
|
||||||
|
|
||||||
# query messages
|
# query messages
|
||||||
query_messages = self._organize_user_query(self._query, [])
|
query_messages = self._organize_user_query(self._query, [])
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||||
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
|
|||||||
config: dict[str, Any] = Field(default_factory=dict)
|
config: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
SupportedComparisonOperator = Literal[
|
|
||||||
# for string or array
|
|
||||||
"contains",
|
|
||||||
"not contains",
|
|
||||||
"start with",
|
|
||||||
"end with",
|
|
||||||
"is",
|
|
||||||
"is not",
|
|
||||||
"empty",
|
|
||||||
"not empty",
|
|
||||||
"in",
|
|
||||||
"not in",
|
|
||||||
# for number
|
|
||||||
"=",
|
|
||||||
"≠",
|
|
||||||
">",
|
|
||||||
"<",
|
|
||||||
"≥",
|
|
||||||
"≤",
|
|
||||||
# for time
|
|
||||||
"before",
|
|
||||||
"after",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
provider: str
|
provider: str
|
||||||
name: str
|
name: str
|
||||||
@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
|
|||||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class Condition(BaseModel):
|
|
||||||
"""
|
|
||||||
Condition detail
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
comparison_operator: SupportedComparisonOperator
|
|
||||||
value: str | Sequence[str] | None | int | float = None
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataFilteringCondition(BaseModel):
|
|
||||||
"""
|
|
||||||
Metadata Filtering Condition.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logical_operator: Literal["and", "or"] | None = "and"
|
|
||||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieveConfigEntity(BaseModel):
|
class DatasetRetrieveConfigEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
Dataset Retrieve Config Entity.
|
Dataset Retrieve Config Entity.
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
|
|||||||
from graphon.variable_loader import VariableLoader
|
from graphon.variable_loader import VariableLoader
|
||||||
from graphon.variables.variables import Variable
|
from graphon.variables.variables import Variable
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
|
|
||||||
:return: List of conversation variables ready for use
|
:return: List of conversation variables ready for use
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
existing_variables = self._load_existing_conversation_variables(session)
|
existing_variables = self._load_existing_conversation_variables(session)
|
||||||
|
|
||||||
if not existing_variables:
|
if not existing_variables:
|
||||||
@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
# Convert to Variable objects for use in the workflow
|
# Convert to Variable objects for use in the workflow
|
||||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return cast(list[Variable], conversation_variables)
|
return cast(list[Variable], conversation_variables)
|
||||||
|
|
||||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from graphon.nodes import BuiltinNodeTypes
|
from graphon.nodes import BuiltinNodeTypes
|
||||||
from graphon.runtime import GraphRuntimeState
|
from graphon.runtime import GraphRuntimeState
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def _database_session(self):
|
def _database_session(self):
|
||||||
"""Context manager for database sessions."""
|
"""Context manager for database sessions."""
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||||
try:
|
yield session
|
||||||
yield session
|
|
||||||
session.commit()
|
|
||||||
except Exception:
|
|
||||||
session.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_workflow_initialized(self):
|
def _ensure_workflow_initialized(self):
|
||||||
"""Fluent validation for workflow state."""
|
"""Fluent validation for workflow state."""
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import Union
|
|||||||
from graphon.entities import WorkflowStartReason
|
from graphon.entities import WorkflowStartReason
|
||||||
from graphon.enums import WorkflowExecutionStatus
|
from graphon.enums import WorkflowExecutionStatus
|
||||||
from graphon.runtime import GraphRuntimeState
|
from graphon.runtime import GraphRuntimeState
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def _database_session(self):
|
def _database_session(self):
|
||||||
"""Context manager for database sessions."""
|
"""Context manager for database sessions."""
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||||
try:
|
yield session
|
||||||
yield session
|
|
||||||
session.commit()
|
|
||||||
except Exception:
|
|
||||||
session.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_workflow_initialized(self):
|
def _ensure_workflow_initialized(self):
|
||||||
"""Fluent validation for workflow state."""
|
"""Fluent validation for workflow state."""
|
||||||
|
|||||||
@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||||
from core.workflow.system_variables import (
|
from core.workflow.system_variables import (
|
||||||
build_bootstrap_variables,
|
build_bootstrap_variables,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(StrEnum):
|
class QueueEvent(StrEnum):
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
|
|
||||||
|
|
||||||
class AnnotationReplyAccount(BaseModel):
|
class AnnotationReplyAccount(BaseModel):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
@ -73,7 +73,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
|||||||
pool_type="paid",
|
pool_type="paid",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
update(Provider)
|
update(Provider)
|
||||||
.where(
|
.where(
|
||||||
@ -90,4 +90,3 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
session.execute(stmt)
|
session.execute(stmt)
|
||||||
session.commit()
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
|
|||||||
)
|
)
|
||||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
event = message.event
|
event = message.event
|
||||||
|
|
||||||
if isinstance(event, QueueErrorEvent):
|
if isinstance(event, QueueErrorEvent):
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||||
session.commit()
|
|
||||||
yield self.error_to_stream_response(err)
|
yield self.error_to_stream_response(err)
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||||
@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
answer=output_moderation_answer
|
answer=output_moderation_answer
|
||||||
)
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
# Save message
|
# Save message
|
||||||
self._save_message(session=session, trace_manager=trace_manager)
|
self._save_message(session=session, trace_manager=trace_manager)
|
||||||
session.commit()
|
|
||||||
message_end_resp = self._message_end_to_stream_response()
|
message_end_resp = self._message_end_to_stream_response()
|
||||||
yield message_end_resp
|
yield message_end_resp
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from sqlalchemy import select, update
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|||||||
@ -1,22 +1,3 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||||
|
|
||||||
|
__all__ = ["I18nObject", "I18nObjectDict"]
|
||||||
class I18nObject(BaseModel):
|
|
||||||
"""
|
|
||||||
Model class for i18n object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
en_US: str
|
|
||||||
zh_Hans: str | None = Field(default=None)
|
|
||||||
pt_BR: str | None = Field(default=None)
|
|
||||||
ja_JP: str | None = Field(default=None)
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def _(self):
|
|
||||||
self.zh_Hans = self.zh_Hans or self.en_US
|
|
||||||
self.pt_BR = self.pt_BR or self.en_US
|
|
||||||
self.ja_JP = self.ja_JP or self.en_US
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from yarl import URL
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.plugin.entities.oauth import OAuthSchema
|
from core.plugin.entities import OAuthSchema
|
||||||
from core.plugin.entities.parameters import (
|
from core.plugin.entities.parameters import (
|
||||||
PluginParameter,
|
PluginParameter,
|
||||||
PluginParameterOption,
|
PluginParameterOption,
|
||||||
|
|||||||
@ -1 +1,8 @@
|
|||||||
|
from core.entities.plugin_credential_type import PluginCredentialType
|
||||||
|
|
||||||
DEFAULT_PLUGIN_ID = "langgenius"
|
DEFAULT_PLUGIN_ID = "langgenius"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_PLUGIN_ID",
|
||||||
|
"PluginCredentialType",
|
||||||
|
]
|
||||||
|
|||||||
9
api/core/entities/plugin_credential_type.py
Normal file
9
api/core/entities/plugin_credential_type.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class PluginCredentialType(enum.Enum):
|
||||||
|
MODEL = 0 # must be 0 for API contract compatibility
|
||||||
|
TOOL = 1 # must be 1 for API contract compatibility
|
||||||
|
|
||||||
|
def to_number(self):
|
||||||
|
return self.value
|
||||||
@ -22,6 +22,7 @@ from sqlalchemy import func, select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
|
from core.entities import PluginCredentialType
|
||||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||||
from core.entities.provider_entities import (
|
from core.entities.provider_entities import (
|
||||||
CustomConfiguration,
|
CustomConfiguration,
|
||||||
@ -46,7 +47,6 @@ from models.provider import (
|
|||||||
TenantPreferredModelProvider,
|
TenantPreferredModelProvider,
|
||||||
)
|
)
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
Credential utility functions for checking credential existence and policy compliance.
|
Credential utility functions for checking credential existence and policy compliance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
from core.entities import PluginCredentialType
|
||||||
|
|
||||||
|
|
||||||
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
|
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
|
|||||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.entities import PluginCredentialType
|
||||||
from core.entities.embedding_type import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||||
@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
|
|||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
5
api/core/plugin/entities/__init__.py
Normal file
5
api/core/plugin/entities/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from core.plugin.entities.oauth import OAuthSchema
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OAuthSchema",
|
||||||
|
]
|
||||||
@ -1,5 +1,3 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
|
|||||||
OAuth schema
|
OAuth schema
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client_schema: Sequence[ProviderConfig] = Field(
|
client_schema: list[ProviderConfig] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="client schema like client_id, client_secret, etc.",
|
description="client schema like client_id, client_secret, etc.",
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
credentials_schema: list[ProviderConfig] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="credentials schema like access_token, refresh_token, etc.",
|
description="credentials schema like access_token, refresh_token, etc.",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -209,7 +209,10 @@ class PluginInstaller(BasePluginClient):
|
|||||||
"GET",
|
"GET",
|
||||||
f"plugin/{tenant_id}/management/decode/from_identifier",
|
f"plugin/{tenant_id}/management/decode/from_identifier",
|
||||||
PluginDecodeResponse,
|
PluginDecodeResponse,
|
||||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
params={
|
||||||
|
"plugin_unique_identifier": plugin_unique_identifier,
|
||||||
|
"PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch_plugin_installation_by_ids(
|
def fetch_plugin_installation_by_ids(
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from graphon.model_runtime.entities.provider_entities import (
|
from graphon.model_runtime.entities.provider_entities import (
|
||||||
@ -15,6 +14,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
|||||||
ProviderEntity,
|
ProviderEntity,
|
||||||
)
|
)
|
||||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -58,6 +58,8 @@ from services.feature_service import FeatureService
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from graphon.model_runtime.runtime import ModelRuntime
|
from graphon.model_runtime.runtime import ModelRuntime
|
||||||
|
|
||||||
|
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||||
|
|
||||||
|
|
||||||
class ProviderManager:
|
class ProviderManager:
|
||||||
"""
|
"""
|
||||||
@ -875,8 +877,8 @@ class ProviderManager:
|
|||||||
return {"openai_api_key": encrypted_config}
|
return {"openai_api_key": encrypted_config}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = cast(dict, json.loads(encrypted_config))
|
credentials = _credentials_adapter.validate_json(encrypted_config)
|
||||||
except JSONDecodeError:
|
except (ValueError, JSONDecodeError):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Decrypt secret variables
|
# Decrypt secret variables
|
||||||
@ -1015,7 +1017,7 @@ class ProviderManager:
|
|||||||
if not cached_provider_credentials:
|
if not cached_provider_credentials:
|
||||||
provider_credentials: dict[str, Any] = {}
|
provider_credentials: dict[str, Any] = {}
|
||||||
if provider_records and provider_records[0].encrypted_config:
|
if provider_records and provider_records[0].encrypted_config:
|
||||||
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
|
||||||
|
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self._extract_secret_variables(
|
provider_credential_secret_variables = self._extract_secret_variables(
|
||||||
@ -1162,8 +1164,10 @@ class ProviderManager:
|
|||||||
|
|
||||||
if not cached_provider_model_credentials:
|
if not cached_provider_model_credentials:
|
||||||
try:
|
try:
|
||||||
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
provider_model_credentials = _credentials_adapter.validate_json(
|
||||||
except JSONDecodeError:
|
load_balancing_model_config.encrypted_config
|
||||||
|
)
|
||||||
|
except (ValueError, JSONDecodeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get decoding rsa key and cipher for decrypting credentials
|
# Get decoding rsa key and cipher for decrypting credentials
|
||||||
@ -1176,7 +1180,7 @@ class ProviderManager:
|
|||||||
if variable in provider_model_credentials:
|
if variable in provider_model_credentials:
|
||||||
try:
|
try:
|
||||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
provider_model_credentials.get(variable),
|
provider_model_credentials.get(variable) or "",
|
||||||
self.decoding_rsa_key,
|
self.decoding_rsa_key,
|
||||||
self.decoding_cipher_rsa,
|
self.decoding_cipher_rsa,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
|
|||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||||
from core.rag.entities.metadata_entities import MetadataCondition
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.constant.query_type import QueryType
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
@ -182,7 +182,9 @@ class RetrievalService:
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
return []
|
return []
|
||||||
metadata_condition = (
|
metadata_condition = (
|
||||||
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
|
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||||
|
if metadata_filtering_conditions
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
dataset.tenant_id,
|
dataset.tenant_id,
|
||||||
|
|||||||
@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector):
|
|||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
dimension = len(embeddings[0])
|
dimension = len(embeddings[0])
|
||||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
self.analyticdb_vector.create_collection_if_not_exists(dimension)
|
||||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||||
|
return []
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
return self.analyticdb_vector.text_exists(id)
|
return self.analyticdb_vector.text_exists(id)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
@ -13,6 +13,13 @@ from core.rag.models.document import Document
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticdbClientParamsDict(TypedDict):
|
||||||
|
access_key_id: str
|
||||||
|
access_key_secret: str
|
||||||
|
region_id: str
|
||||||
|
read_timeout: int
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||||
access_key_id: str
|
access_key_id: str
|
||||||
access_key_secret: str
|
access_key_secret: str
|
||||||
@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
|||||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def to_analyticdb_client_params(self):
|
def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
|
||||||
return {
|
result: AnalyticdbClientParamsDict = {
|
||||||
"access_key_id": self.access_key_id,
|
"access_key_id": self.access_key_id,
|
||||||
"access_key_secret": self.access_key_secret,
|
"access_key_secret": self.access_key_secret,
|
||||||
"region_id": self.region_id,
|
"region_id": self.region_id,
|
||||||
"read_timeout": self.read_timeout,
|
"read_timeout": self.read_timeout,
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbVectorOpenAPI:
|
class AnalyticdbVectorOpenAPI:
|
||||||
@ -115,7 +123,7 @@ class AnalyticdbVectorOpenAPI:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||||
|
|
||||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
from Tea.exceptions import TeaException
|
from Tea.exceptions import TeaException
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ class AnalyticdbVectorBySql:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _get_cursor(self):
|
def _get_cursor(self) -> Iterator[Any]:
|
||||||
assert self.pool is not None, "Connection pool is not initialized"
|
assert self.pool is not None, "Connection pool is not initialized"
|
||||||
conn = self.pool.getconn()
|
conn = self.pool.getconn()
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
@ -130,7 +131,7 @@ class AnalyticdbVectorBySql:
|
|||||||
)
|
)
|
||||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||||
|
|
||||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
def create_collection_if_not_exists(self, embedding_dimension: int):
|
||||||
cache_key = f"vector_indexing_{self._collection_name}"
|
cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
lock_name = f"{cache_key}_lock"
|
lock_name = f"{cache_key}_lock"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field as VDBField
|
from core.rag.datasource.vdb.field import Field as VDBField
|
||||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -85,8 +85,12 @@ class BaiduVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.BAIDU
|
return VectorType.BAIDU
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
self._create_table(len(embeddings[0]))
|
self._create_table(len(embeddings[0]))
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb import QueryResult, Settings
|
from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client
|
|||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaParamsDict(TypedDict):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
ssl: bool
|
||||||
|
tenant: str
|
||||||
|
database: str
|
||||||
|
settings: Settings
|
||||||
|
|
||||||
|
|
||||||
class ChromaConfig(BaseModel):
|
class ChromaConfig(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
@ -23,14 +32,13 @@ class ChromaConfig(BaseModel):
|
|||||||
auth_provider: str | None = None
|
auth_provider: str | None = None
|
||||||
auth_credentials: str | None = None
|
auth_credentials: str | None = None
|
||||||
|
|
||||||
def to_chroma_params(self):
|
def to_chroma_params(self) -> ChromaParamsDict:
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
# auth
|
# auth
|
||||||
chroma_client_auth_provider=self.auth_provider,
|
chroma_client_auth_provider=self.auth_provider,
|
||||||
chroma_client_auth_credentials=self.auth_credentials,
|
chroma_client_auth_credentials=self.auth_credentials,
|
||||||
)
|
)
|
||||||
|
result: ChromaParamsDict = {
|
||||||
return {
|
|
||||||
"host": self.host,
|
"host": self.host,
|
||||||
"port": self.port,
|
"port": self.port,
|
||||||
"ssl": False,
|
"ssl": False,
|
||||||
@ -38,6 +46,7 @@ class ChromaConfig(BaseModel):
|
|||||||
"database": self.database,
|
"database": self.database,
|
||||||
"settings": settings,
|
"settings": settings,
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ChromaVector(BaseVector):
|
class ChromaVector(BaseVector):
|
||||||
@ -97,14 +106,15 @@ class ChromaVector(BaseVector):
|
|||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
collection = self._client.get_or_create_collection(self._collection_name)
|
collection = self._client.get_or_create_collection(self._collection_name)
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
|
results: QueryResult
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
results: QueryResult = collection.query(
|
results = collection.query(
|
||||||
query_embeddings=query_vector,
|
query_embeddings=query_vector,
|
||||||
n_results=kwargs.get("top_k", 4),
|
n_results=kwargs.get("top_k", 4),
|
||||||
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
|
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
|
||||||
# Check if results contain data
|
# Check if results contain data
|
||||||
@ -145,7 +155,10 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
|||||||
else:
|
else:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
|
index_struct_dict: VectorIndexStructDict = {
|
||||||
|
"type": VectorType.CHROMA,
|
||||||
|
"vector_store": {"class_prefix": collection_name},
|
||||||
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
|
||||||
return ChromaVector(
|
return ChromaVector(
|
||||||
@ -153,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
|||||||
config=ChromaConfig(
|
config=ChromaConfig(
|
||||||
host=dify_config.CHROMA_HOST or "",
|
host=dify_config.CHROMA_HOST or "",
|
||||||
port=dify_config.CHROMA_PORT,
|
port=dify_config.CHROMA_PORT,
|
||||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
|
||||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
|
||||||
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
|
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
|
||||||
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
|
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -20,6 +20,15 @@ from models.dataset import Dataset
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusParamsDict(TypedDict):
|
||||||
|
uri: str
|
||||||
|
token: str | None
|
||||||
|
user: str | None
|
||||||
|
password: str | None
|
||||||
|
db_name: str
|
||||||
|
analyzer_params: str | None
|
||||||
|
|
||||||
|
|
||||||
class MilvusConfig(BaseModel):
|
class MilvusConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Configuration class for Milvus connection.
|
Configuration class for Milvus connection.
|
||||||
@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
|
|||||||
raise ValueError("config MILVUS_PASSWORD is required")
|
raise ValueError("config MILVUS_PASSWORD is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def to_milvus_params(self):
|
def to_milvus_params(self) -> MilvusParamsDict:
|
||||||
"""
|
"""
|
||||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||||
"""
|
"""
|
||||||
return {
|
result: MilvusParamsDict = {
|
||||||
"uri": self.uri,
|
"uri": self.uri,
|
||||||
"token": self.token,
|
"token": self.token,
|
||||||
"user": self.user,
|
"user": self.user,
|
||||||
@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
|
|||||||
"db_name": self.database,
|
"db_name": self.database,
|
||||||
"analyzer_params": self.analyzer_params,
|
"analyzer_params": self.analyzer_params,
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class MilvusVector(BaseVector):
|
class MilvusVector(BaseVector):
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Iterable, Sequence
|
from collections.abc import Generator, Iterable, Sequence
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
@ -22,7 +22,7 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client
|
|||||||
from models.dataset import Dataset, DatasetCollectionBinding
|
from models.dataset import Dataset, DatasetCollectionBinding
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from qdrant_client import grpc # noqa
|
|
||||||
from qdrant_client.conversions import common_types
|
from qdrant_client.conversions import common_types
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
|
|
||||||
@ -94,8 +93,12 @@ class QdrantVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.QDRANT
|
return VectorType.QDRANT
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
@ -176,7 +179,7 @@ class QdrantVector(BaseVector):
|
|||||||
for batch_ids, points in self._generate_rest_batches(
|
for batch_ids, points in self._generate_rest_batches(
|
||||||
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
||||||
):
|
):
|
||||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
|
||||||
added_ids.extend(batch_ids)
|
added_ids.extend(batch_ids)
|
||||||
|
|
||||||
return added_ids
|
return added_ids
|
||||||
@ -468,7 +471,7 @@ class QdrantVector(BaseVector):
|
|||||||
|
|
||||||
def _reload_if_needed(self):
|
def _reload_if_needed(self):
|
||||||
if isinstance(self._client, QdrantLocal):
|
if isinstance(self._client, QdrantLocal):
|
||||||
self._client._load()
|
self._client._load() # pyright: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _document_from_scored_point(
|
def _document_from_scored_point(
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
Base = declarative_base() # type: Any
|
Base: Any = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class RelytConfig(BaseModel):
|
class RelytConfig(BaseModel):
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
||||||
@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -23,6 +23,13 @@ from models.dataset import Dataset
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentParamsDict(TypedDict):
|
||||||
|
url: str
|
||||||
|
username: str | None
|
||||||
|
key: str | None
|
||||||
|
timeout: float
|
||||||
|
|
||||||
|
|
||||||
class TencentConfig(BaseModel):
|
class TencentConfig(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
|
|||||||
max_upsert_batch_size: int = 128
|
max_upsert_batch_size: int = 128
|
||||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||||
|
|
||||||
def to_tencent_params(self):
|
def to_tencent_params(self) -> TencentParamsDict:
|
||||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
result: TencentParamsDict = {
|
||||||
|
"url": self.url,
|
||||||
|
"username": self.username,
|
||||||
|
"key": self.api_key,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
bm25 = BM25Encoder.default("zh")
|
bm25 = BM25Encoder.default("zh")
|
||||||
@ -83,8 +96,12 @@ class TencentVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.TENCENT
|
return VectorType.TENCENT
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def _has_collection(self) -> bool:
|
def _has_collection(self) -> bool:
|
||||||
return bool(
|
return bool(
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from sqlalchemy import select
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -91,8 +91,12 @@ class TidbOnQdrantVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.TIDB_ON_QDRANT
|
return VectorType.TIDB_ON_QDRANT
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
|
|||||||
@ -1,11 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStoreDict(TypedDict):
|
||||||
|
class_prefix: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIndexStructDict(TypedDict):
|
||||||
|
type: str
|
||||||
|
vector_store: VectorStoreDict
|
||||||
|
|
||||||
|
|
||||||
class BaseVector(ABC):
|
class BaseVector(ABC):
|
||||||
def __init__(self, collection_name: str):
|
def __init__(self, collection_name: str):
|
||||||
self._collection_name = collection_name
|
self._collection_name = collection_name
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -30,8 +30,11 @@ class AbstractVectorFactory(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
|
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
|
||||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
index_struct_dict: VectorIndexStructDict = {
|
||||||
|
"type": vector_type,
|
||||||
|
"vector_store": {"class_prefix": collection_name},
|
||||||
|
}
|
||||||
return index_struct_dict
|
return index_struct_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -184,9 +184,13 @@ class WeaviateVector(BaseVector):
|
|||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
|
||||||
def to_index_struct(self) -> dict:
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
"""Returns the index structure dictionary for persistence."""
|
"""Returns the index structure dictionary for persistence."""
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
28
api/core/rag/entities/__init__.py
Normal file
28
api/core/rag/entities/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
|
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
|
||||||
|
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
|
||||||
|
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
|
||||||
|
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
|
||||||
|
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Condition",
|
||||||
|
"DatasourceCompletedEvent",
|
||||||
|
"DatasourceErrorEvent",
|
||||||
|
"DatasourceProcessingEvent",
|
||||||
|
"DocumentContext",
|
||||||
|
"EconomySetting",
|
||||||
|
"EmbeddingSetting",
|
||||||
|
"IndexMethod",
|
||||||
|
"KeywordSetting",
|
||||||
|
"MetadataFilteringCondition",
|
||||||
|
"ParentMode",
|
||||||
|
"PreProcessingRule",
|
||||||
|
"RetrievalSourceMetadata",
|
||||||
|
"Rule",
|
||||||
|
"Segmentation",
|
||||||
|
"SupportedComparisonOperator",
|
||||||
|
"VectorSetting",
|
||||||
|
"WeightedScoreConfig",
|
||||||
|
]
|
||||||
30
api/core/rag/entities/index_entities.py
Normal file
30
api/core/rag/entities/index_entities.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Embedding Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EconomySetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Economy Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_number: int
|
||||||
|
|
||||||
|
|
||||||
|
class IndexMethod(BaseModel):
|
||||||
|
"""
|
||||||
|
Knowledge Index Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indexing_technique: Literal["high_quality", "economy"]
|
||||||
|
embedding_setting: EmbeddingSetting
|
||||||
|
economy_setting: EconomySetting
|
||||||
@ -38,9 +38,9 @@ class Condition(BaseModel):
|
|||||||
value: str | Sequence[str] | None | int | float = None
|
value: str | Sequence[str] | None | int | float = None
|
||||||
|
|
||||||
|
|
||||||
class MetadataCondition(BaseModel):
|
class MetadataFilteringCondition(BaseModel):
|
||||||
"""
|
"""
|
||||||
Metadata Condition.
|
Metadata Filtering Condition.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logical_operator: Literal["and", "or"] | None = "and"
|
logical_operator: Literal["and", "or"] | None = "and"
|
||||||
|
|||||||
27
api/core/rag/entities/processing_entities.py
Normal file
27
api/core/rag/entities/processing_entities.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ParentMode(StrEnum):
|
||||||
|
FULL_DOC = "full-doc"
|
||||||
|
PARAGRAPH = "paragraph"
|
||||||
|
|
||||||
|
|
||||||
|
class PreProcessingRule(BaseModel):
|
||||||
|
id: str
|
||||||
|
enabled: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Segmentation(BaseModel):
|
||||||
|
separator: str = "\n"
|
||||||
|
max_tokens: int
|
||||||
|
chunk_overlap: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Rule(BaseModel):
|
||||||
|
pre_processing_rules: list[PreProcessingRule] | None = None
|
||||||
|
segmentation: Segmentation | None = None
|
||||||
|
parent_mode: Literal["full-doc", "paragraph"] | None = None
|
||||||
|
subchunk_segmentation: Segmentation | None = None
|
||||||
28
api/core/rag/entities/retrieval_settings.py
Normal file
28
api/core/rag/entities/retrieval_settings.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Vector Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_weight: float
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Keyword Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_weight: float
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedScoreConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Weighted score Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_setting: VectorSetting
|
||||||
|
keyword_setting: KeywordSetting
|
||||||
@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor):
|
|||||||
|
|
||||||
def extract(self) -> list[Document]:
|
def extract(self) -> list[Document]:
|
||||||
from unstructured.__version__ import __version__ as __unstructured_version__
|
from unstructured.__version__ import __version__ as __unstructured_version__
|
||||||
from unstructured.file_utils.filetype import FileType, detect_filetype
|
from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
FileType,
|
||||||
|
detect_filetype,
|
||||||
|
)
|
||||||
|
|
||||||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||||
# check the file extension
|
# check the file extension
|
||||||
try:
|
try:
|
||||||
import magic # noqa: F401
|
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||||
|
|
||||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from core.db.session_factory import session_factory
|
|||||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
|
||||||
from .index_processor_factory import IndexProcessorFactory
|
from .index_processor_factory import IndexProcessorFactory
|
||||||
@ -61,7 +61,7 @@ class IndexProcessor:
|
|||||||
chunks: Mapping[str, Any],
|
chunks: Mapping[str, Any],
|
||||||
batch: Any,
|
batch: Any,
|
||||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||||
):
|
) -> IndexingResultDict:
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
document = session.query(Document).filter_by(id=document_id).first()
|
document = session.query(Document).filter_by(id=document_id).first()
|
||||||
if not document:
|
if not document:
|
||||||
@ -129,7 +129,7 @@ class IndexProcessor:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
result: IndexingResultDict = {
|
||||||
"dataset_id": dataset_id,
|
"dataset_id": dataset_id,
|
||||||
"dataset_name": dataset_name_value,
|
"dataset_name": dataset_name_value,
|
||||||
"batch": batch,
|
"batch": batch,
|
||||||
@ -138,6 +138,7 @@ class IndexProcessor:
|
|||||||
"created_at": created_at_value.timestamp(),
|
"created_at": created_at_value.timestamp(),
|
||||||
"display_status": "completed",
|
"display_status": "completed",
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def get_preview_output(
|
def get_preview_output(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
|||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
|
from core.rag.entities import Rule
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
@ -49,7 +50,6 @@ from models.account import Account
|
|||||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
_file_access_controller = DatabaseFileAccessController()
|
_file_access_controller = DatabaseFileAccessController()
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
|||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
|
from core.rag.entities import ParentMode, Rule
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
@ -30,7 +31,6 @@ from models import Account
|
|||||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
|||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
|
from core.rag.entities import Rule
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||||
@ -30,7 +31,6 @@ from libs import helper
|
|||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -1,16 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.rag.entities import KeywordSetting, VectorSetting
|
||||||
class VectorSetting(BaseModel):
|
|
||||||
vector_weight: float
|
|
||||||
|
|
||||||
embedding_provider_name: str
|
|
||||||
|
|
||||||
embedding_model_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordSetting(BaseModel):
|
|
||||||
keyword_weight: float
|
|
||||||
|
|
||||||
|
|
||||||
class Weights(BaseModel):
|
class Weights(BaseModel):
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp
|
|||||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from sqlalchemy import and_, func, literal, or_, select
|
from sqlalchemy import and_, func, literal, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.app.app_config.entities import (
|
from core.app.app_config.entities import (
|
||||||
DatasetEntity,
|
DatasetEntity,
|
||||||
@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
|
||||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||||
from core.rag.index_processor.constant.query_type import QueryType
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
@ -604,7 +602,7 @@ class DatasetRetrieval:
|
|||||||
planning_strategy: PlanningStrategy,
|
planning_strategy: PlanningStrategy,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataFilteringCondition | None = None,
|
||||||
):
|
):
|
||||||
tools = []
|
tools = []
|
||||||
for dataset in available_datasets:
|
for dataset in available_datasets:
|
||||||
@ -743,7 +741,7 @@ class DatasetRetrieval:
|
|||||||
reranking_enable: bool = True,
|
reranking_enable: bool = True,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataFilteringCondition | None = None,
|
||||||
attachment_ids: list[str] | None = None,
|
attachment_ids: list[str] | None = None,
|
||||||
):
|
):
|
||||||
if not available_datasets:
|
if not available_datasets:
|
||||||
@ -886,7 +884,7 @@ class DatasetRetrieval:
|
|||||||
self._send_trace_task(message_id, documents, timer)
|
self._send_trace_task(message_id, documents, timer)
|
||||||
return
|
return
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
# Collect all document_ids and batch fetch DatasetDocuments
|
# Collect all document_ids and batch fetch DatasetDocuments
|
||||||
document_ids = {
|
document_ids = {
|
||||||
doc.metadata["document_id"]
|
doc.metadata["document_id"]
|
||||||
@ -977,7 +975,6 @@ class DatasetRetrieval:
|
|||||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||||
synchronize_session=False,
|
synchronize_session=False,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
self._send_trace_task(message_id, documents, timer)
|
self._send_trace_task(message_id, documents, timer)
|
||||||
|
|
||||||
@ -1063,7 +1060,7 @@ class DatasetRetrieval:
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
all_documents: list[Document],
|
all_documents: list[Document],
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataFilteringCondition | None = None,
|
||||||
attachment_ids: list[str] | None = None,
|
attachment_ids: list[str] | None = None,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
@ -1339,7 +1336,7 @@ class DatasetRetrieval:
|
|||||||
metadata_model_config: ModelConfig,
|
metadata_model_config: ModelConfig,
|
||||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
|
||||||
document_query = select(DatasetDocument).where(
|
document_query = select(DatasetDocument).where(
|
||||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||||
DatasetDocument.indexing_status == "completed",
|
DatasetDocument.indexing_status == "completed",
|
||||||
@ -1371,7 +1368,7 @@ class DatasetRetrieval:
|
|||||||
value=filter.get("value"),
|
value=filter.get("value"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
metadata_condition = MetadataCondition(
|
metadata_condition = MetadataFilteringCondition(
|
||||||
logical_operator=metadata_filtering_conditions.logical_operator
|
logical_operator=metadata_filtering_conditions.logical_operator
|
||||||
if metadata_filtering_conditions
|
if metadata_filtering_conditions
|
||||||
else "or", # type: ignore
|
else "or", # type: ignore
|
||||||
@ -1400,7 +1397,7 @@ class DatasetRetrieval:
|
|||||||
expected_value,
|
expected_value,
|
||||||
filters,
|
filters,
|
||||||
)
|
)
|
||||||
metadata_condition = MetadataCondition(
|
metadata_condition = MetadataFilteringCondition(
|
||||||
logical_operator=metadata_filtering_conditions.logical_operator,
|
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
)
|
)
|
||||||
@ -1723,7 +1720,7 @@ class DatasetRetrieval:
|
|||||||
self,
|
self,
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
available_datasets: list[Dataset],
|
available_datasets: list[Dataset],
|
||||||
metadata_condition: MetadataCondition | None,
|
metadata_condition: MetadataFilteringCondition | None,
|
||||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||||
all_documents: list[Document],
|
all_documents: list[Document],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
@ -1,6 +1,15 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class I18nObjectDict(TypedDict):
|
||||||
|
zh_Hans: str | None
|
||||||
|
en_US: str
|
||||||
|
pt_BR: str | None
|
||||||
|
ja_JP: str | None
|
||||||
|
|
||||||
|
|
||||||
class I18nObject(BaseModel):
|
class I18nObject(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for i18n object.
|
Model class for i18n object.
|
||||||
@ -18,5 +27,11 @@ class I18nObject(BaseModel):
|
|||||||
self.ja_JP = self.ja_JP or self.en_US
|
self.ja_JP = self.ja_JP or self.en_US
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> I18nObjectDict:
|
||||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
result: I18nObjectDict = {
|
||||||
|
"zh_Hans": self.zh_Hans,
|
||||||
|
"en_US": self.en_US,
|
||||||
|
"pt_BR": self.pt_BR,
|
||||||
|
"ja_JP": self.ja_JP,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|||||||
@ -6,9 +6,20 @@ from collections.abc import Mapping
|
|||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
TypeAdapter,
|
||||||
|
ValidationInfo,
|
||||||
|
field_serializer,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.entities import OAuthSchema
|
||||||
from core.plugin.entities.parameters import (
|
from core.plugin.entities.parameters import (
|
||||||
MCPServerParameterType,
|
MCPServerParameterType,
|
||||||
PluginParameter,
|
PluginParameter,
|
||||||
@ -18,11 +29,19 @@ from core.plugin.entities.parameters import (
|
|||||||
cast_parameter_value,
|
cast_parameter_value,
|
||||||
init_frontend_parameter,
|
init_frontend_parameter,
|
||||||
)
|
)
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiIconDict(TypedDict):
|
||||||
|
background: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
emoji_icon_adapter: TypeAdapter[EmojiIconDict] = TypeAdapter(EmojiIconDict)
|
||||||
|
|
||||||
|
|
||||||
class ToolLabelEnum(StrEnum):
|
class ToolLabelEnum(StrEnum):
|
||||||
SEARCH = "search"
|
SEARCH = "search"
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
@ -410,15 +429,6 @@ class ToolEntity(BaseModel):
|
|||||||
return value or {}
|
return value or {}
|
||||||
|
|
||||||
|
|
||||||
class OAuthSchema(BaseModel):
|
|
||||||
client_schema: list[ProviderConfig] = Field(
|
|
||||||
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
|
|
||||||
)
|
|
||||||
credentials_schema: list[ProviderConfig] = Field(
|
|
||||||
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderEntity(BaseModel):
|
class ToolProviderEntity(BaseModel):
|
||||||
identity: ToolProviderIdentity
|
identity: ToolProviderIdentity
|
||||||
plugin_id: str | None = None
|
plugin_id: str | None = None
|
||||||
|
|||||||
@ -5,16 +5,19 @@ import time
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from os import listdir, path
|
from os import listdir, path
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from graphon.runtime import VariablePool
|
from graphon.runtime import VariablePool
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from typing_extensions import TypedDict
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.entities import PluginCredentialType
|
||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
@ -27,7 +30,6 @@ from core.tools.utils.uuid_utils import is_valid_uuid
|
|||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.provider_ids import ToolProviderID
|
from models.provider_ids import ToolProviderID
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -49,9 +51,11 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
|||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
|
EmojiIconDict,
|
||||||
ToolInvokeFrom,
|
ToolInvokeFrom,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
|
emoji_icon_adapter,
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
@ -72,9 +76,7 @@ class ApiProviderControllerItem(TypedDict):
|
|||||||
controller: ApiToolProviderController
|
controller: ApiToolProviderController
|
||||||
|
|
||||||
|
|
||||||
class EmojiIconDict(TypedDict):
|
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||||
background: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowToolRuntimeSpec(Protocol):
|
class WorkflowToolRuntimeSpec(Protocol):
|
||||||
@ -203,16 +205,160 @@ class ToolManager:
|
|||||||
|
|
||||||
:return: the tool
|
:return: the tool
|
||||||
"""
|
"""
|
||||||
if provider_type == ToolProviderType.BUILT_IN:
|
match provider_type:
|
||||||
# check if the builtin tool need credentials
|
case ToolProviderType.BUILT_IN:
|
||||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||||
|
|
||||||
builtin_tool = provider_controller.get_tool(tool_name)
|
builtin_tool = provider_controller.get_tool(tool_name)
|
||||||
if not builtin_tool:
|
if not builtin_tool:
|
||||||
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
|
||||||
|
|
||||||
|
if not provider_controller.need_credentials:
|
||||||
|
return builtin_tool.fork_tool_runtime(
|
||||||
|
runtime=ToolRuntime(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
credentials={},
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
tool_invoke_from=tool_invoke_from,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
builtin_provider = None
|
||||||
|
if isinstance(provider_controller, PluginToolProviderController):
|
||||||
|
provider_id_entity = ToolProviderID(provider_id)
|
||||||
|
if is_valid_uuid(credential_id):
|
||||||
|
try:
|
||||||
|
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.id == credential_id,
|
||||||
|
)
|
||||||
|
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||||
|
except Exception as e:
|
||||||
|
builtin_provider = None
|
||||||
|
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||||
|
if builtin_provider is None:
|
||||||
|
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
||||||
|
|
||||||
|
if builtin_provider is None:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
builtin_provider = session.scalar(
|
||||||
|
sa.select(BuiltinToolProvider)
|
||||||
|
.where(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||||
|
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||||
|
)
|
||||||
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||||
|
)
|
||||||
|
if builtin_provider is None:
|
||||||
|
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||||
|
else:
|
||||||
|
builtin_provider = db.session.scalar(
|
||||||
|
select(BuiltinToolProvider)
|
||||||
|
.where(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)
|
||||||
|
)
|
||||||
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if builtin_provider is None:
|
||||||
|
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||||
|
|
||||||
|
from core.helper.credential_utils import check_credential_policy_compliance
|
||||||
|
|
||||||
|
check_credential_policy_compliance(
|
||||||
|
credential_id=builtin_provider.id,
|
||||||
|
provider=provider_id,
|
||||||
|
credential_type=PluginCredentialType.TOOL,
|
||||||
|
check_existence=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypter, cache = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[
|
||||||
|
x.to_basic_provider_config()
|
||||||
|
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||||
|
],
|
||||||
|
cache=ToolProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||||
|
|
||||||
|
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||||
|
# TODO: circular import
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
|
|
||||||
|
tool_provider = ToolProviderID(provider_id)
|
||||||
|
provider_name = tool_provider.provider_name
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||||
|
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||||
|
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=builtin_provider.user_id,
|
||||||
|
plugin_id=tool_provider.plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=system_credentials or {},
|
||||||
|
credentials=decrypted_credentials,
|
||||||
|
)
|
||||||
|
# update the credentials
|
||||||
|
builtin_provider.encrypted_credentials = json.dumps(
|
||||||
|
encrypter.encrypt(refreshed_credentials.credentials)
|
||||||
|
)
|
||||||
|
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||||
|
db.session.commit()
|
||||||
|
decrypted_credentials = refreshed_credentials.credentials
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
if not provider_controller.need_credentials:
|
|
||||||
return builtin_tool.fork_tool_runtime(
|
return builtin_tool.fork_tool_runtime(
|
||||||
|
runtime=ToolRuntime(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
credentials=dict(decrypted_credentials),
|
||||||
|
credential_type=builtin_provider.credential_type,
|
||||||
|
runtime_parameters={},
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
tool_invoke_from=tool_invoke_from,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
case ToolProviderType.API:
|
||||||
|
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||||
|
encrypter, _ = create_tool_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
controller=api_provider,
|
||||||
|
)
|
||||||
|
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||||
|
runtime=ToolRuntime(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
credentials=dict(encrypter.decrypt(credentials)),
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
tool_invoke_from=tool_invoke_from,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case ToolProviderType.WORKFLOW:
|
||||||
|
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||||
|
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||||
|
)
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||||
|
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||||
|
|
||||||
|
if workflow_provider is None:
|
||||||
|
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||||
|
|
||||||
|
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||||
|
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||||
|
if controller_tools is None or len(controller_tools) == 0:
|
||||||
|
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||||
|
|
||||||
|
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||||
runtime=ToolRuntime(
|
runtime=ToolRuntime(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@ -221,177 +367,28 @@ class ToolManager:
|
|||||||
tool_invoke_from=tool_invoke_from,
|
tool_invoke_from=tool_invoke_from,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
builtin_provider = None
|
case ToolProviderType.APP:
|
||||||
if isinstance(provider_controller, PluginToolProviderController):
|
raise NotImplementedError("app provider not implemented")
|
||||||
provider_id_entity = ToolProviderID(provider_id)
|
case ToolProviderType.PLUGIN:
|
||||||
# get specific credentials
|
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
||||||
if is_valid_uuid(credential_id):
|
runtime = getattr(plugin_tool, "runtime", None)
|
||||||
try:
|
if runtime is not None:
|
||||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
runtime.user_id = user_id
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
runtime.invoke_from = invoke_from
|
||||||
BuiltinToolProvider.id == credential_id,
|
runtime.tool_invoke_from = tool_invoke_from
|
||||||
)
|
return plugin_tool
|
||||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
case ToolProviderType.MCP:
|
||||||
except Exception as e:
|
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
||||||
builtin_provider = None
|
runtime = getattr(mcp_tool, "runtime", None)
|
||||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
if runtime is not None:
|
||||||
# if the provider has been deleted, raise an error
|
runtime.user_id = user_id
|
||||||
if builtin_provider is None:
|
runtime.invoke_from = invoke_from
|
||||||
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
|
runtime.tool_invoke_from = tool_invoke_from
|
||||||
|
return mcp_tool
|
||||||
# fallback to the default provider
|
case ToolProviderType.DATASET_RETRIEVAL:
|
||||||
if builtin_provider is None:
|
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||||
# use the default provider
|
case _:
|
||||||
with Session(db.engine) as session:
|
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
||||||
builtin_provider = session.scalar(
|
|
||||||
sa.select(BuiltinToolProvider)
|
|
||||||
.where(
|
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
|
||||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
|
||||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
|
||||||
)
|
|
||||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
|
||||||
)
|
|
||||||
if builtin_provider is None:
|
|
||||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
|
||||||
else:
|
|
||||||
builtin_provider = db.session.scalar(
|
|
||||||
select(BuiltinToolProvider)
|
|
||||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
|
||||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
|
|
||||||
if builtin_provider is None:
|
|
||||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
|
||||||
|
|
||||||
# check if the credential is allowed to be used
|
|
||||||
from core.helper.credential_utils import check_credential_policy_compliance
|
|
||||||
|
|
||||||
check_credential_policy_compliance(
|
|
||||||
credential_id=builtin_provider.id,
|
|
||||||
provider=provider_id,
|
|
||||||
credential_type=PluginCredentialType.TOOL,
|
|
||||||
check_existence=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
encrypter, cache = create_provider_encrypter(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
config=[
|
|
||||||
x.to_basic_provider_config()
|
|
||||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
|
||||||
],
|
|
||||||
cache=ToolProviderCredentialsCache(
|
|
||||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# decrypt the credentials
|
|
||||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
|
||||||
|
|
||||||
# check if the credentials is expired
|
|
||||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
|
||||||
# TODO: circular import
|
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
|
||||||
|
|
||||||
# refresh the credentials
|
|
||||||
tool_provider = ToolProviderID(provider_id)
|
|
||||||
provider_name = tool_provider.provider_name
|
|
||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
|
||||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
|
||||||
|
|
||||||
oauth_handler = OAuthHandler()
|
|
||||||
# refresh the credentials
|
|
||||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=builtin_provider.user_id,
|
|
||||||
plugin_id=tool_provider.plugin_id,
|
|
||||||
provider=provider_name,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
system_credentials=system_credentials or {},
|
|
||||||
credentials=decrypted_credentials,
|
|
||||||
)
|
|
||||||
# update the credentials
|
|
||||||
builtin_provider.encrypted_credentials = json.dumps(
|
|
||||||
encrypter.encrypt(refreshed_credentials.credentials)
|
|
||||||
)
|
|
||||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
|
||||||
db.session.commit()
|
|
||||||
decrypted_credentials = refreshed_credentials.credentials
|
|
||||||
cache.delete()
|
|
||||||
|
|
||||||
return builtin_tool.fork_tool_runtime(
|
|
||||||
runtime=ToolRuntime(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=user_id,
|
|
||||||
credentials=dict(decrypted_credentials),
|
|
||||||
credential_type=builtin_provider.credential_type,
|
|
||||||
runtime_parameters={},
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
tool_invoke_from=tool_invoke_from,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif provider_type == ToolProviderType.API:
|
|
||||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
|
||||||
encrypter, _ = create_tool_provider_encrypter(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
controller=api_provider,
|
|
||||||
)
|
|
||||||
return api_provider.get_tool(tool_name).fork_tool_runtime(
|
|
||||||
runtime=ToolRuntime(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=user_id,
|
|
||||||
credentials=dict(encrypter.decrypt(credentials)),
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
tool_invoke_from=tool_invoke_from,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif provider_type == ToolProviderType.WORKFLOW:
|
|
||||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
|
||||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
|
||||||
)
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
|
||||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
|
||||||
|
|
||||||
if workflow_provider is None:
|
|
||||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
|
||||||
|
|
||||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
|
||||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
|
||||||
if controller_tools is None or len(controller_tools) == 0:
|
|
||||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
|
||||||
|
|
||||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
|
||||||
runtime=ToolRuntime(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=user_id,
|
|
||||||
credentials={},
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
tool_invoke_from=tool_invoke_from,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif provider_type == ToolProviderType.APP:
|
|
||||||
raise NotImplementedError("app provider not implemented")
|
|
||||||
elif provider_type == ToolProviderType.PLUGIN:
|
|
||||||
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
|
|
||||||
runtime = getattr(plugin_tool, "runtime", None)
|
|
||||||
if runtime is not None:
|
|
||||||
runtime.user_id = user_id
|
|
||||||
runtime.invoke_from = invoke_from
|
|
||||||
runtime.tool_invoke_from = tool_invoke_from
|
|
||||||
return plugin_tool
|
|
||||||
elif provider_type == ToolProviderType.MCP:
|
|
||||||
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
|
|
||||||
runtime = getattr(mcp_tool, "runtime", None)
|
|
||||||
if runtime is not None:
|
|
||||||
runtime.user_id = user_id
|
|
||||||
runtime.invoke_from = invoke_from
|
|
||||||
runtime.tool_invoke_from = tool_invoke_from
|
|
||||||
return mcp_tool
|
|
||||||
else:
|
|
||||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_agent_tool_runtime(
|
def get_agent_tool_runtime(
|
||||||
@ -885,7 +882,7 @@ class ToolManager:
|
|||||||
raise ValueError(f"you have not added provider {provider_name}")
|
raise ValueError(f"you have not added provider {provider_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = json.loads(provider_obj.credentials_str) or {}
|
credentials = _credentials_adapter.validate_json(provider_obj.credentials_str) or {}
|
||||||
except Exception:
|
except Exception:
|
||||||
credentials = {}
|
credentials = {}
|
||||||
|
|
||||||
@ -910,7 +907,7 @@ class ToolManager:
|
|||||||
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
|
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
icon = json.loads(provider_obj.icon)
|
icon = emoji_icon_adapter.validate_json(provider_obj.icon)
|
||||||
except Exception:
|
except Exception:
|
||||||
icon = {"background": "#252525", "content": "\ud83d\ude01"}
|
icon = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
|
|
||||||
@ -973,7 +970,7 @@ class ToolManager:
|
|||||||
if workflow_provider is None:
|
if workflow_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||||
|
|
||||||
icon = json.loads(workflow_provider.icon)
|
icon = emoji_icon_adapter.validate_json(workflow_provider.icon)
|
||||||
return icon
|
return icon
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
@ -990,7 +987,7 @@ class ToolManager:
|
|||||||
if api_provider is None:
|
if api_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||||
|
|
||||||
icon = json.loads(api_provider.icon)
|
icon = emoji_icon_adapter.validate_json(api_provider.icon)
|
||||||
return icon
|
return icon
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
@ -1025,31 +1022,31 @@ class ToolManager:
|
|||||||
:param provider_id: the id of the provider
|
:param provider_id: the id of the provider
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
provider_type = provider_type
|
match provider_type:
|
||||||
provider_id = provider_id
|
case ToolProviderType.BUILT_IN:
|
||||||
if provider_type == ToolProviderType.BUILT_IN:
|
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
||||||
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
|
if isinstance(provider, PluginToolProviderController):
|
||||||
if isinstance(provider, PluginToolProviderController):
|
try:
|
||||||
|
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||||
|
except Exception:
|
||||||
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
|
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||||
|
case ToolProviderType.API:
|
||||||
|
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
||||||
|
case ToolProviderType.WORKFLOW:
|
||||||
|
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||||
|
case ToolProviderType.PLUGIN:
|
||||||
|
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||||
try:
|
try:
|
||||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
case ToolProviderType.MCP:
|
||||||
elif provider_type == ToolProviderType.API:
|
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||||
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
|
case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
|
||||||
elif provider_type == ToolProviderType.WORKFLOW:
|
raise ValueError(f"provider type {provider_type} not found")
|
||||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
case _:
|
||||||
elif provider_type == ToolProviderType.PLUGIN:
|
raise ValueError(f"provider type {provider_type} not found")
|
||||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
|
||||||
try:
|
|
||||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
|
||||||
except Exception:
|
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
|
||||||
raise ValueError(f"plugin provider {provider_id} not found")
|
|
||||||
elif provider_type == ToolProviderType.MCP:
|
|
||||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"provider type {provider_type} not found")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _convert_tool_parameters_type(
|
def _convert_tool_parameters_type(
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from sqlalchemy import select
|
|||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||||
from core.rag.models.document import Document as RagDocument
|
from core.rag.models.document import Document as RagDocument
|
||||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||||
|
|||||||
@ -6,8 +6,7 @@ from sqlalchemy import select
|
|||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import DocumentContext, RetrievalSourceMetadata
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||||
from core.rag.models.document import Document as RetrievalDocument
|
from core.rag.models.document import Document as RetrievalDocument
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
|
|||||||
@ -305,14 +305,15 @@ class WorkflowTool(Tool):
|
|||||||
"transfer_method": file.transfer_method.value,
|
"transfer_method": file.transfer_method.value,
|
||||||
"type": file.type.value,
|
"type": file.type.value,
|
||||||
}
|
}
|
||||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
match file.transfer_method:
|
||||||
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
|
case FileTransferMethod.TOOL_FILE:
|
||||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
|
||||||
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE:
|
file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
|
||||||
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
|
case FileTransferMethod.DATASOURCE_FILE:
|
||||||
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
|
||||||
file_dict["url"] = file.generate_url()
|
case FileTransferMethod.REMOTE_URL:
|
||||||
|
file_dict["url"] = file.generate_url()
|
||||||
|
|
||||||
files.append(file_dict)
|
files.append(file_dict)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -357,8 +358,11 @@ class WorkflowTool(Tool):
|
|||||||
def _update_file_mapping(self, file_dict: dict):
|
def _update_file_mapping(self, file_dict: dict):
|
||||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||||
if transfer_method == FileTransferMethod.TOOL_FILE:
|
match transfer_method:
|
||||||
file_dict["tool_file_id"] = file_id
|
case FileTransferMethod.TOOL_FILE:
|
||||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
file_dict["tool_file_id"] = file_id
|
||||||
file_dict["upload_file_id"] = file_id
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
|
file_dict["upload_file_id"] = file_id
|
||||||
|
case FileTransferMethod.REMOTE_URL | FileTransferMethod.DATASOURCE_FILE:
|
||||||
|
pass
|
||||||
return file_dict
|
return file_dict
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Any, Union
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.entities import OAuthSchema
|
||||||
from core.plugin.entities.parameters import (
|
from core.plugin.entities.parameters import (
|
||||||
PluginParameterAutoGenerate,
|
PluginParameterAutoGenerate,
|
||||||
PluginParameterOption,
|
PluginParameterOption,
|
||||||
@ -108,13 +109,6 @@ class EventEntity(BaseModel):
|
|||||||
return v or []
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
class OAuthSchema(BaseModel):
|
|
||||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
|
||||||
credentials_schema: list[ProviderConfig] = Field(
|
|
||||||
default_factory=list, description="The schema of the OAuth credentials"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionConstructor(BaseModel):
|
class SubscriptionConstructor(BaseModel):
|
||||||
"""
|
"""
|
||||||
The subscription constructor of the trigger provider
|
The subscription constructor of the trigger provider
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from typing import Literal, Union
|
from typing import Union
|
||||||
|
|
||||||
from graphon.entities.base_node_data import BaseNodeData
|
from graphon.entities.base_node_data import BaseNodeData
|
||||||
from graphon.enums import NodeType
|
from graphon.enums import NodeType
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.rag.entities.retrieval_settings import WeightedScoreConfig
|
||||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||||
@ -18,50 +19,6 @@ class RerankingModelConfig(BaseModel):
|
|||||||
reranking_model_name: str
|
reranking_model_name: str
|
||||||
|
|
||||||
|
|
||||||
class VectorSetting(BaseModel):
|
|
||||||
"""
|
|
||||||
Vector Setting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
vector_weight: float
|
|
||||||
embedding_provider_name: str
|
|
||||||
embedding_model_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordSetting(BaseModel):
|
|
||||||
"""
|
|
||||||
Keyword Setting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
keyword_weight: float
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedScoreConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Weighted score Config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
vector_setting: VectorSetting
|
|
||||||
keyword_setting: KeywordSetting
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingSetting(BaseModel):
|
|
||||||
"""
|
|
||||||
Embedding Setting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
embedding_provider_name: str
|
|
||||||
embedding_model_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class EconomySetting(BaseModel):
|
|
||||||
"""
|
|
||||||
Economy Setting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
keyword_number: int
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalSetting(BaseModel):
|
class RetrievalSetting(BaseModel):
|
||||||
"""
|
"""
|
||||||
Retrieval Setting.
|
Retrieval Setting.
|
||||||
@ -77,16 +34,6 @@ class RetrievalSetting(BaseModel):
|
|||||||
weights: WeightedScoreConfig | None = None
|
weights: WeightedScoreConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
class IndexMethod(BaseModel):
|
|
||||||
"""
|
|
||||||
Knowledge Index Setting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
indexing_technique: Literal["high_quality", "economy"]
|
|
||||||
embedding_setting: EmbeddingSetting
|
|
||||||
economy_setting: EconomySetting
|
|
||||||
|
|
||||||
|
|
||||||
class FileInfo(BaseModel):
|
class FileInfo(BaseModel):
|
||||||
"""
|
"""
|
||||||
File Info.
|
File Info.
|
||||||
|
|||||||
@ -1,9 +1,19 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class IndexingResultDict(TypedDict):
|
||||||
|
dataset_id: str
|
||||||
|
dataset_name: str
|
||||||
|
batch: Any
|
||||||
|
document_id: str
|
||||||
|
document_name: str
|
||||||
|
created_at: float
|
||||||
|
display_status: str
|
||||||
|
|
||||||
|
|
||||||
class PreviewItem(BaseModel):
|
class PreviewItem(BaseModel):
|
||||||
content: str | None = Field(default=None)
|
content: str | None = Field(default=None)
|
||||||
child_chunks: list[str] | None = Field(default=None)
|
child_chunks: list[str] | None = Field(default=None)
|
||||||
@ -34,7 +44,7 @@ class IndexProcessorProtocol(Protocol):
|
|||||||
chunks: Mapping[str, Any],
|
chunks: Mapping[str, Any],
|
||||||
batch: Any,
|
batch: Any,
|
||||||
summary_index_setting: dict | None = None,
|
summary_index_setting: dict | None = None,
|
||||||
) -> dict[str, Any]: ...
|
) -> IndexingResultDict: ...
|
||||||
|
|
||||||
def get_preview_output(
|
def get_preview_output(
|
||||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from graphon.entities.base_node_data import BaseNodeData
|
from graphon.entities.base_node_data import BaseNodeData
|
||||||
@ -6,6 +5,10 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
|||||||
from graphon.nodes.llm.entities import ModelConfig, VisionConfig
|
from graphon.nodes.llm.entities import ModelConfig, VisionConfig
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.rag.entities import Condition, MetadataFilteringCondition, WeightedScoreConfig
|
||||||
|
|
||||||
|
__all__ = ["Condition"]
|
||||||
|
|
||||||
|
|
||||||
class RerankingModelConfig(BaseModel):
|
class RerankingModelConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -16,33 +19,6 @@ class RerankingModelConfig(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
|
|
||||||
|
|
||||||
class VectorSetting(BaseModel):
|
|
||||||
"""
|
|
||||||
Vector Setting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
vector_weight: float
|
|
||||||
embedding_provider_name: str
|
|
||||||
embedding_model_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordSetting(BaseModel):
|
|
||||||
"""
|
|
||||||
Keyword Setting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
keyword_weight: float
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedScoreConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Weighted score Config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
vector_setting: VectorSetting
|
|
||||||
keyword_setting: KeywordSetting
|
|
||||||
|
|
||||||
|
|
||||||
class MultipleRetrievalConfig(BaseModel):
|
class MultipleRetrievalConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Multiple Retrieval Config.
|
Multiple Retrieval Config.
|
||||||
@ -64,50 +40,6 @@ class SingleRetrievalConfig(BaseModel):
|
|||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
|
|
||||||
|
|
||||||
SupportedComparisonOperator = Literal[
|
|
||||||
# for string or array
|
|
||||||
"contains",
|
|
||||||
"not contains",
|
|
||||||
"start with",
|
|
||||||
"end with",
|
|
||||||
"is",
|
|
||||||
"is not",
|
|
||||||
"empty",
|
|
||||||
"not empty",
|
|
||||||
"in",
|
|
||||||
"not in",
|
|
||||||
# for number
|
|
||||||
"=",
|
|
||||||
"≠",
|
|
||||||
">",
|
|
||||||
"<",
|
|
||||||
"≥",
|
|
||||||
"≤",
|
|
||||||
# for time
|
|
||||||
"before",
|
|
||||||
"after",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Condition(BaseModel):
|
|
||||||
"""
|
|
||||||
Condition detail
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
comparison_operator: SupportedComparisonOperator
|
|
||||||
value: str | Sequence[str] | None | int | float = None
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataFilteringCondition(BaseModel):
|
|
||||||
"""
|
|
||||||
Metadata Filtering Condition.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logical_operator: Literal["and", "or"] | None = "and"
|
|
||||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNodeData(BaseNodeData):
|
class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||||
"""
|
"""
|
||||||
Knowledge retrieval Node Data.
|
Knowledge retrieval Node Data.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from graphon.entities import GraphInitParams
|
from graphon.entities import GraphInitParams
|
||||||
from graphon.entities.graph_config import NodeConfigDictAdapter
|
from graphon.entities.graph_config import NodeConfigDictAdapter
|
||||||
@ -107,6 +107,26 @@ class _WorkflowChildEngineBuilder:
|
|||||||
return child_engine
|
return child_engine
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeConfigDict(TypedDict):
|
||||||
|
id: str
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
type: str
|
||||||
|
data: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class _EdgeConfigDict(TypedDict):
|
||||||
|
source: str
|
||||||
|
target: str
|
||||||
|
sourceHandle: str
|
||||||
|
targetHandle: str
|
||||||
|
|
||||||
|
|
||||||
|
class SingleNodeGraphDict(TypedDict):
|
||||||
|
nodes: list[_NodeConfigDict]
|
||||||
|
edges: list[_EdgeConfigDict]
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEntry:
|
class WorkflowEntry:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -318,7 +338,7 @@ class WorkflowEntry:
|
|||||||
node_data: dict[str, Any],
|
node_data: dict[str, Any],
|
||||||
node_width: int = 114,
|
node_width: int = 114,
|
||||||
node_height: int = 514,
|
node_height: int = 514,
|
||||||
) -> dict[str, Any]:
|
) -> SingleNodeGraphDict:
|
||||||
"""
|
"""
|
||||||
Create a minimal graph structure for testing a single node in isolation.
|
Create a minimal graph structure for testing a single node in isolation.
|
||||||
|
|
||||||
@ -328,14 +348,14 @@ class WorkflowEntry:
|
|||||||
:param node_height: height for UI layout (default: 100)
|
:param node_height: height for UI layout (default: 100)
|
||||||
:return: graph dictionary with start node and target node
|
:return: graph dictionary with start node and target node
|
||||||
"""
|
"""
|
||||||
node_config = {
|
node_config: _NodeConfigDict = {
|
||||||
"id": node_id,
|
"id": node_id,
|
||||||
"width": node_width,
|
"width": node_width,
|
||||||
"height": node_height,
|
"height": node_height,
|
||||||
"type": "custom",
|
"type": "custom",
|
||||||
"data": node_data,
|
"data": node_data,
|
||||||
}
|
}
|
||||||
start_node_config = {
|
start_node_config: _NodeConfigDict = {
|
||||||
"id": "start",
|
"id": "start",
|
||||||
"width": node_width,
|
"width": node_width,
|
||||||
"height": node_height,
|
"height": node_height,
|
||||||
@ -346,9 +366,9 @@ class WorkflowEntry:
|
|||||||
"desc": "Start",
|
"desc": "Start",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return {
|
return SingleNodeGraphDict(
|
||||||
"nodes": [start_node_config, node_config],
|
nodes=[start_node_config, node_config],
|
||||||
"edges": [
|
edges=[
|
||||||
{
|
{
|
||||||
"source": "start",
|
"source": "start",
|
||||||
"target": node_id,
|
"target": node_id,
|
||||||
@ -356,7 +376,7 @@ class WorkflowEntry:
|
|||||||
"targetHandle": "target",
|
"targetHandle": "target",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run_free_node(
|
def run_free_node(
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
|
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
|
||||||
from events.app_event import app_published_workflow_was_updated
|
from events.app_event import app_published_workflow_was_updated
|
||||||
@ -45,7 +45,7 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
|||||||
Returns:
|
Returns:
|
||||||
Updated or created WorkflowSchedulePlan, or None if no schedule node
|
Updated or created WorkflowSchedulePlan, or None if no schedule node
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
schedule_config = ScheduleService.extract_schedule_config(workflow)
|
schedule_config = ScheduleService.extract_schedule_config(workflow)
|
||||||
|
|
||||||
existing_plan = session.scalar(
|
existing_plan = session.scalar(
|
||||||
@ -59,7 +59,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
|||||||
if existing_plan:
|
if existing_plan:
|
||||||
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
|
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
|
||||||
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
|
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
|
||||||
session.commit()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if existing_plan:
|
if existing_plan:
|
||||||
@ -73,7 +72,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
|||||||
schedule_id=existing_plan.id,
|
schedule_id=existing_plan.id,
|
||||||
updates=updates,
|
updates=updates,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
return updated_plan
|
return updated_plan
|
||||||
else:
|
else:
|
||||||
new_plan = ScheduleService.create_schedule(
|
new_plan = ScheduleService.create_schedule(
|
||||||
@ -82,5 +80,4 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
|
|||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
config=schedule_config,
|
config=schedule_config,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
return new_plan
|
return new_plan
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||||
from events.app_event import app_published_workflow_was_updated
|
from events.app_event import app_published_workflow_was_updated
|
||||||
@ -31,7 +31,7 @@ def handle(sender, **kwargs):
|
|||||||
# Extract trigger info from workflow
|
# Extract trigger info from workflow
|
||||||
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
|
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
# Get existing app triggers
|
# Get existing app triggers
|
||||||
existing_triggers = (
|
existing_triggers = (
|
||||||
session.execute(
|
session.execute(
|
||||||
@ -79,8 +79,6 @@ def handle(sender, **kwargs):
|
|||||||
existing_trigger.title = new_title
|
existing_trigger.title = new_title
|
||||||
session.add(existing_trigger)
|
session.add(existing_trigger)
|
||||||
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
|
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -354,11 +354,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
) -> WorkflowRun | None:
|
) -> WorkflowRun | None:
|
||||||
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
|
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
stmt = select(WorkflowRun).where(
|
stmt = select(WorkflowRun).where(
|
||||||
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
|
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
|
||||||
)
|
)
|
||||||
@ -439,11 +439,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
|
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
|
||||||
"""Fallback to PostgreSQL query for records not in LogStore."""
|
"""Fallback to PostgreSQL query for records not in LogStore."""
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
|
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
|
||||||
return session.scalar(stmt)
|
return session.scalar(stmt)
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user