Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-09 09:07:12 +08:00
commit 78240ed199
389 changed files with 12476 additions and 4103 deletions

9
.github/labeler.yml vendored
View File

@ -1,3 +1,10 @@
web: web:
- changed-files: - changed-files:
- any-glob-to-any-file: 'web/**' - any-glob-to-any-file:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'

View File

@ -0,0 +1,82 @@
import { execFileSync } from 'node:child_process'
import fs from 'node:fs'
import path from 'node:path'
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
outputPath,
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)

View File

@ -39,9 +39,11 @@ jobs:
with: with:
files: | files: |
web/** web/**
packages/**
package.json package.json
pnpm-lock.yaml pnpm-lock.yaml
pnpm-workspace.yaml pnpm-workspace.yaml
.npmrc
.nvmrc .nvmrc
- name: Check api inputs - name: Check api inputs
if: github.event_name != 'merge_group' if: github.event_name != 'merge_group'

View File

@ -8,9 +8,11 @@ on:
- api/Dockerfile - api/Dockerfile
- web/docker/** - web/docker/**
- web/Dockerfile - web/Dockerfile
- packages/**
- package.json - package.json
- pnpm-lock.yaml - pnpm-lock.yaml
- pnpm-workspace.yaml - pnpm-workspace.yaml
- .npmrc
- .nvmrc - .nvmrc
concurrency: concurrency:

View File

@ -65,9 +65,11 @@ jobs:
- 'docker/volumes/sandbox/conf/**' - 'docker/volumes/sandbox/conf/**'
web: web:
- 'web/**' - 'web/**'
- 'packages/**'
- 'package.json' - 'package.json'
- 'pnpm-lock.yaml' - 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml' - 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc' - '.nvmrc'
- '.github/workflows/web-tests.yml' - '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**' - '.github/actions/setup-web/**'
@ -77,9 +79,11 @@ jobs:
- 'api/uv.lock' - 'api/uv.lock'
- 'e2e/**' - 'e2e/**'
- 'web/**' - 'web/**'
- 'packages/**'
- 'package.json' - 'package.json'
- 'pnpm-lock.yaml' - 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml' - 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc' - '.nvmrc'
- 'docker/docker-compose.middleware.yaml' - 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example' - 'docker/middleware.env.example'

View File

@ -77,9 +77,11 @@ jobs:
with: with:
files: | files: |
web/** web/**
packages/**
package.json package.json
pnpm-lock.yaml pnpm-lock.yaml
pnpm-workspace.yaml pnpm-workspace.yaml
.npmrc
.nvmrc .nvmrc
.github/workflows/style.yml .github/workflows/style.yml
.github/actions/setup-web/** .github/actions/setup-web/**

View File

@ -9,6 +9,7 @@ on:
- package.json - package.json
- pnpm-lock.yaml - pnpm-lock.yaml
- pnpm-workspace.yaml - pnpm-workspace.yaml
- .npmrc
concurrency: concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }} group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@ -68,89 +68,7 @@ jobs:
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//') " web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
generate_changes_json() { generate_changes_json() {
node <<'NODE' node .github/scripts/generate-i18n-changes.mjs
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
} }
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
@ -270,7 +188,7 @@ jobs:
Tool rules: Tool rules:
- Use Read for repository files. - Use Read for repository files.
- Use Edit for JSON updates. - Use Edit for JSON updates.
- Use Bash only for `pnpm`. - Use Bash only for `vp`.
- Do not use Bash for `git`, `gh`, or branch management. - Do not use Bash for `git`, `gh`, or branch management.
Required execution plan: Required execution plan:
@ -292,7 +210,7 @@ jobs:
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate. - Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth. - If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
4. Run a scoped pre-check before editing: 4. Run a scoped pre-check before editing:
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` - `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Use this command as the source of truth for missing and extra keys inside the current scope. - Use this command as the source of truth for missing and extra keys inside the current scope.
5. Apply translations. 5. Apply translations.
- For every target language and scoped file: - For every target language and scoped file:
@ -300,19 +218,19 @@ jobs:
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed. - If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
- ADD missing keys. - ADD missing keys.
- UPDATE stale translations when the English value changed. - UPDATE stale translations when the English value changed.
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. - DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names. - Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
- Match the existing terminology and register used by each locale. - Match the existing terminology and register used by each locale.
- Prefer one Edit per file when stable, but prioritize correctness over batching. - Prefer one Edit per file when stable, but prioritize correctness over batching.
6. Verify only the edited files. 6. Verify only the edited files.
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>` - Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` - Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- If verification fails, fix the remaining problems before continuing. - If verification fails, fix the remaining problems before continuing.
7. Stop after the scoped locale files are updated and verification passes. 7. Stop after the scoped locale files are updated and verification passes.
- Do not create branches, commits, or pull requests. - Do not create branches, commits, or pull requests.
claude_args: | claude_args: |
--max-turns 120 --max-turns 120
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep" --allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
- name: Prepare branch metadata - name: Prepare branch metadata
id: pr_meta id: pr_meta
@ -354,6 +272,7 @@ jobs:
- name: Create or update translation PR - name: Create or update translation PR
if: steps.pr_meta.outputs.has_changes == 'true' if: steps.pr_meta.outputs.has_changes == 'true'
env: env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }} BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }} FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }} TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
@ -402,8 +321,8 @@ jobs:
'', '',
'## Verification', '## Verification',
'', '',
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``, `- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``, `- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
'', '',
'## Notes', '## Notes',
'', '',

View File

@ -42,88 +42,7 @@ jobs:
fi fi
export BASE_SHA HEAD_SHA CHANGED_FILES export BASE_SHA HEAD_SHA CHANGED_FILES
node <<'NODE' node .github/scripts/generate-i18n-changes.mjs
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = readCurrentJson(fileStem) || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: readCurrentJson(fileStem) === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
if [ -n "$CHANGED_FILES" ]; then if [ -n "$CHANGED_FILES" ]; then
echo "has_changes=true" >> "$GITHUB_OUTPUT" echo "has_changes=true" >> "$GITHUB_OUTPUT"

View File

@ -81,8 +81,8 @@ if $web_modified; then
if $web_ts_modified; then if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo" echo "Running TypeScript type-check:tsgo"
if ! pnpm run type-check:tsgo; then if ! npm run type-check:tsgo; then
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors." echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
exit 1 exit 1
fi fi
else else
@ -90,8 +90,8 @@ if $web_modified; then
fi fi
echo "Running knip" echo "Running knip"
if ! pnpm run knip; then if ! npm run knip; then
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors." echo "Knip check failed. Please run 'npm run knip' to fix the errors."
exit 1 exit 1
fi fi

View File

@ -0,0 +1,79 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
from libs.helper import UUIDStrOrEmpty
# --- Conversation schemas ---
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
# --- Message schemas ---
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
# --- Saved message schemas ---
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
# --- Workflow schemas ---
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None

View File

@ -7,7 +7,7 @@ from flask import request
from flask_restx import Resource from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
@ -26,9 +26,11 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType from models.model import IconType
@ -41,10 +43,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
NotionIcon, NotionIcon,
NotionInfo, NotionInfo,
NotionPage, NotionPage,
PreProcessingRule,
RerankingModel, RerankingModel,
Rule,
Segmentation,
WebsiteInfo, WebsiteInfo,
WeightKeywordSetting, WeightKeywordSetting,
WeightModel, WeightModel,
@ -155,16 +154,6 @@ class AppTracePayload(BaseModel):
type JSONValue = Any type JSONValue = Any
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
def _to_timestamp(value: datetime | int | None) -> int | None: def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime): if isinstance(value, datetime):
return int(value.timestamp()) return int(value.timestamp())

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload) args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session # Create service with session
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Import app # Import app
account = current_user account = current_user
@ -92,11 +92,13 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result # Return appropriate status code based on result
status = result.status status = result.status
if status == ImportStatus.FAILED: match status:
return result.model_dump(mode="json"), 400 case ImportStatus.FAILED:
elif status == ImportStatus.PENDING: return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 202 case ImportStatus.PENDING:
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:import_id>/confirm") @console_ns.route("/apps/imports/<string:import_id>/confirm")

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -59,10 +60,8 @@ class ChatMessagesQuery(BaseModel):
return uuid_value(value) return uuid_value(value)
class MessageFeedbackPayload(BaseModel): class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
message_id: str = Field(..., description="Message ID") message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id") @field_validator("message_id")
@classmethod @classmethod

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow_run import workflow_run_node_execution_model from controllers.console.app.workflow_run import workflow_run_node_execution_model
@ -142,10 +143,6 @@ class PublishWorkflowPayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100) marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel): class ConvertToWorkflowPayload(BaseModel):
name: str | None = None name: str | None = None
icon_type: str | None = None icon_type: str | None = None
@ -153,18 +150,6 @@ class ConvertToWorkflowPayload(BaseModel):
icon_background: str | None = None icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel): class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str node_id: str

View File

@ -3,7 +3,7 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -20,35 +20,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel): class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result") result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token") data: str | None = Field(default=None, description="Reset token")

View File

@ -42,6 +42,7 @@ from libs.token import (
) )
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -49,9 +50,7 @@ from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel): class LoginPayload(LoginPayloadBase):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag") remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token") invite_token: str | None = Field(default=None, description="Invitation token")

View File

@ -83,11 +83,13 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result # Return appropriate status code based on result
status = result.status status = result.status
if status == ImportStatus.FAILED: match status:
return result.model_dump(mode="json"), 400 case ImportStatus.FAILED:
elif status == ImportStatus.PENDING: return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 202 case ImportStatus.PENDING:
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm") @console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -94,22 +95,6 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
original_document_id: str | None = None original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel): class NodeIdQuery(BaseModel):
node_id: str node_id: str

View File

@ -2,10 +2,10 @@ import logging
from flask import request from flask import request
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
@ -32,14 +32,6 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload) register_schema_model(console_ns, TextToAudioPayload)

View File

@ -1,10 +1,11 @@
from typing import Any from typing import Any
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
pinned: bool | None = None pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload) register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -3,9 +3,10 @@ from typing import Literal
from flask import request from flask import request
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError, AppMoreLikeThisDisabledError,
@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.enums import FeedbackRating from models.enums import FeedbackRating
from models.model import AppMode from models.model import AppMode
@ -44,17 +44,6 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel): class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] response_mode: Literal["blocking", "streaming"]

View File

@ -1,28 +1,18 @@
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,11 +1,10 @@
import logging import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
@ -34,12 +33,6 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload) register_schema_model(console_ns, WorkflowRunPayload)

View File

@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -240,8 +240,10 @@ class CustomConfigWorkspaceApi(Resource):
args = WorkspaceCustomConfigPayload.model_validate(payload) args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id) tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = { custom_config_dict: TenantCustomConfigDict = {
"remove_webapp_brand": args.remove_webapp_brand, "remove_webapp_brand": args.remove_webapp_brand
if args.remove_webapp_brand is not None
else tenant.custom_config_dict.get("remove_webapp_brand", False),
"replace_webapp_logo": args.replace_webapp_logo "replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"), else tenant.custom_config_dict.get("replace_webapp_logo"),

View File

@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id) account.set_tenant_id(workspace_id)
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
dsl_service = AppDslService(session) dsl_service = AppDslService(session)
result = dsl_service.import_app( result = dsl_service.import_app(
account=account, account=account,
@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
name=args.name, name=args.name,
description=args.description, description=args.description,
) )
session.commit()
if result.status == ImportStatus.FAILED: if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400

View File

@ -2,11 +2,12 @@ from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
import services import services
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
) )
class ConversationRenamePayload(BaseModel):
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel): class ConversationVariablesQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")

View File

@ -1,5 +1,4 @@
import logging import logging
from typing import Literal
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.errors.message import ( from services.errors.message import (
@ -27,17 +26,6 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel): class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number") page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page") limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Literal from typing import Literal
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request from flask import request
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel): class WorkflowRunPayload(WorkflowRunPayloadBase):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None response_mode: Literal["blocking", "streaming"] | None = None

View File

@ -31,6 +31,7 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
) )
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
@ -40,11 +41,8 @@ from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig, KnowledgeConfig,
PreProcessingRule,
ProcessRule, ProcessRule,
RetrievalModel, RetrievalModel,
Rule,
Segmentation,
) )
from services.file_service import FileService from services.file_service import FileService
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService

View File

@ -3,10 +3,11 @@ import logging
from flask import request from flask import request
from flask_restx import fields, marshal_with from flask_restx import fields, marshal_with
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, field_validator from pydantic import field_validator
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
AppUnavailableError, AppUnavailableError,
@ -34,12 +35,7 @@ from services.errors.audio import (
from ..common.schema import register_schema_models from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel): class TextToAudioPayload(TextToAudioPayloadBase):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id") @field_validator("message_id")
@classmethod @classmethod
def validate_message_id(cls, value: str | None) -> str | None: def validate_message_id(cls, value: str | None) -> str | None:

View File

@ -1,10 +1,11 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import NotChatAppError from controllers.web.error import NotChatAppError
@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
return uuid_value(value) return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -3,7 +3,6 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -19,33 +18,15 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns from controllers.web import web_ns
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip from libs.helper import extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password
from models.account import Account from models.account import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
class ForgotPasswordSendPayload(BaseModel): ForgotPasswordResetPayload,
email: EmailStr ForgotPasswordSendPayload,
language: str | None = None )
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload) register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)

View File

@ -29,13 +29,11 @@ from libs.token import (
) )
from services.account_service import AccountService from services.account_service import AccountService
from services.app_service import AppService from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
class LoginPayload(BaseModel): class LoginPayload(LoginPayloadBase):
email: EmailStr
password: str
@field_validator("password") @field_validator("password")
@classmethod @classmethod
def validate_password(cls, value: str) -> str: def validate_password(cls, value: str) -> str:

View File

@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter, field_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
return uuid_value(value) return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel): class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field( response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode", description="Response mode",

View File

@ -138,12 +138,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
if auth_type == WebAppAuthType.PUBLIC: match auth_type:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) case WebAppAuthType.PUBLIC:
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
raise WebAppAuthRequiredError("Please login as external user.") case WebAppAuthType.EXTERNAL:
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": if user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as internal user.") raise WebAppAuthRequiredError("Please login as external user.")
case WebAppAuthType.INTERNAL:
if user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None end_user = None
if end_user_id: if end_user_id:

View File

@ -1,27 +1,17 @@
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,11 +1,10 @@
import logging import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
register_schema_models(web_ns, WorkflowRunPayload) register_schema_models(web_ns, WorkflowRunPayload)

View File

@ -79,21 +79,18 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad: if not agent_scratchpad:
assistant_messages = [] assistant_messages = []
else: else:
assistant_message = AssistantPromptMessage(content="") content = ""
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assert isinstance(assistant_message.content, str) content += f"Final Answer: {unit.agent_response}"
assistant_message.content += f"Final Answer: {unit.agent_response}"
else: else:
assert isinstance(assistant_message.content, str) content += f"Thought: {unit.thought}\n\n"
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str: if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n" content += f"Action: {unit.action_str}\n\n"
if unit.observation: if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n" content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message] assistant_messages = [AssistantPromptMessage(content=content)]
# query messages # query messages
query_messages = self._organize_user_query(self._query, []) query_messages = self._organize_user_query(self._query, [])

View File

@ -1,4 +1,3 @@
from collections.abc import Sequence
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Any, Literal from typing import Any, Literal
@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.entities import MetadataFilteringCondition
from models.model import AppMode from models.model import AppMode
@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
config: dict[str, Any] = Field(default_factory=dict) config: dict[str, Any] = Field(default_factory=dict)
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
provider: str provider: str
name: str name: str
@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
completion_params: dict[str, Any] = Field(default_factory=dict) completion_params: dict[str, Any] = Field(default_factory=dict)
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel): class DatasetRetrieveConfigEntity(BaseModel):
""" """
Dataset Retrieve Config Entity. Dataset Retrieve Config Entity.

View File

@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variable_loader import VariableLoader from graphon.variable_loader import VariableLoader
from graphon.variables.variables import Variable from graphon.variables.variables import Variable
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
:return: List of conversation variables ready for use :return: List of conversation variables ready for use
""" """
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
existing_variables = self._load_existing_conversation_variables(session) existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables: if not existing_variables:
@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Convert to Variable objects for use in the workflow # Convert to Variable objects for use in the workflow
conversation_variables = [var.to_variable() for var in existing_variables] conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[Variable], conversation_variables) return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:

View File

@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes import BuiltinNodeTypes from graphon.nodes import BuiltinNodeTypes
from graphon.runtime import GraphRuntimeState from graphon.runtime import GraphRuntimeState
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager @contextmanager
def _database_session(self): def _database_session(self):
"""Context manager for database sessions.""" """Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session: with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
try: yield session
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self): def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state.""" """Fluent validation for workflow state."""

View File

@ -7,7 +7,7 @@ from typing import Union
from graphon.entities import WorkflowStartReason from graphon.entities import WorkflowStartReason
from graphon.enums import WorkflowExecutionStatus from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState from graphon.runtime import GraphRuntimeState
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager @contextmanager
def _database_session(self): def _database_session(self):
"""Context manager for database sessions.""" """Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session: with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
try: yield session
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self): def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state.""" """Fluent validation for workflow state."""

View File

@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.system_variables import ( from core.workflow.system_variables import (
build_bootstrap_variables, build_bootstrap_variables,

View File

@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
class QueueEvent(StrEnum): class QueueEvent(StrEnum):

View File

@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
class AnnotationReplyAccount(BaseModel): class AnnotationReplyAccount(BaseModel):

View File

@ -1,6 +1,6 @@
from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.llm_entities import LLMUsage
from sqlalchemy import update from sqlalchemy import update
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
@ -73,7 +73,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
pool_type="paid", pool_type="paid",
) )
else: else:
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
stmt = ( stmt = (
update(Provider) update(Provider)
.where( .where(
@ -90,4 +90,3 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
) )
) )
session.execute(stmt) session.execute(stmt)
session.commit()

View File

@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
) )
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
event = message.event event = message.event
if isinstance(event, QueueErrorEvent): if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
err = self.handle_error(event=event, session=session, message_id=self._message_id) err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self.error_to_stream_response(err) yield self.error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
answer=output_moderation_answer answer=output_moderation_answer
) )
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
# Save message # Save message
self._save_message(session=session, trace_manager=trace_manager) self._save_message(session=session, trace_manager=trace_manager)
session.commit()
message_end_resp = self._message_end_to_stream_response() message_end_resp = self._message_end_to_stream_response()
yield message_end_resp yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):

View File

@ -6,7 +6,7 @@ from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -1,22 +1,3 @@
from pydantic import BaseModel, Field, model_validator from core.tools.entities.common_entities import I18nObject, I18nObjectDict
__all__ = ["I18nObject", "I18nObjectDict"]
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -9,7 +9,7 @@ from yarl import URL
from configs import dify_config from configs import dify_config
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import ( from core.plugin.entities.parameters import (
PluginParameter, PluginParameter,
PluginParameterOption, PluginParameterOption,

View File

@ -1 +1,8 @@
from core.entities.plugin_credential_type import PluginCredentialType
DEFAULT_PLUGIN_ID = "langgenius" DEFAULT_PLUGIN_ID = "langgenius"
__all__ = [
"DEFAULT_PLUGIN_ID",
"PluginCredentialType",
]

View File

@ -0,0 +1,9 @@
import enum
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value

View File

@ -22,6 +22,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities import PluginCredentialType
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import ( from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
@ -46,7 +47,6 @@ from models.provider import (
TenantPreferredModelProvider, TenantPreferredModelProvider,
) )
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@
Credential utility functions for checking credential existence and policy compliance. Credential utility functions for checking credential existence and policy compliance.
""" """
from services.enterprise.plugin_manager_service import PluginCredentialType from core.entities import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:

View File

@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from configs import dify_config from configs import dify_config
from core.entities import PluginCredentialType
from core.entities.embedding_type import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.provider import ProviderType from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -0,0 +1,5 @@
from core.plugin.entities.oauth import OAuthSchema
__all__ = [
"OAuthSchema",
]

View File

@ -1,5 +1,3 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
OAuth schema OAuth schema
""" """
client_schema: Sequence[ProviderConfig] = Field( client_schema: list[ProviderConfig] = Field(
default_factory=list, default_factory=list,
description="client schema like client_id, client_secret, etc.", description="client schema like client_id, client_secret, etc.",
) )
credentials_schema: Sequence[ProviderConfig] = Field( credentials_schema: list[ProviderConfig] = Field(
default_factory=list, default_factory=list,
description="credentials schema like access_token, refresh_token, etc.", description="credentials schema like access_token, refresh_token, etc.",
) )

View File

@ -209,7 +209,10 @@ class PluginInstaller(BasePluginClient):
"GET", "GET",
f"plugin/{tenant_id}/management/decode/from_identifier", f"plugin/{tenant_id}/management/decode/from_identifier",
PluginDecodeResponse, PluginDecodeResponse,
params={"plugin_unique_identifier": plugin_unique_identifier}, params={
"plugin_unique_identifier": plugin_unique_identifier,
"PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4
},
) )
def fetch_plugin_installation_by_ids( def fetch_plugin_installation_by_ids(

View File

@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import json
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ( from graphon.model_runtime.entities.provider_entities import (
@ -15,6 +14,7 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderEntity, ProviderEntity,
) )
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -58,6 +58,8 @@ from services.feature_service import FeatureService
if TYPE_CHECKING: if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime from graphon.model_runtime.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class ProviderManager: class ProviderManager:
""" """
@ -875,8 +877,8 @@ class ProviderManager:
return {"openai_api_key": encrypted_config} return {"openai_api_key": encrypted_config}
try: try:
credentials = cast(dict, json.loads(encrypted_config)) credentials = _credentials_adapter.validate_json(encrypted_config)
except JSONDecodeError: except (ValueError, JSONDecodeError):
return {} return {}
# Decrypt secret variables # Decrypt secret variables
@ -1015,7 +1017,7 @@ class ProviderManager:
if not cached_provider_credentials: if not cached_provider_credentials:
provider_credentials: dict[str, Any] = {} provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config: if provider_records and provider_records[0].encrypted_config:
provider_credentials = json.loads(provider_records[0].encrypted_config) provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(
@ -1162,8 +1164,10 @@ class ProviderManager:
if not cached_provider_model_credentials: if not cached_provider_model_credentials:
try: try:
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config) provider_model_credentials = _credentials_adapter.validate_json(
except JSONDecodeError: load_balancing_model_config.encrypted_config
)
except (ValueError, JSONDecodeError):
continue continue
# Get decoding rsa key and cipher for decrypting credentials # Get decoding rsa key and cipher for decrypting credentials
@ -1176,7 +1180,7 @@ class ProviderManager:
if variable in provider_model_credentials: if variable in provider_model_credentials:
try: try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable), provider_model_credentials.get(variable) or "",
self.decoding_rsa_key, self.decoding_rsa_key,
self.decoding_cipher_rsa, self.decoding_cipher_rsa,
) )

View File

@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition from core.rag.entities import MetadataFilteringCondition
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType from core.rag.index_processor.constant.query_type import QueryType
@ -182,7 +182,9 @@ class RetrievalService:
if not dataset: if not dataset:
return [] return []
metadata_condition = ( metadata_condition = (
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
if metadata_filtering_conditions
else None
) )
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset.tenant_id,

View File

@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector):
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0]) dimension = len(embeddings[0])
self.analyticdb_vector._create_collection_if_not_exists(dimension) self.analyticdb_vector.create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings) self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
self.analyticdb_vector.add_texts(documents, embeddings) self.analyticdb_vector.add_texts(documents, embeddings)
return []
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id) return self.analyticdb_vector.text_exists(id)

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any from typing import Any, TypedDict
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -13,6 +13,13 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
class AnalyticdbClientParamsDict(TypedDict):
access_key_id: str
access_key_secret: str
region_id: str
read_timeout: int
class AnalyticdbVectorOpenAPIConfig(BaseModel): class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str access_key_id: str
access_key_secret: str access_key_secret: str
@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required") raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values return values
def to_analyticdb_client_params(self): def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
return { result: AnalyticdbClientParamsDict = {
"access_key_id": self.access_key_id, "access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret, "access_key_secret": self.access_key_secret,
"region_id": self.region_id, "region_id": self.region_id,
"read_timeout": self.read_timeout, "read_timeout": self.read_timeout,
} }
return result
class AnalyticdbVectorOpenAPI: class AnalyticdbVectorOpenAPI:
@ -115,7 +123,7 @@ class AnalyticdbVectorOpenAPI:
else: else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int): def create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException from Tea.exceptions import TeaException

View File

@ -1,5 +1,6 @@
import json import json
import uuid import uuid
from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any
@ -74,7 +75,7 @@ class AnalyticdbVectorBySql:
) )
@contextmanager @contextmanager
def _get_cursor(self): def _get_cursor(self) -> Iterator[Any]:
assert self.pool is not None, "Connection pool is not initialized" assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn() conn = self.pool.getconn()
cur = conn.cursor() cur = conn.cursor()
@ -130,7 +131,7 @@ class AnalyticdbVectorBySql:
) )
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}") cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int): def create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}" cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock" lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock_name, timeout=20):

View File

@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -85,8 +85,12 @@ class BaiduVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.BAIDU return VectorType.BAIDU
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0])) self._create_table(len(embeddings[0]))

View File

@ -1,12 +1,12 @@
import json import json
from typing import Any from typing import Any, TypedDict
import chromadb import chromadb
from chromadb import QueryResult, Settings from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
class ChromaParamsDict(TypedDict):
host: str
port: int
ssl: bool
tenant: str
database: str
settings: Settings
class ChromaConfig(BaseModel): class ChromaConfig(BaseModel):
host: str host: str
port: int port: int
@ -23,14 +32,13 @@ class ChromaConfig(BaseModel):
auth_provider: str | None = None auth_provider: str | None = None
auth_credentials: str | None = None auth_credentials: str | None = None
def to_chroma_params(self): def to_chroma_params(self) -> ChromaParamsDict:
settings = Settings( settings = Settings(
# auth # auth
chroma_client_auth_provider=self.auth_provider, chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials, chroma_client_auth_credentials=self.auth_credentials,
) )
result: ChromaParamsDict = {
return {
"host": self.host, "host": self.host,
"port": self.port, "port": self.port,
"ssl": False, "ssl": False,
@ -38,6 +46,7 @@ class ChromaConfig(BaseModel):
"database": self.database, "database": self.database,
"settings": settings, "settings": settings,
} }
return result
class ChromaVector(BaseVector): class ChromaVector(BaseVector):
@ -97,14 +106,15 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name) collection = self._client.get_or_create_collection(self._collection_name)
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
results: QueryResult
if document_ids_filter: if document_ids_filter:
results: QueryResult = collection.query( results = collection.query(
query_embeddings=query_vector, query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4), n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}}, # type: ignore where={"document_id": {"$in": document_ids_filter}}, # type: ignore
) )
else: else:
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Check if results contain data # Check if results contain data
@ -145,7 +155,10 @@ class ChromaVectorFactory(AbstractVectorFactory):
else: else:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} index_struct_dict: VectorIndexStructDict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector( return ChromaVector(
@ -153,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
config=ChromaConfig( config=ChromaConfig(
host=dify_config.CHROMA_HOST or "", host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT, port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
auth_provider=dify_config.CHROMA_AUTH_PROVIDER, auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS, auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
), ),

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any from typing import Any, TypedDict
from packaging import version from packaging import version
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -20,6 +20,15 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MilvusParamsDict(TypedDict):
uri: str
token: str | None
user: str | None
password: str | None
db_name: str
analyzer_params: str | None
class MilvusConfig(BaseModel): class MilvusConfig(BaseModel):
""" """
Configuration class for Milvus connection. Configuration class for Milvus connection.
@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
raise ValueError("config MILVUS_PASSWORD is required") raise ValueError("config MILVUS_PASSWORD is required")
return values return values
def to_milvus_params(self): def to_milvus_params(self) -> MilvusParamsDict:
""" """
Convert the configuration to a dictionary of Milvus connection parameters. Convert the configuration to a dictionary of Milvus connection parameters.
""" """
return { result: MilvusParamsDict = {
"uri": self.uri, "uri": self.uri,
"token": self.token, "token": self.token,
"user": self.user, "user": self.user,
@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
"db_name": self.database, "db_name": self.database,
"analyzer_params": self.analyzer_params, "analyzer_params": self.analyzer_params,
} }
return result
class MilvusVector(BaseVector): class MilvusVector(BaseVector):

View File

@ -3,7 +3,7 @@ import os
import uuid import uuid
from collections.abc import Generator, Iterable, Sequence from collections.abc import Generator, Iterable, Sequence
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, cast
import qdrant_client import qdrant_client
from flask import current_app from flask import current_app
@ -22,7 +22,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding from models.dataset import Dataset, DatasetCollectionBinding
if TYPE_CHECKING: if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
@ -94,8 +93,12 @@ class QdrantVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.QDRANT return VectorType.QDRANT
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts: if texts:
@ -176,7 +179,7 @@ class QdrantVector(BaseVector):
for batch_ids, points in self._generate_rest_batches( for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
): ):
self._client.upsert(collection_name=self._collection_name, points=points) self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
added_ids.extend(batch_ids) added_ids.extend(batch_ids)
return added_ids return added_ids
@ -468,7 +471,7 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self): def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal): if isinstance(self._client, QdrantLocal):
self._client._load() self._client._load() # pyright: ignore[reportPrivateUsage]
@classmethod @classmethod
def _document_from_scored_point( def _document_from_scored_point(

View File

@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any Base: Any = declarative_base()
class RelytConfig(BaseModel): class RelytConfig(BaseModel):

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import math import math
from typing import Any from typing import Any, TypedDict
from pydantic import BaseModel from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore from tcvdb_text.encoder import BM25Encoder # type: ignore
@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -23,6 +23,13 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TencentParamsDict(TypedDict):
url: str
username: str | None
key: str | None
timeout: float
class TencentConfig(BaseModel): class TencentConfig(BaseModel):
url: str url: str
api_key: str | None = None api_key: str | None = None
@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
max_upsert_batch_size: int = 128 max_upsert_batch_size: int = 128
enable_hybrid_search: bool = False # Flag to enable hybrid search enable_hybrid_search: bool = False # Flag to enable hybrid search
def to_tencent_params(self): def to_tencent_params(self) -> TencentParamsDict:
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} result: TencentParamsDict = {
"url": self.url,
"username": self.username,
"key": self.api_key,
"timeout": self.timeout,
}
return result
bm25 = BM25Encoder.default("zh") bm25 = BM25Encoder.default("zh")
@ -83,8 +96,12 @@ class TencentVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.TENCENT return VectorType.TENCENT
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def _has_collection(self) -> bool: def _has_collection(self) -> bool:
return bool( return bool(

View File

@ -25,7 +25,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -91,8 +91,12 @@ class TidbOnQdrantVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT return VectorType.TIDB_ON_QDRANT
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts: if texts:

View File

@ -1,11 +1,20 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, TypedDict
from core.rag.models.document import Document from core.rag.models.document import Document
class VectorStoreDict(TypedDict):
class_prefix: str
class VectorIndexStructDict(TypedDict):
type: str
vector_store: VectorStoreDict
class BaseVector(ABC): class BaseVector(ABC):
def __init__(self, collection_name: str): def __init__(self, collection_name: str):
self._collection_name = collection_name self._collection_name = collection_name

View File

@ -9,7 +9,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -30,8 +30,11 @@ class AbstractVectorFactory(ABC):
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str): def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} index_struct_dict: VectorIndexStructDict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
}
return index_struct_dict return index_struct_dict

View File

@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -184,9 +184,13 @@ class WeaviateVector(BaseVector):
dataset_id = dataset.id dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id) return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict: def to_index_struct(self) -> VectorIndexStructDict:
"""Returns the index structure dictionary for persistence.""" """Returns the index structure dictionary for persistence."""
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
""" """

View File

@ -0,0 +1,28 @@
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
__all__ = [
"Condition",
"DatasourceCompletedEvent",
"DatasourceErrorEvent",
"DatasourceProcessingEvent",
"DocumentContext",
"EconomySetting",
"EmbeddingSetting",
"IndexMethod",
"KeywordSetting",
"MetadataFilteringCondition",
"ParentMode",
"PreProcessingRule",
"RetrievalSourceMetadata",
"Rule",
"Segmentation",
"SupportedComparisonOperator",
"VectorSetting",
"WeightedScoreConfig",
]

View File

@ -0,0 +1,30 @@
from typing import Literal
from pydantic import BaseModel
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting

View File

@ -38,9 +38,9 @@ class Condition(BaseModel):
value: str | Sequence[str] | None | int | float = None value: str | Sequence[str] | None | int | float = None
class MetadataCondition(BaseModel): class MetadataFilteringCondition(BaseModel):
""" """
Metadata Condition. Metadata Filtering Condition.
""" """
logical_operator: Literal["and", "or"] | None = "and" logical_operator: Literal["and", "or"] | None = "and"

View File

@ -0,0 +1,27 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None

View File

@ -0,0 +1,28 @@
from pydantic import BaseModel
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting

View File

@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor):
def extract(self) -> list[Document]: def extract(self) -> list[Document]:
from unstructured.__version__ import __version__ as __unstructured_version__ from unstructured.__version__ import __version__ as __unstructured_version__
from unstructured.file_utils.filetype import FileType, detect_filetype from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage]
FileType,
detect_filetype,
)
unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
# check the file extension # check the file extension
try: try:
import magic # noqa: F401 import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
is_doc = detect_filetype(self._file_path) == FileType.DOC is_doc = detect_filetype(self._file_path) == FileType.DOC
except ImportError: except ImportError:

View File

@ -12,7 +12,7 @@ from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory from .index_processor_factory import IndexProcessorFactory
@ -61,7 +61,7 @@ class IndexProcessor:
chunks: Mapping[str, Any], chunks: Mapping[str, Any],
batch: Any, batch: Any,
summary_index_setting: SummaryIndexSettingDict | None = None, summary_index_setting: SummaryIndexSettingDict | None = None,
): ) -> IndexingResultDict:
with session_factory.create_session() as session: with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first() document = session.query(Document).filter_by(id=document_id).first()
if not document: if not document:
@ -129,7 +129,7 @@ class IndexProcessor:
} }
) )
return { result: IndexingResultDict = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"dataset_name": dataset_name_value, "dataset_name": dataset_name_value,
"batch": batch, "batch": batch,
@ -138,6 +138,7 @@ class IndexProcessor:
"created_at": created_at_value.timestamp(), "created_at": created_at_value.timestamp(),
"display_status": "completed", "display_status": "completed",
} }
return result
def get_preview_output( def get_preview_output(
self, self,

View File

@ -32,6 +32,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
@ -49,7 +50,6 @@ from models.account import Account
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController() _file_access_controller = DatabaseFileAccessController()

View File

@ -17,6 +17,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import ParentMode, Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
@ -30,7 +31,6 @@ from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -19,6 +19,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -30,7 +31,6 @@ from libs import helper
from models.account import Account from models.account import Account
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,16 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from core.rag.entities import KeywordSetting, VectorSetting
class VectorSetting(BaseModel):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
keyword_weight: float
class Weights(BaseModel): class Weights(BaseModel):

View File

@ -15,7 +15,7 @@ from graphon.model_runtime.entities.message_entities import PromptMessage, Promp
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import and_, func, literal, or_, select from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import ( from core.app.app_config.entities import (
DatasetEntity, DatasetEntity,
@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType from core.rag.index_processor.constant.query_type import QueryType
@ -604,7 +602,7 @@ class DatasetRetrieval:
planning_strategy: PlanningStrategy, planning_strategy: PlanningStrategy,
message_id: str | None = None, message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None, metadata_condition: MetadataFilteringCondition | None = None,
): ):
tools = [] tools = []
for dataset in available_datasets: for dataset in available_datasets:
@ -743,7 +741,7 @@ class DatasetRetrieval:
reranking_enable: bool = True, reranking_enable: bool = True,
message_id: str | None = None, message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None, metadata_condition: MetadataFilteringCondition | None = None,
attachment_ids: list[str] | None = None, attachment_ids: list[str] | None = None,
): ):
if not available_datasets: if not available_datasets:
@ -886,7 +884,7 @@ class DatasetRetrieval:
self._send_trace_task(message_id, documents, timer) self._send_trace_task(message_id, documents, timer)
return return
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
# Collect all document_ids and batch fetch DatasetDocuments # Collect all document_ids and batch fetch DatasetDocuments
document_ids = { document_ids = {
doc.metadata["document_id"] doc.metadata["document_id"]
@ -977,7 +975,6 @@ class DatasetRetrieval:
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False, synchronize_session=False,
) )
session.commit()
self._send_trace_task(message_id, documents, timer) self._send_trace_task(message_id, documents, timer)
@ -1063,7 +1060,7 @@ class DatasetRetrieval:
top_k: int, top_k: int,
all_documents: list[Document], all_documents: list[Document],
document_ids_filter: list[str] | None = None, document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None, metadata_condition: MetadataFilteringCondition | None = None,
attachment_ids: list[str] | None = None, attachment_ids: list[str] | None = None,
): ):
with flask_app.app_context(): with flask_app.app_context():
@ -1339,7 +1336,7 @@ class DatasetRetrieval:
metadata_model_config: ModelConfig, metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition | None, metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict, inputs: dict,
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: ) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
document_query = select(DatasetDocument).where( document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
@ -1371,7 +1368,7 @@ class DatasetRetrieval:
value=filter.get("value"), value=filter.get("value"),
) )
) )
metadata_condition = MetadataCondition( metadata_condition = MetadataFilteringCondition(
logical_operator=metadata_filtering_conditions.logical_operator logical_operator=metadata_filtering_conditions.logical_operator
if metadata_filtering_conditions if metadata_filtering_conditions
else "or", # type: ignore else "or", # type: ignore
@ -1400,7 +1397,7 @@ class DatasetRetrieval:
expected_value, expected_value,
filters, filters,
) )
metadata_condition = MetadataCondition( metadata_condition = MetadataFilteringCondition(
logical_operator=metadata_filtering_conditions.logical_operator, logical_operator=metadata_filtering_conditions.logical_operator,
conditions=conditions, conditions=conditions,
) )
@ -1723,7 +1720,7 @@ class DatasetRetrieval:
self, self,
flask_app: Flask, flask_app: Flask,
available_datasets: list[Dataset], available_datasets: list[Dataset],
metadata_condition: MetadataCondition | None, metadata_condition: MetadataFilteringCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None, metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document], all_documents: list[Document],
tenant_id: str, tenant_id: str,

View File

@ -1,6 +1,15 @@
from typing import TypedDict
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
class I18nObjectDict(TypedDict):
zh_Hans: str | None
en_US: str
pt_BR: str | None
ja_JP: str | None
class I18nObject(BaseModel): class I18nObject(BaseModel):
""" """
Model class for i18n object. Model class for i18n object.
@ -18,5 +27,11 @@ class I18nObject(BaseModel):
self.ja_JP = self.ja_JP or self.en_US self.ja_JP = self.ja_JP or self.en_US
return self return self
def to_dict(self): def to_dict(self) -> I18nObjectDict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} result: I18nObjectDict = {
"zh_Hans": self.zh_Hans,
"en_US": self.en_US,
"pt_BR": self.pt_BR,
"ja_JP": self.ja_JP,
}
return result

View File

@ -6,9 +6,20 @@ from collections.abc import Mapping
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Any, Union from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
ValidationInfo,
field_serializer,
field_validator,
model_validator,
)
from typing_extensions import TypedDict
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import ( from core.plugin.entities.parameters import (
MCPServerParameterType, MCPServerParameterType,
PluginParameter, PluginParameter,
@ -18,11 +29,19 @@ from core.plugin.entities.parameters import (
cast_parameter_value, cast_parameter_value,
init_frontend_parameter, init_frontend_parameter,
) )
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
class EmojiIconDict(TypedDict):
background: str
content: str
emoji_icon_adapter: TypeAdapter[EmojiIconDict] = TypeAdapter(EmojiIconDict)
class ToolLabelEnum(StrEnum): class ToolLabelEnum(StrEnum):
SEARCH = "search" SEARCH = "search"
IMAGE = "image" IMAGE = "image"
@ -410,15 +429,6 @@ class ToolEntity(BaseModel):
return value or {} return value or {}
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
)
credentials_schema: list[ProviderConfig] = Field(
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel): class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity identity: ToolProviderIdentity
plugin_id: str | None = None plugin_id: str | None = None

View File

@ -5,16 +5,19 @@ import time
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
import sqlalchemy as sa import sqlalchemy as sa
from graphon.runtime import VariablePool from graphon.runtime import VariablePool
from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing_extensions import TypedDict
from yarl import URL from yarl import URL
import contexts import contexts
from configs import dify_config from configs import dify_config
from core.entities import PluginCredentialType
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
@ -27,7 +30,6 @@ from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from extensions.ext_database import db from extensions.ext_database import db
from models.provider_ids import ToolProviderID from models.provider_ids import ToolProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING: if TYPE_CHECKING:
@ -49,9 +51,11 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
EmojiIconDict,
ToolInvokeFrom, ToolInvokeFrom,
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
emoji_icon_adapter,
) )
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
@ -72,9 +76,7 @@ class ApiProviderControllerItem(TypedDict):
controller: ApiToolProviderController controller: ApiToolProviderController
class EmojiIconDict(TypedDict): _credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
background: str
content: str
class WorkflowToolRuntimeSpec(Protocol): class WorkflowToolRuntimeSpec(Protocol):
@ -203,16 +205,160 @@ class ToolManager:
:return: the tool :return: the tool
""" """
if provider_type == ToolProviderType.BUILT_IN: match provider_type:
# check if the builtin tool need credentials case ToolProviderType.BUILT_IN:
provider_controller = cls.get_builtin_provider(provider_id, tenant_id) provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
builtin_tool = provider_controller.get_tool(tool_name) builtin_tool = provider_controller.get_tool(tool_name)
if not builtin_tool: if not builtin_tool:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
if is_valid_uuid(credential_id):
try:
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
if builtin_provider is None:
with Session(db.engine) as session:
builtin_provider = session.scalar(
sa.select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = db.session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.limit(1)
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=builtin_provider.id,
provider=provider_id,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from core.plugin.impl.oauth import OAuthHandler
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = json.dumps(
encrypter.encrypt(refreshed_credentials.credentials)
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime( return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(decrypted_credentials),
credential_type=builtin_provider.credential_type,
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
case ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=api_provider,
)
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
case ToolProviderType.WORKFLOW:
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_provider = session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
@ -221,177 +367,28 @@ class ToolManager:
tool_invoke_from=tool_invoke_from, tool_invoke_from=tool_invoke_from,
) )
) )
builtin_provider = None case ToolProviderType.APP:
if isinstance(provider_controller, PluginToolProviderController): raise NotImplementedError("app provider not implemented")
provider_id_entity = ToolProviderID(provider_id) case ToolProviderType.PLUGIN:
# get specific credentials plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
if is_valid_uuid(credential_id): runtime = getattr(plugin_tool, "runtime", None)
try: if runtime is not None:
builtin_provider_stmt = select(BuiltinToolProvider).where( runtime.user_id = user_id
BuiltinToolProvider.tenant_id == tenant_id, runtime.invoke_from = invoke_from
BuiltinToolProvider.id == credential_id, runtime.tool_invoke_from = tool_invoke_from
) return plugin_tool
builtin_provider = db.session.scalar(builtin_provider_stmt) case ToolProviderType.MCP:
except Exception as e: mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
builtin_provider = None runtime = getattr(mcp_tool, "runtime", None)
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) if runtime is not None:
# if the provider has been deleted, raise an error runtime.user_id = user_id
if builtin_provider is None: runtime.invoke_from = invoke_from
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") runtime.tool_invoke_from = tool_invoke_from
return mcp_tool
# fallback to the default provider case ToolProviderType.DATASET_RETRIEVAL:
if builtin_provider is None: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
# use the default provider case _:
with Session(db.engine) as session: raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
builtin_provider = session.scalar(
sa.select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = db.session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.limit(1)
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# check if the credential is allowed to be used
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=builtin_provider.id,
provider=provider_id,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from core.plugin.impl.oauth import OAuthHandler
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = json.dumps(
encrypter.encrypt(refreshed_credentials.credentials)
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(decrypted_credentials),
credential_type=builtin_provider.credential_type,
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=api_provider,
)
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials=dict(encrypter.decrypt(credentials)),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
workflow_provider = session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
user_id=user_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
runtime = getattr(plugin_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return plugin_tool
elif provider_type == ToolProviderType.MCP:
mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
runtime = getattr(mcp_tool, "runtime", None)
if runtime is not None:
runtime.user_id = user_id
runtime.invoke_from = invoke_from
runtime.tool_invoke_from = tool_invoke_from
return mcp_tool
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@classmethod @classmethod
def get_agent_tool_runtime( def get_agent_tool_runtime(
@ -885,7 +882,7 @@ class ToolManager:
raise ValueError(f"you have not added provider {provider_name}") raise ValueError(f"you have not added provider {provider_name}")
try: try:
credentials = json.loads(provider_obj.credentials_str) or {} credentials = _credentials_adapter.validate_json(provider_obj.credentials_str) or {}
except Exception: except Exception:
credentials = {} credentials = {}
@ -910,7 +907,7 @@ class ToolManager:
masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials)) masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials))
try: try:
icon = json.loads(provider_obj.icon) icon = emoji_icon_adapter.validate_json(provider_obj.icon)
except Exception: except Exception:
icon = {"background": "#252525", "content": "\ud83d\ude01"} icon = {"background": "#252525", "content": "\ud83d\ude01"}
@ -973,7 +970,7 @@ class ToolManager:
if workflow_provider is None: if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
icon = json.loads(workflow_provider.icon) icon = emoji_icon_adapter.validate_json(workflow_provider.icon)
return icon return icon
except Exception: except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}
@ -990,7 +987,7 @@ class ToolManager:
if api_provider is None: if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found") raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
icon = json.loads(api_provider.icon) icon = emoji_icon_adapter.validate_json(api_provider.icon)
return icon return icon
except Exception: except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}
@ -1025,31 +1022,31 @@ class ToolManager:
:param provider_id: the id of the provider :param provider_id: the id of the provider
:return: :return:
""" """
provider_type = provider_type match provider_type:
provider_id = provider_id case ToolProviderType.BUILT_IN:
if provider_type == ToolProviderType.BUILT_IN: provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
provider = ToolManager.get_builtin_provider(provider_id, tenant_id) if isinstance(provider, PluginToolProviderController):
if isinstance(provider, PluginToolProviderController): try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id)
case ToolProviderType.API:
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
case ToolProviderType.WORKFLOW:
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
case ToolProviderType.PLUGIN:
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
try: try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception: except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id) case ToolProviderType.MCP:
elif provider_type == ToolProviderType.API: return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
return cls.generate_api_tool_icon_url(tenant_id, provider_id) case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL:
elif provider_type == ToolProviderType.WORKFLOW: raise ValueError(f"provider type {provider_type} not found")
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) case _:
elif provider_type == ToolProviderType.PLUGIN: raise ValueError(f"provider type {provider_type} not found")
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
elif provider_type == ToolProviderType.MCP:
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
else:
raise ValueError(f"provider type {provider_type} not found")
@classmethod @classmethod
def _convert_tool_parameters_type( def _convert_tool_parameters_type(

View File

@ -8,7 +8,7 @@ from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RagDocument from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_model import RerankModelRunner

View File

@ -6,8 +6,7 @@ from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import DocumentContext, RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval

View File

@ -305,14 +305,15 @@ class WorkflowTool(Tool):
"transfer_method": file.transfer_method.value, "transfer_method": file.transfer_method.value,
"type": file.type.value, "type": file.type.value,
} }
if file.transfer_method == FileTransferMethod.TOOL_FILE: match file.transfer_method:
file_dict["tool_file_id"] = resolve_file_record_id(file.reference) case FileTransferMethod.TOOL_FILE:
elif file.transfer_method == FileTransferMethod.LOCAL_FILE: file_dict["tool_file_id"] = resolve_file_record_id(file.reference)
file_dict["upload_file_id"] = resolve_file_record_id(file.reference) case FileTransferMethod.LOCAL_FILE:
elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: file_dict["upload_file_id"] = resolve_file_record_id(file.reference)
file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) case FileTransferMethod.DATASOURCE_FILE:
elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["datasource_file_id"] = resolve_file_record_id(file.reference)
file_dict["url"] = file.generate_url() case FileTransferMethod.REMOTE_URL:
file_dict["url"] = file.generate_url()
files.append(file_dict) files.append(file_dict)
except Exception: except Exception:
@ -357,8 +358,11 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict): def _update_file_mapping(self, file_dict: dict):
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
if transfer_method == FileTransferMethod.TOOL_FILE: match transfer_method:
file_dict["tool_file_id"] = file_id case FileTransferMethod.TOOL_FILE:
elif transfer_method == FileTransferMethod.LOCAL_FILE: file_dict["tool_file_id"] = file_id
file_dict["upload_file_id"] = file_id case FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_id
case FileTransferMethod.REMOTE_URL | FileTransferMethod.DATASOURCE_FILE:
pass
return file_dict return file_dict

View File

@ -6,6 +6,7 @@ from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import ( from core.plugin.entities.parameters import (
PluginParameterAutoGenerate, PluginParameterAutoGenerate,
PluginParameterOption, PluginParameterOption,
@ -108,13 +109,6 @@ class EventEntity(BaseModel):
return v or [] return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class SubscriptionConstructor(BaseModel): class SubscriptionConstructor(BaseModel):
""" """
The subscription constructor of the trigger provider The subscription constructor of the trigger provider

View File

@ -1,9 +1,10 @@
from typing import Literal, Union from typing import Union
from graphon.entities.base_node_data import BaseNodeData from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import NodeType from graphon.enums import NodeType
from pydantic import BaseModel from pydantic import BaseModel
from core.rag.entities.retrieval_settings import WeightedScoreConfig
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
@ -18,50 +19,6 @@ class RerankingModelConfig(BaseModel):
reranking_model_name: str reranking_model_name: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel): class RetrievalSetting(BaseModel):
""" """
Retrieval Setting. Retrieval Setting.
@ -77,16 +34,6 @@ class RetrievalSetting(BaseModel):
weights: WeightedScoreConfig | None = None weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class FileInfo(BaseModel): class FileInfo(BaseModel):
""" """
File Info. File Info.

View File

@ -1,9 +1,19 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Protocol from typing import Any, Protocol, TypedDict
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class IndexingResultDict(TypedDict):
dataset_id: str
dataset_name: str
batch: Any
document_id: str
document_name: str
created_at: float
display_status: str
class PreviewItem(BaseModel): class PreviewItem(BaseModel):
content: str | None = Field(default=None) content: str | None = Field(default=None)
child_chunks: list[str] | None = Field(default=None) child_chunks: list[str] | None = Field(default=None)
@ -34,7 +44,7 @@ class IndexProcessorProtocol(Protocol):
chunks: Mapping[str, Any], chunks: Mapping[str, Any],
batch: Any, batch: Any,
summary_index_setting: dict | None = None, summary_index_setting: dict | None = None,
) -> dict[str, Any]: ... ) -> IndexingResultDict: ...
def get_preview_output( def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None

View File

@ -1,4 +1,3 @@
from collections.abc import Sequence
from typing import Literal from typing import Literal
from graphon.entities.base_node_data import BaseNodeData from graphon.entities.base_node_data import BaseNodeData
@ -6,6 +5,10 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.nodes.llm.entities import ModelConfig, VisionConfig from graphon.nodes.llm.entities import ModelConfig, VisionConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.rag.entities import Condition, MetadataFilteringCondition, WeightedScoreConfig
__all__ = ["Condition"]
class RerankingModelConfig(BaseModel): class RerankingModelConfig(BaseModel):
""" """
@ -16,33 +19,6 @@ class RerankingModelConfig(BaseModel):
model: str model: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class MultipleRetrievalConfig(BaseModel): class MultipleRetrievalConfig(BaseModel):
""" """
Multiple Retrieval Config. Multiple Retrieval Config.
@ -64,50 +40,6 @@ class SingleRetrievalConfig(BaseModel):
model: ModelConfig model: ModelConfig
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData): class KnowledgeRetrievalNodeData(BaseNodeData):
""" """
Knowledge retrieval Node Data. Knowledge retrieval Node Data.

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any from typing import Any, TypedDict
from graphon.entities import GraphInitParams from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.entities.graph_config import NodeConfigDictAdapter
@ -107,6 +107,26 @@ class _WorkflowChildEngineBuilder:
return child_engine return child_engine
class _NodeConfigDict(TypedDict):
id: str
width: int
height: int
type: str
data: dict[str, Any]
class _EdgeConfigDict(TypedDict):
source: str
target: str
sourceHandle: str
targetHandle: str
class SingleNodeGraphDict(TypedDict):
nodes: list[_NodeConfigDict]
edges: list[_EdgeConfigDict]
class WorkflowEntry: class WorkflowEntry:
def __init__( def __init__(
self, self,
@ -318,7 +338,7 @@ class WorkflowEntry:
node_data: dict[str, Any], node_data: dict[str, Any],
node_width: int = 114, node_width: int = 114,
node_height: int = 514, node_height: int = 514,
) -> dict[str, Any]: ) -> SingleNodeGraphDict:
""" """
Create a minimal graph structure for testing a single node in isolation. Create a minimal graph structure for testing a single node in isolation.
@ -328,14 +348,14 @@ class WorkflowEntry:
:param node_height: height for UI layout (default: 100) :param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node :return: graph dictionary with start node and target node
""" """
node_config = { node_config: _NodeConfigDict = {
"id": node_id, "id": node_id,
"width": node_width, "width": node_width,
"height": node_height, "height": node_height,
"type": "custom", "type": "custom",
"data": node_data, "data": node_data,
} }
start_node_config = { start_node_config: _NodeConfigDict = {
"id": "start", "id": "start",
"width": node_width, "width": node_width,
"height": node_height, "height": node_height,
@ -346,9 +366,9 @@ class WorkflowEntry:
"desc": "Start", "desc": "Start",
}, },
} }
return { return SingleNodeGraphDict(
"nodes": [start_node_config, node_config], nodes=[start_node_config, node_config],
"edges": [ edges=[
{ {
"source": "start", "source": "start",
"target": node_id, "target": node_id,
@ -356,7 +376,7 @@ class WorkflowEntry:
"targetHandle": "target", "targetHandle": "target",
} }
], ],
} )
@classmethod @classmethod
def run_free_node( def run_free_node(

View File

@ -2,7 +2,7 @@ import logging
from typing import cast from typing import cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
from events.app_event import app_published_workflow_was_updated from events.app_event import app_published_workflow_was_updated
@ -45,7 +45,7 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
Returns: Returns:
Updated or created WorkflowSchedulePlan, or None if no schedule node Updated or created WorkflowSchedulePlan, or None if no schedule node
""" """
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
schedule_config = ScheduleService.extract_schedule_config(workflow) schedule_config = ScheduleService.extract_schedule_config(workflow)
existing_plan = session.scalar( existing_plan = session.scalar(
@ -59,7 +59,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
if existing_plan: if existing_plan:
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id) logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id) ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
session.commit()
return None return None
if existing_plan: if existing_plan:
@ -73,7 +72,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
schedule_id=existing_plan.id, schedule_id=existing_plan.id,
updates=updates, updates=updates,
) )
session.commit()
return updated_plan return updated_plan
else: else:
new_plan = ScheduleService.create_schedule( new_plan = ScheduleService.create_schedule(
@ -82,5 +80,4 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow)
app_id=app_id, app_id=app_id,
config=schedule_config, config=schedule_config,
) )
session.commit()
return new_plan return new_plan

View File

@ -1,7 +1,7 @@
from typing import cast from typing import cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from core.trigger.constants import TRIGGER_NODE_TYPES from core.trigger.constants import TRIGGER_NODE_TYPES
from events.app_event import app_published_workflow_was_updated from events.app_event import app_published_workflow_was_updated
@ -31,7 +31,7 @@ def handle(sender, **kwargs):
# Extract trigger info from workflow # Extract trigger info from workflow
trigger_infos = get_trigger_infos_from_workflow(published_workflow) trigger_infos = get_trigger_infos_from_workflow(published_workflow)
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
# Get existing app triggers # Get existing app triggers
existing_triggers = ( existing_triggers = (
session.execute( session.execute(
@ -79,8 +79,6 @@ def handle(sender, **kwargs):
existing_trigger.title = new_title existing_trigger.title = new_title
session.add(existing_trigger) session.add(existing_trigger)
session.commit()
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]: def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
""" """

View File

@ -354,11 +354,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
) -> WorkflowRun | None: ) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation).""" """Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db from extensions.ext_database import db
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
stmt = select(WorkflowRun).where( stmt = select(WorkflowRun).where(
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
) )
@ -439,11 +439,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None: def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore.""" """Fallback to PostgreSQL query for records not in LogStore."""
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db from extensions.ext_database import db
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id) stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
return session.scalar(stmt) return session.scalar(stmt)

Some files were not shown because too many files have changed in this diff Show More