mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 14:14:17 +08:00
Merge branch 'main' into yanli/fix-iter-log
This commit is contained in:
commit
f659eb48c6
13
.gemini/config.yaml
Normal file
13
.gemini/config.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
have_fun: false
|
||||
memory_config:
|
||||
disabled: false
|
||||
code_review:
|
||||
disable: true
|
||||
comment_severity_threshold: MEDIUM
|
||||
max_review_comments: -1
|
||||
pull_request_opened:
|
||||
help: false
|
||||
summary: false
|
||||
code_review: false
|
||||
include_drafts: false
|
||||
ignore_patterns: []
|
||||
9
.github/actions/setup-web/action.yml
vendored
9
.github/actions/setup-web/action.yml
vendored
@ -4,10 +4,9 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
node-version-file: web/.nvmrc
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
cache-dependency-path: web/pnpm-lock.yaml
|
||||
run-install: |
|
||||
cwd: ./web
|
||||
run-install: true
|
||||
|
||||
5
.github/workflows/autofix.yml
vendored
5
.github/workflows/autofix.yml
vendored
@ -94,11 +94,6 @@ jobs:
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
29
.github/workflows/style.yml
vendored
29
.github/workflows/style.yml
vendored
@ -84,20 +84,20 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
vp run lint:ci
|
||||
# pnpm run lint:report
|
||||
# continue-on-error: true
|
||||
|
||||
# - name: Annotate Code
|
||||
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
|
||||
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
|
||||
# with:
|
||||
# eslint-report: web/eslint_report.json
|
||||
# github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@ -114,6 +114,13 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@df37d2f0760a4b5683a6e617c9325bc1a36443f6 # v1.0.75
|
||||
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process.
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
## Automated Agent Contributions
|
||||
|
||||
> [!NOTE]
|
||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
||||
|
||||
@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
|
||||
@ -10,6 +10,7 @@ from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
@ -85,7 +86,7 @@ def migrate_annotation_vector_database():
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
@ -177,7 +178,9 @@ def migrate_knowledge_vector_database():
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
|
||||
select(Dataset)
|
||||
.where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
@ -269,7 +272,7 @@ def migrate_knowledge_vector_database():
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == "hierarchical_model":
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
|
||||
@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field(
|
||||
description="Auto build row count increment threshold (default is 500)",
|
||||
default=500,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field(
|
||||
description="Auto build row count increment ratio threshold (default is 0.05)",
|
||||
default=0.05,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field(
|
||||
description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)",
|
||||
default=300,
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -9,6 +9,7 @@ from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
@ -33,16 +34,10 @@ api_key_list_model = console_ns.model(
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
if resource_model == App:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.")
|
||||
@ -53,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model):
|
||||
class BaseApiKeyListResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
token_prefix: str | None = None
|
||||
@ -80,10 +75,13 @@ class BaseApiKeyListResource(Resource):
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.count()
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@ -94,6 +92,7 @@ class BaseApiKeyListResource(Resource):
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
assert self.resource_type is not None, "resource_type must be set"
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_tenant_id
|
||||
@ -107,7 +106,7 @@ class BaseApiKeyListResource(Resource):
|
||||
class BaseApiKeyResource(Resource):
|
||||
method_decorators = [account_initialization_required, login_required, setup_required]
|
||||
|
||||
resource_type: str | None = None
|
||||
resource_type: ApiTokenType | None = None
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
@ -119,14 +118,14 @@ class BaseApiKeyResource(Resource):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
getattr(ApiToken, self.resource_id_field) == resource_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@ -137,7 +136,7 @@ class BaseApiKeyResource(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
@ -162,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
@ -178,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "app"
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
|
||||
@ -202,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
token_prefix = "ds-"
|
||||
@ -218,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel):
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel):
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@ -594,7 +594,7 @@ class AppApi(Resource):
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon_type": args.icon_type,
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
|
||||
@ -5,7 +5,7 @@ from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -376,8 +376,12 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
@ -454,9 +458,7 @@ class ChatConversationApi(Resource):
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")
|
||||
)
|
||||
sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id"))
|
||||
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
|
||||
.subquery()
|
||||
)
|
||||
@ -511,8 +513,12 @@ class ChatConversationApi(Resource):
|
||||
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
query = (
|
||||
query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type]
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
@ -587,10 +593,8 @@ class ChatConversationDetailApi(Resource):
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
||||
@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
app = db.session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -47,7 +48,7 @@ class AppMCPServerController(Resource):
|
||||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@ -98,7 +99,7 @@ class AppMCPServerController(Resource):
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -30,6 +30,7 @@ from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
@ -243,27 +244,25 @@ class ChatMessageListApi(Resource):
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
first_message = db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1)
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise NotFound("First message not found")
|
||||
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
@ -271,16 +270,14 @@ class ChatMessageListApi(Resource):
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
history_messages = db.session.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args.limit:
|
||||
@ -325,7 +322,9 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@ -335,7 +334,7 @@ class MessageFeedbackApi(Resource):
|
||||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
feedback.rating = FeedbackRating(args.rating)
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
@ -347,9 +346,9 @@ class MessageFeedbackApi(Resource):
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
rating=FeedbackRating(rating_value),
|
||||
content=args.content,
|
||||
from_source="admin",
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
db.session.add(feedback)
|
||||
@ -374,7 +373,9 @@ class MessageAnnotationCountApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
|
||||
return {"count": count}
|
||||
|
||||
@ -478,7 +479,9 @@ class MessageApi(Resource):
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -69,9 +69,7 @@ class ModelConfigResource(Resource):
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id)
|
||||
if original_app_model_config is None:
|
||||
raise ValueError("Original app model config not found")
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
@ -75,7 +76,7 @@ class AppSite(Resource):
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
@ -7,7 +7,7 @@ from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
@ -46,13 +46,14 @@ from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -284,7 +285,9 @@ class DraftWorkflowApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables_list = Workflow.normalize_environment_variable_mappings(
|
||||
args.get("environment_variables") or [],
|
||||
)
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
@ -994,6 +997,43 @@ class PublishedAllWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>/restore")
|
||||
class DraftWorkflowRestoreApi(Resource):
|
||||
@console_ns.doc("restore_workflow_to_draft")
|
||||
@console_ns.doc(description="Restore a published workflow version into the draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"})
|
||||
@console_ns.response(200, "Workflow restored successfully")
|
||||
@console_ns.response(400, "Source workflow must be published")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
workflow = workflow_service.restore_published_workflow_to_draft(
|
||||
app_model=app_model,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
|
||||
@ -2,6 +2,8 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
@ -15,16 +17,14 @@ R1 = TypeVar("R1")
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1))
|
||||
return app_model
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
|
||||
@ -4,7 +4,7 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
# Create workspace if needed
|
||||
if (
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -112,6 +113,9 @@ class OAuthCallback(Resource):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
except ValueError as e:
|
||||
logger.warning("OAuth error with %s", provider, exc_info=True)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}")
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
@ -176,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -29,6 +29,7 @@ from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
@ -54,7 +55,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import SegmentStatus
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
@ -355,7 +356,7 @@ class DatasetListApi(Resource):
|
||||
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
@ -436,7 +437,7 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
data["embedding_model_provider"] = str(provider_id)
|
||||
@ -454,7 +455,7 @@ class DatasetApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data["indexing_technique"] == "high_quality":
|
||||
if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY:
|
||||
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
@ -485,7 +486,7 @@ class DatasetApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == "high_quality"
|
||||
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
and payload.embedding_model_provider is not None
|
||||
and payload.embedding_model is not None
|
||||
):
|
||||
@ -738,20 +739,23 @@ class DatasetIndexingStatusApi(Resource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@ -777,7 +781,7 @@ class DatasetIndexingStatusApi(Resource):
|
||||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = "dataset-"
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@ -802,9 +806,12 @@ class DatasetApiKeyApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@ -826,7 +833,7 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = "dataset"
|
||||
resource_type = ApiTokenType.DATASET
|
||||
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete dataset API key")
|
||||
@ -839,14 +846,14 @@ class DatasetApiDeleteApi(Resource):
|
||||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@ -857,7 +864,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, select
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -27,6 +27,7 @@ from core.model_manager import ModelManager
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from extensions.ext_database import db
|
||||
@ -211,12 +212,11 @@ class GetProcessRuleApi(Resource):
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
dataset_process_rule = db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == document.dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
@ -330,21 +330,23 @@ class DatasetDocumentListApi(Resource):
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
@ -448,7 +450,7 @@ class DatasetInitApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
@ -462,7 +464,7 @@ class DatasetInitApi(Resource):
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
)
|
||||
knowledge_config.is_multimodal = is_multimodal
|
||||
knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment]
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
@ -521,10 +523,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
file = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
@ -586,10 +588,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if file_detail is None:
|
||||
@ -672,20 +674,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@ -723,18 +728,23 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
@ -1258,11 +1268,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
@ -1328,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
raise BadRequest("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
|
||||
@ -26,6 +26,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -45,7 +46,7 @@ def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
@ -206,7 +207,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
@ -279,7 +280,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -333,7 +334,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -383,7 +384,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
|
||||
@ -559,17 +560,17 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
@ -6,7 +6,7 @@ from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -16,7 +16,11 @@ from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
DraftWorkflowNotSync,
|
||||
)
|
||||
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
|
||||
from controllers.console.app.workflow import (
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
workflow_model,
|
||||
workflow_pagination_model,
|
||||
)
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
@ -42,7 +46,8 @@ from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import EndUser
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -203,9 +208,12 @@ class DraftRagPipelineApi(Resource):
|
||||
abort(415)
|
||||
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
environment_variables_list = payload.environment_variables or []
|
||||
environment_variables_list = Workflow.normalize_environment_variable_mappings(
|
||||
payload.environment_variables or [],
|
||||
)
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
@ -213,7 +221,6 @@ class DraftRagPipelineApi(Resource):
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=payload.graph,
|
||||
@ -705,6 +712,36 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>/restore")
|
||||
class RagPipelineDraftWorkflowRestoreApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
workflow = rag_pipeline_service.restore_published_workflow_to_draft(
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
# Use a stable, predefined message to keep the 400 response consistent
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@setup_required
|
||||
|
||||
@ -2,6 +2,8 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
pipeline = db.session.scalar(
|
||||
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
@ -17,14 +18,18 @@ class BannerApi(Resource):
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED)
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
banners = db.session.scalars(
|
||||
base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort)
|
||||
).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
|
||||
@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource):
|
||||
def post(self):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||
recommended_app = db.session.scalar(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1)
|
||||
)
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
raise NotFound("App entity not found")
|
||||
@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource):
|
||||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
|
||||
@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Literal, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -476,7 +477,7 @@ class TrialSitApi(Resource):
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
@ -541,13 +542,7 @@ class AppWorkflowApi(Resource):
|
||||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
workflow = db.session.get(Workflow, app_model.workflow_id)
|
||||
return workflow
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
installed_app = db.session.scalar(
|
||||
select(InstalledApp)
|
||||
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1))
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
account_trial_app_record = db.session.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
|
||||
def get_setup_status() -> DifySetup | bool | None:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
return db.session.scalar(select(DifySetup).limit(1))
|
||||
|
||||
return True
|
||||
|
||||
@ -212,13 +212,13 @@ class AccountInitApi(Resource):
|
||||
raise ValueError("invitation_code is required")
|
||||
|
||||
# check invitation code
|
||||
invitation_code = (
|
||||
db.session.query(InvitationCode)
|
||||
invitation_code = db.session.scalar(
|
||||
select(InvitationCode)
|
||||
.where(
|
||||
InvitationCode.code == args.invitation_code,
|
||||
InvitationCode.status == InvitationCodeStatus.UNUSED,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not invitation_code:
|
||||
|
||||
@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if member is None:
|
||||
abort(404)
|
||||
else:
|
||||
|
||||
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
@ -29,6 +30,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
@ -108,9 +110,29 @@ class TenantListApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
|
||||
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
|
||||
tenant_plans: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
if is_saas:
|
||||
tenant_ids = [tenant.id for tenant in tenants]
|
||||
if tenant_ids:
|
||||
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
|
||||
if not tenant_plans:
|
||||
logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path")
|
||||
|
||||
for tenant in tenants:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan: str = CloudPlan.SANDBOX
|
||||
if is_saas:
|
||||
tenant_plan = tenant_plans.get(tenant.id)
|
||||
if tenant_plan:
|
||||
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
|
||||
else:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
tenant_dict = {
|
||||
@ -118,7 +140,7 @@ class TenantListApi(Resource):
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"plan": plan,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant
|
||||
new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant
|
||||
if new_tenant is None:
|
||||
raise ValueError("Tenant not found")
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
|
||||
@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
and os.environ.get("INIT_PASSWORD")
|
||||
and not db.session.query(DifySetup).first()
|
||||
):
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
|
||||
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
raise NotInitValidateError()
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_model = session.get(EndUser, user_id)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
tenant_model = db.session.get(Tenant, tenant_id)
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource):
|
||||
def post(self):
|
||||
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
account = db.session.query(Account).filter_by(email=args.owner_email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1))
|
||||
if account is None:
|
||||
return {"message": "owner account not found."}, 404
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
||||
if signature_base64 != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
|
||||
kwargs["user"] = db.session.get(EndUser, user_id)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.service_api.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
)
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
@ -153,15 +154,20 @@ class DatasetListApi(DatasetApiResource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if (
|
||||
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
|
||||
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
|
||||
):
|
||||
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
|
||||
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
|
||||
)
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = False
|
||||
item["embedding_available"] = False # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
if data.get("indexing_technique") == "high_quality":
|
||||
if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY:
|
||||
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
|
||||
if item_model in model_names:
|
||||
data["embedding_available"] = True
|
||||
@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource):
|
||||
# check embedding model setting
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if payload.indexing_technique == "high_quality" or embedding_model_provider:
|
||||
if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
@ -103,7 +104,7 @@ class SegmentApi(DatasetApiResource):
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -157,7 +158,7 @@ class SegmentApi(DatasetApiResource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
@ -262,7 +263,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -358,7 +359,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
|
||||
@ -8,6 +8,7 @@ from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@ -147,11 +148,11 @@ class HumanInputFormApi(Resource):
|
||||
|
||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
||||
app_model = db.session.get(App, form.app_id)
|
||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource):
|
||||
def get(self, app_model, end_user):
|
||||
"""Retrieve app site info."""
|
||||
# get site
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole, MessageStatus
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
type=file["type"],
|
||||
transfer_method=file["transfer_method"],
|
||||
url=file["remote_url"],
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
upload_file_id=file["related_id"],
|
||||
created_by_role=CreatorUserRole.ACCOUNT
|
||||
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
|
||||
@ -74,11 +74,22 @@ class AppGenerateResponseConverter(ABC):
|
||||
for resource in metadata["retriever_resources"]:
|
||||
updated_resources.append(
|
||||
{
|
||||
"dataset_id": resource.get("dataset_id"),
|
||||
"dataset_name": resource.get("dataset_name"),
|
||||
"document_id": resource.get("document_id"),
|
||||
"segment_id": resource.get("segment_id", ""),
|
||||
"position": resource["position"],
|
||||
"data_source_type": resource.get("data_source_type"),
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
"hit_count": resource.get("hit_count"),
|
||||
"word_count": resource.get("word_count"),
|
||||
"segment_position": resource.get("segment_position"),
|
||||
"index_node_hash": resource.get("index_node_hash"),
|
||||
"content": resource["content"],
|
||||
"page": resource.get("page"),
|
||||
"title": resource.get("title"),
|
||||
"files": resource.get("files"),
|
||||
"summary": resource.get("summary"),
|
||||
}
|
||||
)
|
||||
|
||||
@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -419,7 +419,7 @@ class AppRunner:
|
||||
message_id=message_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
url=f"/files/tools/{tool_file.id}",
|
||||
upload_file_id=tool_file.id,
|
||||
created_by_role=(
|
||||
|
||||
@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -130,10 +130,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
from_source = "api"
|
||||
from_source = ConversationFromSource.API
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = "console"
|
||||
from_source = ConversationFromSource.CONSOLE
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
|
||||
@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
|
||||
@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
created_from=created_from.value,
|
||||
created_from=created_from,
|
||||
created_by_role=self._created_by_role,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
@ -4,9 +4,10 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.enums import CollectionBindingType
|
||||
from models.enums import CollectionBindingType, ConversationFromSource
|
||||
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
@ -50,7 +51,7 @@ class AnnotationReplyFeature:
|
||||
dataset = Dataset(
|
||||
id=app_record.id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
indexing_technique=IndexTechniqueType.HIGH_QUALITY,
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
@ -68,9 +69,9 @@ class AnnotationReplyFeature:
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
|
||||
from_source = "api"
|
||||
from_source = ConversationFromSource.API
|
||||
else:
|
||||
from_source = "console"
|
||||
from_source = ConversationFromSource.CONSOLE
|
||||
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(
|
||||
|
||||
@ -19,6 +19,7 @@ class RateLimit:
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict: dict[str, "RateLimit"] = {}
|
||||
max_active_requests: int
|
||||
|
||||
def __new__(cls, client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
@ -27,7 +28,13 @@ class RateLimit:
|
||||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests
|
||||
self.max_active_requests = max_active_requests
|
||||
# Only flush here if this instance has already been fully initialized,
|
||||
# i.e. the Redis key attributes exist. Otherwise, rely on the flush at
|
||||
# the end of initialization below.
|
||||
if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"):
|
||||
self.flush_cache(use_local_value=True)
|
||||
# must be called after max_active_requests is set
|
||||
if self.disabled():
|
||||
return
|
||||
@ -41,8 +48,6 @@ class RateLimit:
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
if self.disabled():
|
||||
return
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
@ -50,7 +55,8 @@ class RateLimit:
|
||||
else:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
if self.disabled():
|
||||
return
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
|
||||
@ -6,16 +6,23 @@ from dify_graph.graph_events.graph import GraphRunPausedEvent
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
""" """
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._paused = False
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
self._paused = False
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
"""
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
pass
|
||||
self._paused = True
|
||||
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
pass
|
||||
self._paused = False
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import MessageFileBelongsTo
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@ -233,7 +234,7 @@ class MessageCycleManager:
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER,
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
||||
@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
@ -271,7 +271,7 @@ class IndexingRunner:
|
||||
doc_form: str | None = None,
|
||||
doc_language: str = "English",
|
||||
dataset_id: str | None = None,
|
||||
indexing_technique: str = "economy",
|
||||
indexing_technique: str = IndexTechniqueType.ECONOMY,
|
||||
) -> IndexingEstimate:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
@ -289,7 +289,7 @@ class IndexingRunner:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found.")
|
||||
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
||||
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@ -303,7 +303,7 @@ class IndexingRunner:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == "high_quality":
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
@ -573,7 +573,7 @@ class IndexingRunner:
|
||||
"""
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
@ -587,7 +587,7 @@ class IndexingRunner:
|
||||
create_keyword_thread = None
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
):
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
@ -597,7 +597,7 @@ class IndexingRunner:
|
||||
create_keyword_thread.start()
|
||||
|
||||
max_workers = 10
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
|
||||
@ -628,7 +628,7 @@ class IndexingRunner:
|
||||
tokens += future.result()
|
||||
if (
|
||||
dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX
|
||||
and dataset.indexing_technique == "economy"
|
||||
and dataset.indexing_technique == IndexTechniqueType.ECONOMY
|
||||
and create_keyword_thread is not None
|
||||
):
|
||||
create_keyword_thread.join()
|
||||
@ -654,7 +654,7 @@ class IndexingRunner:
|
||||
raise ValueError("no dataset found")
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
document_ids = [document.metadata["doc_id"] for document in documents]
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
@ -764,7 +764,7 @@ class IndexingRunner:
|
||||
) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
|
||||
@ -181,10 +181,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
|
||||
@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore
|
||||
for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
|
||||
@ -209,8 +209,7 @@ class PluginInstaller(BasePluginClient):
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/decode/from_identifier",
|
||||
PluginDecodeResponse,
|
||||
data={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
headers={"Content-Type": "application/json"},
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_by_ids(
|
||||
|
||||
@ -918,11 +918,11 @@ class ProviderManager:
|
||||
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL.value,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.PAID.value,
|
||||
pool_type=ProviderQuotaType.PAID,
|
||||
)
|
||||
else:
|
||||
trail_pool = None
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CleanProcessor:
|
||||
@classmethod
|
||||
def clean(cls, text: str, process_rule: dict) -> str:
|
||||
def clean(cls, text: str, process_rule: dict[str, Any] | None) -> str:
|
||||
# default clean
|
||||
# remove invalid symbol
|
||||
text = re.sub(r"<\|", "<", text)
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
@ -15,6 +16,11 @@ from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
|
||||
|
||||
|
||||
class PreSegmentData(TypedDict):
|
||||
segment: DocumentSegment
|
||||
keywords: list[str]
|
||||
|
||||
|
||||
class KeywordTableConfig(BaseModel):
|
||||
max_keywords_per_chunk: int = 10
|
||||
|
||||
@ -128,7 +134,7 @@ class Jieba(BaseKeyword):
|
||||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
storage.delete(file_key)
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
def _save_dataset_keyword_table(self, keyword_table: dict[str, set[str]] | None):
|
||||
keyword_table_dict = {
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
@ -144,7 +150,7 @@ class Jieba(BaseKeyword):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
|
||||
|
||||
def _get_dataset_keyword_table(self) -> dict | None:
|
||||
def _get_dataset_keyword_table(self) -> dict[str, set[str]] | None:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
@ -169,14 +175,16 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return {}
|
||||
|
||||
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]):
|
||||
def _add_text_to_keyword_table(
|
||||
self, keyword_table: dict[str, set[str]], id: str, keywords: list[str]
|
||||
) -> dict[str, set[str]]:
|
||||
for keyword in keywords:
|
||||
if keyword not in keyword_table:
|
||||
keyword_table[keyword] = set()
|
||||
keyword_table[keyword].add(id)
|
||||
return keyword_table
|
||||
|
||||
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]):
|
||||
def _delete_ids_from_keyword_table(self, keyword_table: dict[str, set[str]], ids: list[str]) -> dict[str, set[str]]:
|
||||
# get set of ids that correspond to node
|
||||
node_idxs_to_delete = set(ids)
|
||||
|
||||
@ -193,7 +201,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return keyword_table
|
||||
|
||||
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
|
||||
def _retrieve_ids_by_query(self, keyword_table: dict[str, set[str]], query: str, k: int = 4) -> list[str]:
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keywords = keyword_table_handler.extract_keywords(query)
|
||||
|
||||
@ -228,7 +236,7 @@ class Jieba(BaseKeyword):
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list[PreSegmentData]):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
|
||||
@ -103,7 +103,7 @@ class RetrievalService:
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: WeightsDict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_ids: list | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
):
|
||||
if not query and not attachment_ids:
|
||||
return []
|
||||
@ -250,8 +250,8 @@ class RetrievalService:
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
exceptions: list,
|
||||
all_documents: list[Document],
|
||||
exceptions: list[str],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
@ -279,9 +279,9 @@ class RetrievalService:
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
all_documents: list[Document],
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
exceptions: list[str],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
query_type: QueryType = QueryType.TEXT_QUERY,
|
||||
):
|
||||
@ -373,9 +373,9 @@ class RetrievalService:
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
reranking_model: RerankingModelDict | None,
|
||||
all_documents: list,
|
||||
all_documents: list[Document],
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
exceptions: list[str],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
|
||||
@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.database import Database
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import (
|
||||
AutoBuildRowCountIncrement,
|
||||
Field,
|
||||
FilteringIndex,
|
||||
HNSWParams,
|
||||
@ -51,6 +52,9 @@ class BaiduConfig(BaseModel):
|
||||
replicas: int = 3
|
||||
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
|
||||
inverted_index_parser_mode: str = "COARSE_MODE"
|
||||
auto_build_row_count_increment: int = 500
|
||||
auto_build_row_count_increment_ratio: float = 0.05
|
||||
rebuild_index_timeout_in_seconds: int = 300
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -107,18 +111,6 @@ class BaiduVector(BaseVector):
|
||||
rows.append(row)
|
||||
table.upsert(rows=rows)
|
||||
|
||||
# rebuild vector index after upsert finished
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
|
||||
if res and res.code == 0:
|
||||
@ -232,8 +224,14 @@ class BaiduVector(BaseVector):
|
||||
return self._client.database(self._client_config.database)
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
return any(table.table_name == self._collection_name for table in tables)
|
||||
try:
|
||||
table = self._db.table(self._collection_name)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
return False
|
||||
else:
|
||||
raise
|
||||
return True
|
||||
|
||||
def _create_table(self, dimension: int):
|
||||
# Try to grab distributed lock and create table
|
||||
@ -287,6 +285,11 @@ class BaiduVector(BaseVector):
|
||||
field=VDBField.VECTOR,
|
||||
metric_type=metric_type,
|
||||
params=HNSWParams(m=16, efconstruction=200),
|
||||
auto_build=True,
|
||||
auto_build_index_policy=AutoBuildRowCountIncrement(
|
||||
row_count_increment=self._client_config.auto_build_row_count_increment,
|
||||
row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@ -335,7 +338,7 @@ class BaiduVector(BaseVector):
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
timeout = 300 # 5 minutes timeout
|
||||
timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
@ -345,6 +348,20 @@ class BaiduVector(BaseVector):
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
# rebuild vector index immediately after table created, make sure index is ready
|
||||
table.rebuild_index(self.vector_index)
|
||||
timeout = 3600 # 1 hour timeout
|
||||
self._wait_for_index_ready(table, timeout)
|
||||
|
||||
def _wait_for_index_ready(self, table, timeout: int = 3600):
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
index = table.describe_index(self.vector_index)
|
||||
if index.state == IndexState.NORMAL:
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
|
||||
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
|
||||
auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT,
|
||||
auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO,
|
||||
rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS,
|
||||
),
|
||||
)
|
||||
|
||||
@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector):
|
||||
)
|
||||
)
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client import grpc # noqa
|
||||
@ -284,27 +285,29 @@ class TidbOnQdrantVector(BaseVector):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
if not ids:
|
||||
return
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchAny(any=ids),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(filter=filter),
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
@ -450,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
password=new_cluster["password"],
|
||||
tenant_id=dataset.tenant_id,
|
||||
active=True,
|
||||
status="ACTIVE",
|
||||
status=TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
||||
@ -9,6 +9,7 @@ from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
class TidbService:
|
||||
@ -170,7 +171,7 @@ class TidbService:
|
||||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = "ACTIVE"
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import AttachmentDocument, Document
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
@ -71,7 +72,7 @@ class DatasetDocumentStore:
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == "high_quality":
|
||||
if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@ -95,15 +95,11 @@ class FirecrawlApp:
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
total = crawl_status_response.get("total", 0)
|
||||
if total == 0:
|
||||
# Normalize to avoid None bypassing the zero-guard when the API returns null.
|
||||
total = crawl_status_response.get("total") or 0
|
||||
if total <= 0:
|
||||
raise Exception("Failed to check crawl status. Error: No page found")
|
||||
data = crawl_status_response.get("data", [])
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = self._extract_common_fields(item)
|
||||
url_data_list.append(url_data)
|
||||
url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers)
|
||||
if url_data_list:
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
try:
|
||||
@ -120,6 +116,36 @@ class FirecrawlApp:
|
||||
self._handle_error(response, "check crawl status")
|
||||
raise RuntimeError("unreachable: _handle_error always raises")
|
||||
|
||||
def _collect_all_crawl_pages(
|
||||
self, first_page: dict[str, Any], headers: dict[str, str]
|
||||
) -> list[FirecrawlDocumentData]:
|
||||
"""Collect all crawl result pages by following pagination links.
|
||||
|
||||
Raises an exception if any paginated request fails, to avoid returning
|
||||
partial data that is inconsistent with the reported total.
|
||||
|
||||
The number of pages processed is capped at ``total`` (the
|
||||
server-reported page count) to guard against infinite loops caused by
|
||||
a misbehaving server that keeps returning a ``next`` URL.
|
||||
"""
|
||||
total: int = first_page.get("total") or 0
|
||||
url_data_list: list[FirecrawlDocumentData] = []
|
||||
current_page = first_page
|
||||
pages_processed = 0
|
||||
while True:
|
||||
for item in current_page.get("data", []):
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data_list.append(self._extract_common_fields(item))
|
||||
next_url: str | None = current_page.get("next")
|
||||
pages_processed += 1
|
||||
if not next_url or pages_processed >= total:
|
||||
break
|
||||
response = self._get_request(next_url, headers)
|
||||
if response.status_code != 200:
|
||||
self._handle_error(response, "fetch next crawl page")
|
||||
current_page = response.json()
|
||||
return url_data_list
|
||||
|
||||
def _format_crawl_status_response(
|
||||
self,
|
||||
status: str,
|
||||
|
||||
@ -366,7 +366,7 @@ class WordExtractor(BaseExtractor):
|
||||
paragraph_content = []
|
||||
# State for legacy HYPERLINK fields
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts: list = []
|
||||
hyperlink_field_text_parts: list[str] = []
|
||||
is_collecting_field_text = False
|
||||
# Iterate through paragraph elements in document order
|
||||
for child in paragraph._element:
|
||||
|
||||
@ -9,6 +9,7 @@ from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
||||
@ -159,7 +160,7 @@ class IndexProcessor:
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
preview_output = self.format_preview(chunk_structure, chunks)
|
||||
if indexing_technique != "high_quality":
|
||||
if indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
|
||||
@ -22,7 +22,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -117,7 +117,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
@ -155,7 +155,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
@ -253,12 +253,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
elif dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
for document in documents:
|
||||
child_documents = document.children
|
||||
@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
vector = Vector(dataset)
|
||||
@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
# add document segments
|
||||
doc_store.add_documents(docs=documents, save_child=True)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
all_child_documents = []
|
||||
all_multimodal_documents = []
|
||||
for doc in documents:
|
||||
|
||||
@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict
|
||||
from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
# save node to document segment
|
||||
doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id)
|
||||
doc_store.add_documents(docs=documents, save_child=False)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
else:
|
||||
|
||||
@ -591,7 +591,7 @@ class DatasetRetrieval:
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
query: str,
|
||||
available_datasets: list,
|
||||
available_datasets: list[Dataset],
|
||||
model_instance: ModelInstance,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
planning_strategy: PlanningStrategy,
|
||||
@ -633,15 +633,15 @@ class DatasetRetrieval:
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
if dataset:
|
||||
selected_dataset = db.session.scalar(dataset_stmt)
|
||||
if selected_dataset:
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
if selected_dataset.provider == "external":
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=dataset.tenant_id,
|
||||
tenant_id=selected_dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
external_retrieval_parameters=selected_dataset.retrieval_model,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
@ -654,28 +654,28 @@ class DatasetRetrieval:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
document.metadata["dataset_name"] = selected_dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
return []
|
||||
document_ids_filter = None
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
document_ids = metadata_filter_document_ids.get(selected_dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
return []
|
||||
retrieval_model_config: DefaultRetrievalModelDict = (
|
||||
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
|
||||
if dataset.retrieval_model
|
||||
cast(DefaultRetrievalModelDict, selected_dataset.retrieval_model)
|
||||
if selected_dataset.retrieval_model
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
@ -694,7 +694,7 @@ class DatasetRetrieval:
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_method,
|
||||
dataset_id=dataset.id,
|
||||
dataset_id=selected_dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
@ -726,7 +726,7 @@ class DatasetRetrieval:
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_from: str,
|
||||
available_datasets: list,
|
||||
available_datasets: list[Dataset],
|
||||
query: str | None,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
@ -752,7 +752,7 @@ class DatasetRetrieval:
|
||||
"The configured knowledge base list have different indexing technique, please set reranking model."
|
||||
)
|
||||
index_type = available_datasets[0].indexing_technique
|
||||
if index_type == "high_quality":
|
||||
if index_type == IndexTechniqueType.HIGH_QUALITY:
|
||||
embedding_model_check = all(
|
||||
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
|
||||
)
|
||||
@ -1028,7 +1028,7 @@ class DatasetRetrieval:
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
all_documents: list,
|
||||
all_documents: list[Document],
|
||||
document_ids_filter: list[str] | None = None,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
@ -1068,7 +1068,7 @@ class DatasetRetrieval:
|
||||
else default_retrieval_model
|
||||
)
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
@ -1298,7 +1298,7 @@ class DatasetRetrieval:
|
||||
|
||||
def get_metadata_filter_condition(
|
||||
self,
|
||||
dataset_ids: list,
|
||||
dataset_ids: list[str],
|
||||
query: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
@ -1400,7 +1400,7 @@ class DatasetRetrieval:
|
||||
return output
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
self, dataset_ids: list[str], query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> list[dict[str, Any]] | None:
|
||||
# get all metadata field
|
||||
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
@ -1598,7 +1598,7 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
def _get_prompt_template(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
|
||||
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list[str], query: str
|
||||
):
|
||||
model_mode = ModelMode(mode)
|
||||
input_text = query
|
||||
@ -1690,7 +1690,7 @@ class DatasetRetrieval:
|
||||
def _multiple_retrieve_thread(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
available_datasets: list,
|
||||
available_datasets: list[Dataset],
|
||||
metadata_condition: MetadataCondition | None,
|
||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||
all_documents: list[Document],
|
||||
|
||||
@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
@ -21,7 +22,7 @@ class SummaryIndex:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset or dataset.indexing_technique != "high_quality":
|
||||
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
if summary_index_setting is None:
|
||||
|
||||
@ -50,7 +50,7 @@ class BuiltinTool(Tool):
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_type=ToolProviderType.BUILT_IN,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from dify_graph.file import FileType
|
||||
from dify_graph.file.models import FileTransferMethod
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -352,7 +352,7 @@ class ToolEngine:
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
|
||||
@ -38,7 +38,7 @@ class ToolLabelManager:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
tool_type=controller.provider_type,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
@ -58,7 +58,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
stmt = select(ToolLabelBinding.label_name).where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
ToolLabelBinding.tool_type == controller.provider_type,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
|
||||
@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||
if dataset.indexing_technique == "economy":
|
||||
if dataset.indexing_technique == IndexTechniqueType.ECONOMY:
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_tool_end(documents)
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
if dataset.indexing_technique != IndexTechniqueType.ECONOMY:
|
||||
for item in documents:
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
@ -9,6 +9,7 @@ from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessage
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
@ -78,7 +79,7 @@ class ModelInvocationUtils:
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
|
||||
@ -101,6 +101,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
# Must be 0 to disable executor-level retries, as the graph engine handles them.
|
||||
# This is critical to prevent nested retries.
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@ -36,6 +39,11 @@ from .exc import (
|
||||
)
|
||||
from .protocols import TemplateRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}")
|
||||
MAX_RESOLVED_VALUE_LENGTH = 1024
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
@ -475,3 +483,61 @@ def _append_file_prompts(
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
|
||||
def _coerce_resolved_value(raw: str) -> int | float | bool | str:
|
||||
"""Try to restore the original type from a resolved template string.
|
||||
|
||||
Variable references are always resolved to text, but completion params may
|
||||
expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to
|
||||
the ``temperature`` parameter). This helper attempts a JSON parse so that
|
||||
``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not
|
||||
valid JSON literals are returned as-is.
|
||||
"""
|
||||
stripped = raw.strip()
|
||||
if not stripped:
|
||||
return raw
|
||||
|
||||
try:
|
||||
parsed: object = json.loads(stripped)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return raw
|
||||
|
||||
if isinstance(parsed, (int, float, bool)):
|
||||
return parsed
|
||||
return raw
|
||||
|
||||
|
||||
def resolve_completion_params_variables(
|
||||
completion_params: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params.
|
||||
|
||||
Security notes:
|
||||
- Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to
|
||||
prevent denial-of-service through excessively large variable payloads.
|
||||
- This follows the same ``VariablePool.convert_template`` pattern used across
|
||||
Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream
|
||||
model plugin receives these values as structured JSON key-value pairs — they
|
||||
are never concatenated into raw HTTP headers or SQL queries.
|
||||
- Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are
|
||||
restored to their native type rather than sent as a bare string.
|
||||
"""
|
||||
resolved: dict[str, Any] = {}
|
||||
for key, value in completion_params.items():
|
||||
if isinstance(value, str) and VARIABLE_PATTERN.search(value):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
text = segment_group.text
|
||||
if len(text) > MAX_RESOLVED_VALUE_LENGTH:
|
||||
logger.warning(
|
||||
"Resolved value for param '%s' truncated from %d to %d chars",
|
||||
key,
|
||||
len(text),
|
||||
MAX_RESOLVED_VALUE_LENGTH,
|
||||
)
|
||||
text = text[:MAX_RESOLVED_VALUE_LENGTH]
|
||||
resolved[key] = _coerce_resolved_value(text)
|
||||
else:
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
@ -202,6 +202,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# fetch model config
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
model_name = model_instance.model_name
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
|
||||
@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
|
||||
@ -114,6 +114,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
memory = self._memory
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
|
||||
@ -125,7 +125,8 @@ class BroadcastChannel(Protocol):
|
||||
a specific topic, all subscription should receive the published message.
|
||||
|
||||
There are no restriction for the persistence of messages. Once a subscription is created, it
|
||||
should receive all subsequent messages published.
|
||||
should receive all subsequent messages published. However, a subscription should not receive
|
||||
any message published before the subscription is established.
|
||||
|
||||
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
|
||||
"""
|
||||
|
||||
@ -63,21 +63,45 @@ class _StreamsSubscription(Subscription):
|
||||
def __init__(self, client: Redis | RedisCluster, key: str):
|
||||
self._client = client
|
||||
self._key = key
|
||||
self._closed = threading.Event()
|
||||
self._last_id = "0-0"
|
||||
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
self._start_lock = threading.Lock()
|
||||
|
||||
# The `_lock` lock is used to
|
||||
#
|
||||
# 1. protect the _listener attribute
|
||||
# 2. prevent repeated releases of underlying resoueces. (The _closed flag.)
|
||||
#
|
||||
# INVARIANT: the implementation must hold the lock while
|
||||
# reading and writing the _listener / `_closed` attribute.
|
||||
self._lock = threading.Lock()
|
||||
self._closed: bool = False
|
||||
# self._closed = threading.Event()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
||||
def _listen(self) -> None:
|
||||
try:
|
||||
while not self._closed.is_set():
|
||||
streams = self._client.xread({self._key: self._last_id}, block=1000, count=100)
|
||||
"""The `_listen` method handles the message retrieval loop. It requires a dedicated thread
|
||||
and is not intended for direct invocation.
|
||||
|
||||
The thread is started by `_start_if_needed`.
|
||||
"""
|
||||
|
||||
# since this method runs in a dedicated thread, acquiring `_lock` inside this method won't cause
|
||||
# deadlock.
|
||||
|
||||
# Setting initial last id to `$` to signal redis that we only want new messages.
|
||||
#
|
||||
# ref: https://redis.io/docs/latest/commands/xread/#the-special--id
|
||||
last_id = "$"
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
break
|
||||
streams = self._client.xread({self._key: last_id}, block=1000, count=100)
|
||||
if not streams:
|
||||
continue
|
||||
|
||||
for _key, entries in streams:
|
||||
for _, entries in streams:
|
||||
for entry_id, fields in entries:
|
||||
data = None
|
||||
if isinstance(fields, dict):
|
||||
@ -89,37 +113,48 @@ class _StreamsSubscription(Subscription):
|
||||
data_bytes = bytes(data)
|
||||
if data_bytes is not None:
|
||||
self._queue.put_nowait(data_bytes)
|
||||
self._last_id = entry_id
|
||||
last_id = entry_id
|
||||
finally:
|
||||
self._queue.put_nowait(self._SENTINEL)
|
||||
self._listener = None
|
||||
with self._lock:
|
||||
self._listener = None
|
||||
self._closed = True
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""This method must be called with `_lock` held."""
|
||||
if self._listener is not None:
|
||||
return
|
||||
# Ensure only one listener thread is created under concurrent calls
|
||||
with self._start_lock:
|
||||
if self._listener is not None or self._closed.is_set():
|
||||
return
|
||||
self._listener = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-streams-sub-{self._key}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener.start()
|
||||
if self._listener is not None or self._closed:
|
||||
return
|
||||
self._listener = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-streams-sub-{self._key}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener.start()
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
# Iterator delegates to receive with timeout; stops on closure.
|
||||
self._start_if_needed()
|
||||
while not self._closed.is_set():
|
||||
item = self.receive(timeout=1)
|
||||
with self._lock:
|
||||
self._start_if_needed()
|
||||
|
||||
while True:
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
return
|
||||
try:
|
||||
item = self.receive(timeout=1)
|
||||
except SubscriptionClosedError:
|
||||
return
|
||||
if item is not None:
|
||||
yield item
|
||||
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis streams subscription is closed")
|
||||
self._start_if_needed()
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
raise SubscriptionClosedError("The Redis streams subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
if timeout is None:
|
||||
@ -129,29 +164,33 @@ class _StreamsSubscription(Subscription):
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
if item is self._SENTINEL or self._closed.is_set():
|
||||
if item is self._SENTINEL:
|
||||
raise SubscriptionClosedError("The Redis streams subscription is closed")
|
||||
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
|
||||
return bytes(item)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
self._closed.set()
|
||||
listener = self._listener
|
||||
if listener is not None:
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
listener = self._listener
|
||||
if listener is not None:
|
||||
self._listener = None
|
||||
# We close the listener outside of the with block to avoid holding the
|
||||
# lock for a long time.
|
||||
if listener is not None and listener.is_alive():
|
||||
listener.join(timeout=2.0)
|
||||
if listener.is_alive():
|
||||
logger.warning(
|
||||
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
|
||||
self._key,
|
||||
)
|
||||
else:
|
||||
self._listener = None
|
||||
|
||||
# Context manager helpers
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
with self._lock:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool | None:
|
||||
|
||||
@ -18,15 +18,23 @@ if TYPE_CHECKING:
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def _resolve_current_user() -> EndUser | Account | None:
|
||||
"""
|
||||
Resolve the current user proxy to its underlying user object.
|
||||
This keeps unit tests working when they patch `current_user` directly
|
||||
instead of bootstrapping a full Flask-Login manager.
|
||||
"""
|
||||
user_proxy = current_user
|
||||
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
|
||||
|
||||
def current_account_with_tenant():
|
||||
"""
|
||||
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
||||
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
||||
"""
|
||||
user_proxy = current_user
|
||||
|
||||
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
user = _resolve_current_user()
|
||||
|
||||
if not isinstance(user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
user = _resolve_current_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
g._login_user = user
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, user.id)
|
||||
|
||||
@ -1,16 +1,19 @@
|
||||
import logging
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
@ -30,8 +33,8 @@ class GitHubEmailRecord(TypedDict, total=False):
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
name: NotRequired[str | None]
|
||||
email: NotRequired[str | None]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
@ -127,9 +130,14 @@ class GitHubOAuth(OAuth):
|
||||
response.raise_for_status()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
try:
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_response.raise_for_status()
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
except (httpx.HTTPStatusError, ValidationError):
|
||||
logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True)
|
||||
primary_email = None
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
@ -137,8 +145,11 @@ class GitHubOAuth(OAuth):
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
raise ValueError(
|
||||
'Dify currently not supports the "Keep my email addresses private" feature,'
|
||||
" please disable it and login again"
|
||||
)
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
|
||||
@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.signature import sign_upload_file
|
||||
@ -43,7 +43,9 @@ from .enums import (
|
||||
IndexingStatus,
|
||||
ProcessRuleMode,
|
||||
SegmentStatus,
|
||||
SegmentType,
|
||||
SummaryStatus,
|
||||
TidbAuthBindingStatus,
|
||||
)
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
|
||||
@ -135,7 +137,7 @@ class Dataset(Base):
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
data_source_type = mapped_column(EnumText(DataSourceType, length=255))
|
||||
indexing_technique: Mapped[str | None] = mapped_column(String(255))
|
||||
indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255))
|
||||
index_struct = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -494,7 +496,9 @@ class Document(Base):
|
||||
)
|
||||
doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True)
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_form: Mapped[IndexStructureType] = mapped_column(
|
||||
EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'")
|
||||
)
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@ -998,7 +1002,9 @@ class ChildChunk(Base):
|
||||
# indexing fields
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase):
|
||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
|
||||
status: Mapped[TidbAuthBindingStatus] = mapped_column(
|
||||
EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'")
|
||||
)
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -158,6 +158,13 @@ class FeedbackFromSource(StrEnum):
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class FeedbackRating(StrEnum):
|
||||
"""MessageFeedback rating"""
|
||||
|
||||
LIKE = "like"
|
||||
DISLIKE = "dislike"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""How a conversation/message was invoked"""
|
||||
|
||||
@ -215,6 +222,13 @@ class DatasetMetadataType(StrEnum):
|
||||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentType(StrEnum):
|
||||
"""Document segment type"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOMIZED = "customized"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
@ -316,3 +330,10 @@ class ProviderQuotaType(StrEnum):
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ApiTokenType(StrEnum):
|
||||
"""API Token type"""
|
||||
|
||||
APP = "app"
|
||||
DATASET = "dataset"
|
||||
|
||||
@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent):
|
||||
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
@classmethod
|
||||
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
|
||||
return cls(form_id=form_id, message_id=message_id)
|
||||
def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent":
|
||||
return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id)
|
||||
|
||||
form: Mapped["HumanInputForm"] = relationship(
|
||||
"HumanInputForm",
|
||||
|
||||
@ -21,7 +21,7 @@ from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
@ -31,13 +31,21 @@ from .account import Account, Tenant
|
||||
from .base import Base, TypeBase, gen_uuidv4_string
|
||||
from .engine import db
|
||||
from .enums import (
|
||||
ApiTokenType,
|
||||
AppMCPServerStatus,
|
||||
AppStatus,
|
||||
BannerStatus,
|
||||
ConversationFromSource,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
InvokeFrom,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
ProviderQuotaType,
|
||||
TagType,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
@ -581,7 +589,9 @@ class AppModelConfig(TypeBase):
|
||||
__tablename__ = "app_model_configs"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
@ -930,7 +940,9 @@ class AccountTrialAppRecord(Base):
|
||||
class ExporleBanner(TypeBase):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
@ -1019,10 +1031,12 @@ class Conversation(Base):
|
||||
#
|
||||
# Its value corresponds to the members of `InvokeFrom`.
|
||||
# (api/core/app/entities/app_invoke_entities.py)
|
||||
invoke_from = mapped_column(String(255), nullable=True)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
|
||||
# ref: ConversationSource.
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
read_at = mapped_column(sa.DateTime)
|
||||
@ -1165,7 +1179,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@ -1176,7 +1190,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@ -1191,7 +1205,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@ -1202,7 +1216,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@ -1371,8 +1385,10 @@ class Message(Base):
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(LongText)
|
||||
message_metadata: Mapped[str | None] = mapped_column(LongText)
|
||||
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True)
|
||||
from_source: Mapped[ConversationFromSource] = mapped_column(
|
||||
EnumText(ConversationFromSource, length=255), nullable=False
|
||||
)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
@ -1725,8 +1741,8 @@ class MessageFeedback(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
|
||||
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
|
||||
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
@ -1773,13 +1789,15 @@ class MessageFile(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(
|
||||
EnumText(FileTransferMethod, length=255), nullable=False
|
||||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
|
||||
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
|
||||
)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@ -1833,7 +1851,9 @@ class AppAnnotationHitHistory(TypeBase):
|
||||
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
source: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
@ -2083,7 +2103,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field.
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
@ -2393,7 +2413,7 @@ class Tag(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@ -2478,7 +2498,9 @@ class TenantCreditPool(TypeBase):
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||
pool_type: Mapped[ProviderQuotaType] = mapped_column(
|
||||
EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial"
|
||||
)
|
||||
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@ -13,12 +13,16 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderSchemaType,
|
||||
ToolProviderType,
|
||||
WorkflowToolParameterConfiguration,
|
||||
)
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
from .types import LongText, StringUUID
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
@ -141,7 +145,9 @@ class ApiToolProvider(TypeBase):
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column(
|
||||
EnumText(ApiProviderSchemaType, length=40), nullable=False
|
||||
)
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
@ -208,7 +214,7 @@ class ToolLabelBinding(TypeBase):
|
||||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# label name
|
||||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
@ -386,7 +392,7 @@ class ToolModelInvoke(TypeBase):
|
||||
# provider
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False)
|
||||
# tool name
|
||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
@ -302,26 +303,40 @@ class Workflow(Base): # bug
|
||||
def features(self) -> str:
|
||||
"""
|
||||
Convert old features structure to new features structure.
|
||||
|
||||
This property avoids rewriting the underlying JSON when normalization
|
||||
produces no effective change, to prevent marking the row dirty on read.
|
||||
"""
|
||||
if not self._features:
|
||||
return self._features
|
||||
|
||||
features = json.loads(self._features)
|
||||
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
|
||||
image_enabled = True
|
||||
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
|
||||
image_transfer_methods = features["file_upload"]["image"].get(
|
||||
"transfer_methods", ["remote_url", "local_file"]
|
||||
)
|
||||
features["file_upload"]["enabled"] = image_enabled
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
|
||||
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
|
||||
"allowed_file_extensions", []
|
||||
)
|
||||
del features["file_upload"]["image"]
|
||||
self._features = json.dumps(features)
|
||||
# Parse once and deep-copy before normalization to detect in-place changes.
|
||||
original_dict = self._decode_features_payload(self._features)
|
||||
if original_dict is None:
|
||||
return self._features
|
||||
|
||||
# Fast-path: if the legacy file_upload.image.enabled shape is absent, skip
|
||||
# deep-copy and normalization entirely and return the stored JSON.
|
||||
file_upload_payload = original_dict.get("file_upload")
|
||||
if not isinstance(file_upload_payload, dict):
|
||||
return self._features
|
||||
file_upload = cast(dict[str, Any], file_upload_payload)
|
||||
|
||||
image_payload = file_upload.get("image")
|
||||
if not isinstance(image_payload, dict):
|
||||
return self._features
|
||||
image = cast(dict[str, Any], image_payload)
|
||||
if "enabled" not in image:
|
||||
return self._features
|
||||
|
||||
normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict))
|
||||
|
||||
if normalized_dict == original_dict:
|
||||
# No effective change; return stored JSON unchanged.
|
||||
return self._features
|
||||
|
||||
# Normalization changed the payload: persist the normalized JSON.
|
||||
self._features = json.dumps(normalized_dict)
|
||||
return self._features
|
||||
|
||||
@features.setter
|
||||
@ -332,6 +347,44 @@ class Workflow(Base): # bug
|
||||
def features_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.features) if self.features else {}
|
||||
|
||||
@property
|
||||
def serialized_features(self) -> str:
|
||||
"""Return the stored features JSON without triggering compatibility rewrites."""
|
||||
return self._features
|
||||
|
||||
@property
|
||||
def normalized_features_dict(self) -> dict[str, Any]:
|
||||
"""Decode features with legacy normalization without mutating the model state."""
|
||||
if not self._features:
|
||||
return {}
|
||||
|
||||
features = self._decode_features_payload(self._features)
|
||||
return self._normalize_features_payload(features) if features is not None else {}
|
||||
|
||||
@staticmethod
|
||||
def _decode_features_payload(features: str) -> dict[str, Any] | None:
|
||||
"""Decode workflow features JSON when it contains an object payload."""
|
||||
payload = json.loads(features)
|
||||
return cast(dict[str, Any], payload) if isinstance(payload, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]:
|
||||
if features.get("file_upload", {}).get("image", {}).get("enabled", False):
|
||||
image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS))
|
||||
image_transfer_methods = features["file_upload"]["image"].get(
|
||||
"transfer_methods", ["remote_url", "local_file"]
|
||||
)
|
||||
features["file_upload"]["enabled"] = True
|
||||
features["file_upload"]["number_limits"] = image_number_limits
|
||||
features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods
|
||||
features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"])
|
||||
features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get(
|
||||
"allowed_file_extensions", []
|
||||
)
|
||||
del features["file_upload"]["image"]
|
||||
|
||||
return features
|
||||
|
||||
def walk_nodes(
|
||||
self, specific_node_type: NodeType | None = None
|
||||
) -> Generator[tuple[str, Mapping[str, Any]], None, None]:
|
||||
@ -517,6 +570,31 @@ class Workflow(Base): # bug
|
||||
)
|
||||
self._environment_variables = environment_variables_json
|
||||
|
||||
@staticmethod
|
||||
def normalize_environment_variable_mappings(
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert masked secret placeholders into the draft hidden sentinel.
|
||||
|
||||
Regular draft sync requests should preserve existing secrets without shipping
|
||||
plaintext values back from the client. The dedicated restore endpoint now
|
||||
copies published secrets server-side, so draft sync only needs to normalize
|
||||
the UI mask into `HIDDEN_VALUE`.
|
||||
"""
|
||||
masked_secret_value = encrypter.full_mask_token()
|
||||
normalized_mappings: list[dict[str, Any]] = []
|
||||
|
||||
for mapping in mappings:
|
||||
normalized_mapping = dict(mapping)
|
||||
if (
|
||||
normalized_mapping.get("value_type") == SegmentType.SECRET.value
|
||||
and normalized_mapping.get("value") == masked_secret_value
|
||||
):
|
||||
normalized_mapping["value"] = HIDDEN_VALUE
|
||||
normalized_mappings.append(normalized_mapping)
|
||||
|
||||
return normalized_mappings
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
|
||||
environment_variables = list(self.environment_variables)
|
||||
environment_variables = [
|
||||
@ -564,6 +642,12 @@ class Workflow(Base): # bug
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None:
|
||||
"""Copy stored variable JSON directly for same-tenant restore flows."""
|
||||
self._environment_variables = source_workflow._environment_variables
|
||||
self._conversation_variables = source_workflow._conversation_variables
|
||||
self._rag_pipeline_variables = source_workflow._rag_pipeline_variables
|
||||
|
||||
@staticmethod
|
||||
def version_from_datetime(d: datetime) -> str:
|
||||
return str(d)
|
||||
@ -1137,7 +1221,9 @@ class WorkflowAppLog(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False
|
||||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@ -1217,10 +1303,14 @@ class WorkflowArchiveLog(TypeBase):
|
||||
|
||||
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column(
|
||||
EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True
|
||||
)
|
||||
|
||||
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[WorkflowExecutionStatus] = mapped_column(
|
||||
EnumText(WorkflowExecutionStatus, length=255), nullable=False
|
||||
)
|
||||
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
|
||||
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
|
||||
)
|
||||
|
||||
@ -8,7 +8,7 @@ dependencies = [
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.68",
|
||||
"boto3==1.42.73",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
@ -23,7 +23,7 @@ dependencies = [
|
||||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.192.0",
|
||||
"google-api-python-client==2.193.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-auth-httplib2==0.3.0",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
@ -40,7 +40,7 @@ dependencies = [
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
@ -72,13 +72,14 @@ dependencies = [
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.52.1",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"pypandoc~=1.13",
|
||||
"yarl~=1.23.0",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.9.0",
|
||||
@ -91,7 +92,7 @@ dependencies = [
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.2.0",
|
||||
"bleach~=6.3.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
@ -118,7 +119,7 @@ dev = [
|
||||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-cov~=7.1.0",
|
||||
"pytest-env~=1.6.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
@ -173,7 +174,7 @@ dev = [
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.55.0",
|
||||
"pyrefly>=0.57.1",
|
||||
]
|
||||
|
||||
############################################################
|
||||
@ -202,7 +203,7 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_gpdb20160503~=5.1.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.14.1",
|
||||
|
||||
@ -8,6 +8,7 @@ from configs import dify_config
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
@ -57,7 +58,7 @@ def create_clusters(batch_size):
|
||||
account=new_cluster["account"],
|
||||
password=new_cluster["password"],
|
||||
active=False,
|
||||
status="CREATING",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
)
|
||||
db.session.add(tidb_auth_binding)
|
||||
db.session.commit()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user