Merge branch 'feat/new-biliing-quota' into deploy/dev

This commit is contained in:
hj24 2026-04-08 15:02:50 +08:00
commit ef7dc9eabb
929 changed files with 23722 additions and 15509 deletions

9
.github/labeler.yml vendored
View File

@ -1,3 +1,10 @@
web: web:
- changed-files: - changed-files:
- any-glob-to-any-file: 'web/**' - any-glob-to-any-file:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'

View File

@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) - [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. - [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly. - [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods - [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods

View File

@ -39,9 +39,11 @@ jobs:
with: with:
files: | files: |
web/** web/**
packages/**
package.json package.json
pnpm-lock.yaml pnpm-lock.yaml
pnpm-workspace.yaml pnpm-workspace.yaml
.npmrc
.nvmrc .nvmrc
- name: Check api inputs - name: Check api inputs
if: github.event_name != 'merge_group' if: github.event_name != 'merge_group'

View File

@ -65,7 +65,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
with: with:
username: ${{ env.DOCKERHUB_USER }} username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }} password: ${{ env.DOCKERHUB_TOKEN }}
@ -130,7 +130,7 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
with: with:
username: ${{ env.DOCKERHUB_USER }} username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }} password: ${{ env.DOCKERHUB_TOKEN }}

View File

@ -8,9 +8,11 @@ on:
- api/Dockerfile - api/Dockerfile
- web/docker/** - web/docker/**
- web/Dockerfile - web/Dockerfile
- packages/**
- package.json - package.json
- pnpm-lock.yaml - pnpm-lock.yaml
- pnpm-workspace.yaml - pnpm-workspace.yaml
- .npmrc
- .nvmrc - .nvmrc
concurrency: concurrency:

View File

@ -65,9 +65,11 @@ jobs:
- 'docker/volumes/sandbox/conf/**' - 'docker/volumes/sandbox/conf/**'
web: web:
- 'web/**' - 'web/**'
- 'packages/**'
- 'package.json' - 'package.json'
- 'pnpm-lock.yaml' - 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml' - 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc' - '.nvmrc'
- '.github/workflows/web-tests.yml' - '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**' - '.github/actions/setup-web/**'
@ -77,9 +79,11 @@ jobs:
- 'api/uv.lock' - 'api/uv.lock'
- 'e2e/**' - 'e2e/**'
- 'web/**' - 'web/**'
- 'packages/**'
- 'package.json' - 'package.json'
- 'pnpm-lock.yaml' - 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml' - 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc' - '.nvmrc'
- 'docker/docker-compose.middleware.yaml' - 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example' - 'docker/middleware.env.example'

View File

@ -77,9 +77,11 @@ jobs:
with: with:
files: | files: |
web/** web/**
packages/**
package.json package.json
pnpm-lock.yaml pnpm-lock.yaml
pnpm-workspace.yaml pnpm-workspace.yaml
.npmrc
.nvmrc .nvmrc
.github/workflows/style.yml .github/workflows/style.yml
.github/actions/setup-web/** .github/actions/setup-web/**
@ -149,7 +151,7 @@ jobs:
.editorconfig .editorconfig
- name: Super-linter - name: Super-linter
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0 uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
env: env:
BASH_SEVERITY: warning BASH_SEVERITY: warning

View File

@ -9,6 +9,7 @@ on:
- package.json - package.json
- pnpm-lock.yaml - pnpm-lock.yaml
- pnpm-workspace.yaml - pnpm-workspace.yaml
- .npmrc
concurrency: concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }} group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@ -240,7 +240,7 @@ jobs:
- name: Run Claude Code for Translation Sync - name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != '' if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82 uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
with: with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true remove_tool_cache: true
- name: Setup UV and Python - name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with: with:
enable-cache: true enable-cache: true
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}

View File

@ -89,30 +89,10 @@ if $web_modified; then
echo "No staged TypeScript changes detected, skipping type-check:tsgo" echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi fi
echo "Running unit tests check" echo "Running knip"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true) if ! pnpm run knip; then
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
if [ -n "$modified_files" ]; then exit 1
for file in $modified_files; do
test_file="${file%.*}.spec.ts"
echo "Checking for test file: $test_file"
# check if the test file exists
if [ -f "../$test_file" ]; then
echo "Detected changes in $file, running corresponding unit tests..."
pnpm run test "../$test_file"
if [ $? -ne 0 ]; then
echo "Unit tests failed. Please fix the errors before committing."
exit 1
fi
echo "Unit tests for $file passed."
else
echo "Warning: $file does not have a corresponding test file."
fi
done
echo "All unit tests for modified web/utils files have passed."
fi fi
cd ../ cd ../

18
api/celery_healthcheck.py Normal file
View File

@ -0,0 +1,18 @@
# This module provides a lightweight Celery instance for use in Docker health checks.
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
# Using this module keeps the health check fast and low-cost.
from celery import Celery
from configs import dify_config
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
broker_transport_options = get_celery_broker_transport_options()
if broker_transport_options:
celery.conf.update(broker_transport_options=broker_transport_options)
ssl_options = get_celery_ssl_options()
if ssl_options:
celery.conf.update(broker_use_ssl=ssl_options)

View File

@ -1,7 +1,7 @@
import datetime import datetime
import logging import logging
import time import time
from typing import Any from typing import TypedDict
import click import click
import sqlalchemy as sa import sqlalchemy as sa
@ -503,7 +503,19 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
return [row[0] for row in result] return [row[0] for row in result]
def _count_orphaned_draft_variables() -> dict[str, Any]: class _AppOrphanCounts(TypedDict):
variables: int
files: int
class OrphanedDraftVariableStatsDict(TypedDict):
total_orphaned_variables: int
total_orphaned_files: int
orphaned_app_count: int
orphaned_by_app: dict[str, _AppOrphanCounts]
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
""" """
Count orphaned draft variables by app, including associated file counts. Count orphaned draft variables by app, including associated file counts.
@ -526,7 +538,7 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
with db.engine.connect() as conn: with db.engine.connect() as conn:
result = conn.execute(sa.text(variables_query)) result = conn.execute(sa.text(variables_query))
orphaned_by_app = {} orphaned_by_app: dict[str, _AppOrphanCounts] = {}
total_files = 0 total_files = 0
for row in result: for row in result:

View File

@ -0,0 +1,63 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
from libs.helper import UUIDStrOrEmpty
# --- Conversation schemas ---
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
# --- Message schemas ---
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
# --- Saved message schemas ---
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
# --- Workflow schemas ---
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None

View File

@ -2,6 +2,7 @@ import csv
import io import io
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import cast
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@ -17,7 +18,7 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db from extensions.ext_database import db
from libs.token import extract_access_token from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService from services.billing_service import BillingService, LangContentDict
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -328,7 +329,7 @@ class UpsertNotificationApi(Resource):
def post(self): def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload) payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification( result = BillingService.upsert_notification(
contents=[c.model_dump() for c in payload.contents], contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
frequency=payload.frequency, frequency=payload.frequency,
status=payload.status, status=payload.status,
notification_id=payload.notification_id, notification_id=payload.notification_id,

View File

@ -7,7 +7,7 @@ from flask import request
from flask_restx import Resource from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
@ -26,9 +26,11 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType from models.model import IconType
@ -41,10 +43,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
NotionIcon, NotionIcon,
NotionInfo, NotionInfo,
NotionPage, NotionPage,
PreProcessingRule,
RerankingModel, RerankingModel,
Rule,
Segmentation,
WebsiteInfo, WebsiteInfo,
WeightKeywordSetting, WeightKeywordSetting,
WeightModel, WeightModel,
@ -155,16 +154,6 @@ class AppTracePayload(BaseModel):
type JSONValue = Any type JSONValue = Any
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
def _to_timestamp(value: datetime | int | None) -> int | None: def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime): if isinstance(value, datetime):
return int(value.timestamp()) return int(value.timestamp())

View File

@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
) )
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]: def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs. """Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied: It ensures the following conditions are satisfied:
@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
@edit_permission_required @edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f) @wraps(f)
def wrapper(*args: Any, **kwargs: Any): def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper

View File

@ -66,13 +66,13 @@ class WebhookTriggerApi(Resource):
with sessionmaker(db.engine).begin() as session: with sessionmaker(db.engine).begin() as session:
# Get webhook trigger for this app and node # Get webhook trigger for this app and node
webhook_trigger = ( webhook_trigger = session.scalar(
session.query(WorkflowWebhookTrigger) select(WorkflowWebhookTrigger)
.where( .where(
WorkflowWebhookTrigger.app_id == app_model.id, WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id, WorkflowWebhookTrigger.node_id == node_id,
) )
.first() .limit(1)
) )
if not webhook_trigger: if not webhook_trigger:

View File

@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any from typing import overload
from sqlalchemy import select from sqlalchemy import select
@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model return app_model
def get_app_model( @overload
view: Callable[..., Any] | None = None, def get_app_model[**P, R](
view: Callable[P, R],
*, *,
mode: AppMode | list[AppMode] | None = None, mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: ) -> Callable[P, R]: ...
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@overload
def get_app_model[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any): def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")
@ -68,14 +84,30 @@ def get_app_model(
return decorator(view) return decorator(view)
def get_app_model_with_trial( @overload
view: Callable[..., Any] | None = None, def get_app_model_with_trial[**P, R](
view: Callable[P, R],
*, *,
mode: AppMode | list[AppMode] | None = None, mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: ) -> Callable[P, R]: ...
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@overload
def get_app_model_with_trial[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model_with_trial[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any): def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")

View File

@ -3,7 +3,7 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -20,35 +20,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel): class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result") result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token") data: str | None = Field(default=None, description="Reset token")

View File

@ -1,5 +1,3 @@
from typing import Any
import flask_login import flask_login
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource from flask_restx import Resource
@ -42,8 +40,9 @@ from libs.token import (
set_csrf_token_to_cookie, set_csrf_token_to_cookie,
set_refresh_token_to_cookie, set_refresh_token_to_cookie,
) )
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -51,9 +50,7 @@ from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel): class LoginPayload(LoginPayloadBase):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag") remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token") invite_token: str | None = Field(default=None, description="Invitation token")
@ -101,7 +98,7 @@ class LoginApi(Resource):
raise EmailPasswordLoginLimitError() raise EmailPasswordLoginLimitError()
invite_token = args.invite_token invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None invitation_data: InvitationDetailDict | None = None
if invite_token: if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None: if invitation_data is None:

View File

@ -158,10 +158,11 @@ class DataSourceApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]): def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id) binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute( data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id) select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
).scalar_one_or_none() ).scalar_one_or_none()
if data_source_binding is None: if data_source_binding is None:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")

View File

@ -3,6 +3,7 @@ import logging
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -86,8 +87,8 @@ class CustomizedPipelineTemplateApi(Resource):
@enterprise_license_required @enterprise_license_required
def post(self, template_id: str): def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = ( template = session.scalar(
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
) )
if not template: if not template:
raise ValueError("Customized pipeline template not found.") raise ValueError("Customized pipeline template not found.")

View File

@ -1,4 +1,5 @@
import logging import logging
from collections.abc import Callable
from typing import Any, NoReturn from typing import Any, NoReturn
from flask import Response, request from flask import Response, request
@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite(f): def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs. """Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied: It ensures the following conditions are satisfied:
@ -70,7 +71,7 @@ def _api_prerequisite(f):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def wrapper(*args, **kwargs): def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission: if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)

View File

@ -2,10 +2,10 @@ import logging
from flask import request from flask import request
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
@ -32,14 +32,6 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload) register_schema_model(console_ns, TextToAudioPayload)

View File

@ -1,10 +1,11 @@
from typing import Any from typing import Any
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
pinned: bool | None = None pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload) register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -3,9 +3,10 @@ from typing import Literal
from flask import request from flask import request
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter from pydantic import BaseModel, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError, AppMoreLikeThisDisabledError,
@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.enums import FeedbackRating from models.enums import FeedbackRating
from models.model import AppMode from models.model import AppMode
@ -44,17 +44,6 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel): class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] response_mode: Literal["blocking", "streaming"]

View File

@ -1,28 +1,18 @@
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,11 +1,10 @@
import logging import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
@ -34,12 +33,6 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload) register_schema_model(console_ns, WorkflowRunPayload)

View File

@ -1,3 +1,5 @@
from typing import TypedDict
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -11,6 +13,21 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US" _FALLBACK_LANG = "en-US"
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
lang: str
title: str
subtitle: str
body: str
title_pic_url: str
class NotificationResponseDict(TypedDict):
should_show: bool
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict: def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English.""" """Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
@ -45,28 +62,30 @@ class NotificationApi(Resource):
result = BillingService.get_account_notification(str(current_user.id)) result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling). # Proto JSON uses camelCase field names (Kratos default marshaling).
response: NotificationResponseDict
if not result.get("shouldShow"): if not result.get("shouldShow"):
return {"should_show": False, "notifications": []}, 200 response = {"should_show": False, "notifications": []}
return response, 200
lang = current_user.interface_language or _FALLBACK_LANG lang = current_user.interface_language or _FALLBACK_LANG
notifications = [] notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []: for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {} contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang) lang_content = _pick_lang_content(contents, lang)
notifications.append( item: NotificationItemDict = {
{ "notification_id": notification.get("notificationId"),
"notification_id": notification.get("notificationId"), "frequency": notification.get("frequency"),
"frequency": notification.get("frequency"), "lang": lang_content.get("lang", lang),
"lang": lang_content.get("lang", lang), "title": lang_content.get("title", ""),
"title": lang_content.get("title", ""), "subtitle": lang_content.get("subtitle", ""),
"subtitle": lang_content.get("subtitle", ""), "body": lang_content.get("body", ""),
"body": lang_content.get("body", ""), "title_pic_url": lang_content.get("titlePicUrl", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""), }
} notifications.append(item)
)
return {"should_show": bool(notifications), "notifications": notifications}, 200 response = {"should_show": bool(notifications), "notifications": notifications}
return response, 200
@console_ns.route("/notification/dismiss") @console_ns.route("/notification/dismiss")

View File

@ -9,7 +9,14 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.tag_service import TagService from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
dataset_tag_fields = { dataset_tag_fields = {
"id": fields.String, "id": fields.String,
@ -25,19 +32,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
class TagBasePayload(BaseModel): class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50) name: str = Field(description="Tag name", min_length=1, max_length=50)
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") type: TagType = Field(description="Tag type")
class TagBindingPayload(BaseModel): class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind") tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to") target_id: str = Field(description="Target ID to bind tags to")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") type: TagType = Field(description="Tag type")
class TagBindingRemovePayload(BaseModel): class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove") tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from") target_id: str = Field(description="Target ID to unbind tag from")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel): class TagListQueryParam(BaseModel):
@ -82,7 +89,7 @@ class TagListApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {}) payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump()) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -103,7 +110,7 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {}) payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id) tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
@ -136,7 +143,9 @@ class TagBindingCreateApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {}) payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump()) TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -154,6 +163,8 @@ class TagBindingDeleteApi(Resource):
raise Forbidden() raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump()) TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -1,6 +1,7 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -21,12 +22,12 @@ def plugin_permission_required(
tenant_id = current_tenant_id tenant_id = current_tenant_id
with sessionmaker(db.engine).begin() as session: with sessionmaker(db.engine).begin() as session:
permission = ( permission = session.scalar(
session.query(TenantPluginPermission) select(TenantPluginPermission)
.where( .where(
TenantPluginPermission.tenant_id == tenant_id, TenantPluginPermission.tenant_id == tenant_id,
) )
.first() .limit(1)
) )
if not permission: if not permission:

View File

@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -240,8 +240,10 @@ class CustomConfigWorkspaceApi(Resource):
args = WorkspaceCustomConfigPayload.model_validate(payload) args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id) tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = { custom_config_dict: TenantCustomConfigDict = {
"remove_webapp_brand": args.remove_webapp_brand, "remove_webapp_brand": args.remove_webapp_brand
if args.remove_webapp_brand is not None
else tenant.custom_config_dict.get("remove_webapp_brand", False),
"replace_webapp_logo": args.replace_webapp_logo "replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"), else tenant.custom_config_dict.get("replace_webapp_logo"),

View File

@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id) account.set_tenant_id(workspace_id)
with Session(db.engine) as session: with sessionmaker(db.engine).begin() as session:
dsl_service = AppDslService(session) dsl_service = AppDslService(session)
result = dsl_service.import_app( result = dsl_service.import_app(
account=account, account=account,
@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
name=args.name, name=args.name,
description=args.description, description=args.description,
) )
session.commit()
if result.status == ImportStatus.FAILED: if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400

View File

@ -4,6 +4,7 @@ from flask import Response
from flask_restx import Resource from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
@ -80,11 +81,11 @@ class MCPAppApi(Resource):
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]: def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session""" """Get and validate MCP server and app in one query session"""
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
if not mcp_server: if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found") raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
app = session.query(App).where(App.id == mcp_server.app_id).first() app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
if not app: if not app:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found") raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
@ -190,12 +191,12 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None: def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session""" """Get end user - manages its own database session"""
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return ( return session.scalar(
session.query(EndUser) select(EndUser)
.where(EndUser.tenant_id == tenant_id) .where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id) .where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp") .where(EndUser.type == "mcp")
.first() .limit(1)
) )
def _create_end_user( def _create_end_user(

View File

@ -2,11 +2,12 @@ from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
import services import services
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
) )
class ConversationRenamePayload(BaseModel):
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel): class ConversationVariablesQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")

View File

@ -1,5 +1,4 @@
import logging import logging
from typing import Literal
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.errors.message import ( from services.errors.message import (
@ -27,17 +26,6 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel): class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number") page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page") limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Literal from typing import Literal
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request from flask import request
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel): class WorkflowRunPayload(WorkflowRunPayloadBase):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None response_mode: Literal["blocking", "streaming"] | None = None

View File

@ -22,10 +22,17 @@ from fields.tag_fields import DataSetTag
from libs.login import current_user from libs.login import current_user
from models.account import Account from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -513,7 +520,7 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
response = DataSetTag.model_validate( response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -536,9 +543,8 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id tag_id = payload.tag_id
tag = TagService.update_tags(params, tag_id) tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id) binding_count = TagService.get_tag_binding_count(tag_id)
@ -585,7 +591,9 @@ class DatasetTagBindingApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204 return "", 204
@ -609,7 +617,9 @@ class DatasetTagUnbindingApi(DatasetApiResource):
raise Forbidden() raise Forbidden()
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204 return "", 204

View File

@ -31,6 +31,7 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
) )
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
@ -40,11 +41,8 @@ from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig, KnowledgeConfig,
PreProcessingRule,
ProcessRule, ProcessRule,
RetrievalModel, RetrievalModel,
Rule,
Segmentation,
) )
from services.file_service import FileService from services.file_service import FileService
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService

View File

@ -4,13 +4,23 @@ Serialization helpers for Service API knowledge pipeline endpoints.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from models.model import UploadFile from models.model import UploadFile
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]: class UploadFileDict(TypedDict):
id: str
name: str
size: int
extension: str
mime_type: str | None
created_by: str
created_at: str | None
def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict:
return { return {
"id": upload_file.id, "id": upload_file.id,
"name": upload_file.name, "name": upload_file.name,

View File

@ -1,9 +1,10 @@
import inspect
import logging import logging
import time import time
from collections.abc import Callable from collections.abc import Callable
from enum import StrEnum, auto from enum import StrEnum, auto
from functools import wraps from functools import wraps
from typing import Any, cast, overload from typing import cast, overload
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
return interceptor return interceptor
def validate_dataset_token( def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
view: Callable[..., Any] | None = None, positional_parameters = [
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: parameter
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]: for parameter in inspect.signature(view).parameters.values()
@wraps(view_func) if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
def decorated(*args: Any, **kwargs: Any) -> Any: ]
api_token = validate_and_get_api_token("dataset") expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
# get url path dataset_id from positional args or kwargs @wraps(view)
# Flask passes URL path parameters as positional arguments def decorated(*args: object, **kwargs: object) -> R:
dataset_id = None api_token = validate_and_get_api_token("dataset")
# First try to get from kwargs (explicit parameter) # Flask may pass URL path parameters positionally, so inspect both kwargs and args.
dataset_id = kwargs.get("dataset_id") dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args if not dataset_id and args:
if not dataset_id and args: potential_id = args[0]
# For class methods: args[0] is self, args[1] is dataset_id (if exists) try:
# Check if first arg is likely a class instance (has __dict__ or __class__) str_id = str(potential_id)
if len(args) > 1 and hasattr(args[0], "__dict__"): if len(str_id) == 36 and str_id.count("-") == 4:
# This is a class method, dataset_id should be in args[1] dataset_id = str_id
potential_id = args[1] except Exception:
# Validate it's a string-like UUID, not another object logger.exception("Failed to parse dataset_id from positional args")
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided if dataset_id:
if dataset_id: dataset_id = str(dataset_id)
dataset_id = str(dataset_id) dataset = db.session.scalar(
dataset = db.session.scalar( select(Dataset)
select(Dataset) .where(
.where( Dataset.id == dataset_id,
Dataset.id == dataset_id, Dataset.tenant_id == api_token.tenant_id,
Dataset.tenant_id == api_token.tenant_id,
)
.limit(1)
) )
if not dataset: .limit(1)
raise NotFound("Dataset not found.") )
if not dataset.enable_api: if not dataset:
raise Forbidden("Dataset api access is not enabled.") raise NotFound("Dataset not found.")
tenant_account_join = db.session.execute( if not dataset.enable_api:
select(Tenant, TenantAccountJoin) raise Forbidden("Dataset api access is not enabled.")
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id) tenant_account_join = db.session.execute(
.where(TenantAccountJoin.role.in_(["owner"])) select(Tenant, TenantAccountJoin)
.where(Tenant.status == TenantStatus.NORMAL) .where(Tenant.id == api_token.tenant_id)
).one_or_none() # TODO: only owner information is required, so only one is returned. .where(TenantAccountJoin.tenant_id == Tenant.id)
if tenant_account_join: .where(TenantAccountJoin.role.in_(["owner"]))
tenant, ta = tenant_account_join .where(Tenant.status == TenantStatus.NORMAL)
account = db.session.get(Account, ta.account_id) ).one_or_none() # TODO: only owner information is required, so only one is returned.
# Login admin if tenant_account_join:
if account: tenant, ta = tenant_account_join
account.current_tenant = tenant account = db.session.get(Account, ta.account_id)
current_app.login_manager._update_request_context_with_user(account) # type: ignore # Login admin
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore if account:
else: account.current_tenant = tenant
raise Unauthorized("Tenant owner account does not exist.") current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else: else:
raise Unauthorized("Tenant does not exist.") raise Unauthorized("Tenant owner account does not exist.")
if args and isinstance(args[0], Resource): else:
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs) raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs) if expects_bound_instance:
if not args:
raise TypeError("validate_dataset_token expected a bound resource instance.")
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
return decorated return view(api_token.tenant_id, *args, **kwargs)
if view: return decorated
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
def validate_and_get_api_token(scope: str | None = None): def validate_and_get_api_token(scope: str | None = None):

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp from controllers.trigger import bp
from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
from services.trigger.webhook_service import WebhookService from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_id, is_debug=is_debug webhook_id, is_debug=is_debug
) )
webhook_data: RawWebhookDataDict
try: try:
# Use new unified extraction and validation # Use new unified extraction and validation
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config) webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)

View File

@ -3,10 +3,11 @@ import logging
from flask import request from flask import request
from flask_restx import fields, marshal_with from flask_restx import fields, marshal_with
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, field_validator from pydantic import field_validator
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
AppUnavailableError, AppUnavailableError,
@ -34,12 +35,7 @@ from services.errors.audio import (
from ..common.schema import register_schema_models from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel): class TextToAudioPayload(TextToAudioPayloadBase):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id") @field_validator("message_id")
@classmethod @classmethod
def validate_message_id(cls, value: str | None) -> str | None: def validate_message_id(cls, value: str | None) -> str | None:

View File

@ -1,10 +1,11 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import NotChatAppError from controllers.web.error import NotChatAppError
@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
return uuid_value(value) return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -3,7 +3,6 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
@ -19,33 +18,15 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns from controllers.web import web_ns
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip from libs.helper import extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password
from models.account import Account from models.account import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
class ForgotPasswordSendPayload(BaseModel): ForgotPasswordResetPayload,
email: EmailStr ForgotPasswordSendPayload,
language: str | None = None )
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload) register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)

View File

@ -29,13 +29,11 @@ from libs.token import (
) )
from services.account_service import AccountService from services.account_service import AccountService
from services.app_service import AppService from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
class LoginPayload(BaseModel): class LoginPayload(LoginPayloadBase):
email: EmailStr
password: str
@field_validator("password") @field_validator("password")
@classmethod @classmethod
def validate_password(cls, value: str) -> str: def validate_password(cls, value: str) -> str:

View File

@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter, field_validator from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
return uuid_value(value) return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel): class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field( response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode", description="Response mode",

View File

@ -1,27 +1,17 @@
from flask import request from flask import request
from pydantic import BaseModel, Field, TypeAdapter from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,11 +1,10 @@
import logging import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.error import ( from controllers.web.error import (
@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
register_schema_models(web_ns, WorkflowRunPayload) register_schema_models(web_ns, WorkflowRunPayload)

View File

@ -79,21 +79,18 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad: if not agent_scratchpad:
assistant_messages = [] assistant_messages = []
else: else:
assistant_message = AssistantPromptMessage(content="") content = ""
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assert isinstance(assistant_message.content, str) content += f"Final Answer: {unit.agent_response}"
assistant_message.content += f"Final Answer: {unit.agent_response}"
else: else:
assert isinstance(assistant_message.content, str) content += f"Thought: {unit.thought}\n\n"
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str: if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n" content += f"Action: {unit.action_str}\n\n"
if unit.observation: if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n" content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message] assistant_messages = [AssistantPromptMessage(content=content)]
# query messages # query messages
query_messages = self._organize_user_query(self._query, []) query_messages = self._organize_user_query(self._query, [])

View File

@ -5,6 +5,10 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS from constants import DEFAULT_FILE_NUMBER_LIMITS
class FeatureToggleDict(TypedDict):
enabled: bool
class SystemParametersDict(TypedDict): class SystemParametersDict(TypedDict):
image_file_size_limit: int image_file_size_limit: int
video_file_size_limit: int video_file_size_limit: int
@ -16,12 +20,12 @@ class SystemParametersDict(TypedDict):
class AppParametersDict(TypedDict): class AppParametersDict(TypedDict):
opening_statement: str | None opening_statement: str | None
suggested_questions: list[str] suggested_questions: list[str]
suggested_questions_after_answer: dict[str, Any] suggested_questions_after_answer: FeatureToggleDict
speech_to_text: dict[str, Any] speech_to_text: FeatureToggleDict
text_to_speech: dict[str, Any] text_to_speech: FeatureToggleDict
retriever_resource: dict[str, Any] retriever_resource: FeatureToggleDict
annotation_reply: dict[str, Any] annotation_reply: FeatureToggleDict
more_like_this: dict[str, Any] more_like_this: FeatureToggleDict
user_input_form: list[dict[str, Any]] user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any] sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any] file_upload: dict[str, Any]

View File

@ -1,4 +1,3 @@
from collections.abc import Sequence
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Any, Literal from typing import Any, Literal
@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.entities import MetadataFilteringCondition
from models.model import AppMode from models.model import AppMode
@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
config: dict[str, Any] = Field(default_factory=dict) config: dict[str, Any] = Field(default_factory=dict)
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
provider: str provider: str
name: str name: str
@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
completion_params: dict[str, Any] = Field(default_factory=dict) completion_params: dict[str, Any] = Field(default_factory=dict)
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel): class DatasetRetrieveConfigEntity(BaseModel):
""" """
Dataset Retrieve Config Entity. Dataset Retrieve Config Entity.

View File

@ -107,13 +107,13 @@ class AppGenerateResponseConverter(ABC):
return metadata return metadata
@classmethod @classmethod
def _error_to_stream_response(cls, e: Exception): def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
""" """
Error to stream response. Error to stream response.
:param e: exception :param e: exception
:return: :return:
""" """
error_responses = { error_responses: dict[type[Exception], dict[str, Any]] = {
ValueError: {"code": "invalid_param", "status": 400}, ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: { QuotaExceededError: {
@ -127,7 +127,7 @@ class AppGenerateResponseConverter(ABC):
} }
# Determine the response based on the type of exception # Determine the response based on the type of exception
data = None data: dict[str, Any] | None = None
for k, v in error_responses.items(): for k, v in error_responses.items():
if isinstance(e, k): if isinstance(e, k):
data = v data = v

View File

@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.system_variables import ( from core.workflow.system_variables import (
build_bootstrap_variables, build_bootstrap_variables,

View File

@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
class QueueEvent(StrEnum): class QueueEvent(StrEnum):

View File

@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
class AnnotationReplyAccount(BaseModel): class AnnotationReplyAccount(BaseModel):

View File

@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return: :return:
""" """
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
agent_thought: MessageAgentThought | None = ( agent_thought: MessageAgentThought | None = session.scalar(
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
) )
if agent_thought: if agent_thought:

View File

@ -6,7 +6,7 @@ from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -345,8 +345,8 @@ class DatasourceManager:
@classmethod @classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session: with session_factory.create_session() as session:
upload_file = ( upload_file = session.scalar(
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first() select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
) )
if not upload_file: if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")

View File

@ -1,22 +1,3 @@
from pydantic import BaseModel, Field, model_validator from core.tools.entities.common_entities import I18nObject, I18nObjectDict
__all__ = ["I18nObject", "I18nObjectDict"]
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -9,7 +9,7 @@ from yarl import URL
from configs import dify_config from configs import dify_config
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema from core.plugin.entities import OAuthSchema
from core.plugin.entities.parameters import ( from core.plugin.entities.parameters import (
PluginParameter, PluginParameter,
PluginParameterOption, PluginParameterOption,

View File

@ -1 +1,8 @@
from core.entities.plugin_credential_type import PluginCredentialType
DEFAULT_PLUGIN_ID = "langgenius" DEFAULT_PLUGIN_ID = "langgenius"
__all__ = [
"DEFAULT_PLUGIN_ID",
"PluginCredentialType",
]

View File

@ -0,0 +1,9 @@
import enum
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value

View File

@ -22,6 +22,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities import PluginCredentialType
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import ( from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
@ -46,7 +47,6 @@ from models.provider import (
TenantPreferredModelProvider, TenantPreferredModelProvider,
) )
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@
Credential utility functions for checking credential existence and policy compliance. Credential utility functions for checking credential existence and policy compliance.
""" """
from services.enterprise.plugin_manager_service import PluginCredentialType from core.entities import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:

View File

@ -2,7 +2,7 @@ import json
import logging import logging
import re import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Protocol, cast from typing import Protocol, TypedDict, cast
import json_repair import json_repair
from graphon.enums import WorkflowNodeExecutionMetadataKey from graphon.enums import WorkflowNodeExecutionMetadataKey
@ -49,6 +49,17 @@ class WorkflowServiceInterface(Protocol):
pass pass
class CodeGenerateResultDict(TypedDict):
code: str
language: str
error: str
class StructuredOutputResultDict(TypedDict):
output: str
error: str
class LLMGenerator: class LLMGenerator:
@classmethod @classmethod
def generate_conversation_name( def generate_conversation_name(
@ -293,7 +304,7 @@ class LLMGenerator:
cls, cls,
tenant_id: str, tenant_id: str,
args: RuleCodeGeneratePayload, args: RuleCodeGeneratePayload,
): ) -> CodeGenerateResultDict:
if args.code_language == "python": if args.code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else: else:
@ -362,7 +373,9 @@ class LLMGenerator:
return answer.strip() return answer.strip()
@classmethod @classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): def generate_structured_output(
cls, tenant_id: str, args: RuleStructuredOutputPayload
) -> StructuredOutputResultDict:
model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -454,7 +467,7 @@ class LLMGenerator:
): ):
session = db.session() session = db.session()
app: App | None = session.query(App).where(App.id == flow_id).first() app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
if not app: if not app:
raise ValueError("App not found.") raise ValueError("App not found.")
workflow = workflow_service.get_draft_workflow(app_model=app) workflow = workflow_service.get_draft_workflow(app_model=app)

View File

@ -6,6 +6,7 @@ import logging
import flask import flask
from core.logging.context import get_request_id, get_trace_id from core.logging.context import get_request_id, get_trace_id
from core.logging.structured_formatter import IdentityDict
class TraceContextFilter(logging.Filter): class TraceContextFilter(logging.Filter):
@ -60,7 +61,7 @@ class IdentityContextFilter(logging.Filter):
record.user_type = identity.get("user_type", "") record.user_type = identity.get("user_type", "")
return True return True
def _extract_identity(self) -> dict[str, str]: def _extract_identity(self) -> IdentityDict:
"""Extract identity from current_user if in request context.""" """Extract identity from current_user if in request context."""
try: try:
if not flask.has_request_context(): if not flask.has_request_context():
@ -77,7 +78,7 @@ class IdentityContextFilter(logging.Filter):
from models import Account from models import Account
from models.model import EndUser from models.model import EndUser
identity: dict[str, str] = {} identity: IdentityDict = {}
if isinstance(user, Account): if isinstance(user, Account):
if user.current_tenant_id: if user.current_tenant_id:

View File

@ -3,13 +3,19 @@
import logging import logging
import traceback import traceback
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any, TypedDict
import orjson import orjson
from configs import dify_config from configs import dify_config
class IdentityDict(TypedDict, total=False):
tenant_id: str
user_id: str
user_type: str
class StructuredJSONFormatter(logging.Formatter): class StructuredJSONFormatter(logging.Formatter):
""" """
JSON log formatter following the specified schema: JSON log formatter following the specified schema:
@ -84,7 +90,7 @@ class StructuredJSONFormatter(logging.Formatter):
return log_dict return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None: def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
tenant_id = getattr(record, "tenant_id", None) tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None) user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None) user_type = getattr(record, "user_type", None)
@ -92,7 +98,7 @@ class StructuredJSONFormatter(logging.Formatter):
if not any([tenant_id, user_id, user_type]): if not any([tenant_id, user_id, user_type]):
return None return None
identity: dict[str, str] = {} identity: IdentityDict = {}
if tenant_id: if tenant_id:
identity["tenant_id"] = tenant_id identity["tenant_id"] = tenant_id
if user_id: if user_id:

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, NotRequired, TypedDict, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.input_entities import VariableEntity, VariableEntityType
@ -15,6 +15,17 @@ from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolParameterSchemaDict(TypedDict):
type: str
properties: dict[str, Any]
required: list[str]
class ToolArgumentsDict(TypedDict):
query: NotRequired[str]
inputs: dict[str, Any]
def handle_mcp_request( def handle_mcp_request(
app: App, app: App,
request: mcp_types.ClientRequest, request: mcp_types.ClientRequest,
@ -119,7 +130,7 @@ def handle_list_tools(
mcp_types.Tool( mcp_types.Tool(
name=app_name, name=app_name,
description=description, description=description,
inputSchema=parameter_schema, inputSchema=cast(dict[str, Any], parameter_schema),
) )
], ],
) )
@ -154,7 +165,7 @@ def build_parameter_schema(
app_mode: str, app_mode: str,
user_input_form: list[VariableEntity], user_input_form: list[VariableEntity],
parameters_dict: dict[str, str], parameters_dict: dict[str, str],
) -> dict[str, Any]: ) -> ToolParameterSchemaDict:
"""Build parameter schema for the tool""" """Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
@ -174,7 +185,7 @@ def build_parameter_schema(
} }
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
"""Prepare arguments based on app mode""" """Prepare arguments based on app mode"""
if app.mode == AppMode.WORKFLOW: if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments} return {"inputs": arguments}

View File

@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta from datetime import timedelta
from types import TracebackType from types import TracebackType
from typing import Any, Self, cast from typing import Any, Self
from httpx import HTTPStatusError from httpx import HTTPStatusError
from pydantic import BaseModel from pydantic import BaseModel
@ -338,12 +338,11 @@ class BaseSession[
validated_request = self._receive_request_type.model_validate( validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT]( responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id, request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None, request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request, request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
session=self, session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None), on_complete=lambda r: self._in_flight.pop(r.request_id, None),
) )
@ -359,15 +358,14 @@ class BaseSession[
notification = self._receive_notification_type.model_validate( notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications # Handle cancellation notifications
if isinstance(notification.root, CancelledNotification): if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight: if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel() self._in_flight[cancelled_id].cancel()
else: else:
self._received_notification(notification) self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) self._handle_incoming(notification) # type: ignore[arg-type]
except Exception as e: except Exception as e:
# For other validation errors, log and continue # For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root) logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)

View File

@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from configs import dify_config from configs import dify_config
from core.entities import PluginCredentialType
from core.entities.embedding_type import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.provider import ProviderType from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,6 +1,6 @@
import json import json
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any, TypedDict
from graphon.entities import WorkflowNodeExecution from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus from graphon.enums import WorkflowNodeExecutionStatus
@ -56,10 +56,22 @@ def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
return links return links
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]: class RetrievalDocumentMetadataDict(TypedDict):
documents_data = [] dataset_id: Any
doc_id: Any
document_id: Any
class RetrievalDocumentDict(TypedDict):
content: str
metadata: RetrievalDocumentMetadataDict
score: Any
def extract_retrieval_documents(documents: list[Document]) -> list[RetrievalDocumentDict]:
documents_data: list[RetrievalDocumentDict] = []
for document in documents: for document in documents:
document_data = { document_data: RetrievalDocumentDict = {
"content": document.page_content, "content": document.page_content,
"metadata": { "metadata": {
"dataset_id": document.metadata.get("dataset_id"), "dataset_id": document.metadata.get("dataset_id"),
@ -83,7 +95,7 @@ def create_common_span_attributes(
framework: str = DEFAULT_FRAMEWORK_NAME, framework: str = DEFAULT_FRAMEWORK_NAME,
inputs: str = "", inputs: str = "",
outputs: str = "", outputs: str = "",
) -> dict[str, Any]: ) -> dict[str, str]:
return { return {
GEN_AI_SESSION_ID: session_id, GEN_AI_SESSION_ID: session_id,
GEN_AI_USER_ID: user_id, GEN_AI_USER_ID: user_id,

View File

@ -56,8 +56,10 @@ class BaseTraceInstance(ABC):
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = ( current_tenant = session.scalar(
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
) )
if not current_tenant: if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}") raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance):
if not service_account: if not service_account:
raise ValueError(f"Creator account not found for app {app_id}") raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = ( current_tenant = session.scalar(
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
) )
if not current_tenant: if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}") raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -0,0 +1,5 @@
from core.plugin.entities.oauth import OAuthSchema
__all__ = [
"OAuthSchema",
]

View File

@ -1,5 +1,3 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
OAuth schema OAuth schema
""" """
client_schema: Sequence[ProviderConfig] = Field( client_schema: list[ProviderConfig] = Field(
default_factory=list, default_factory=list,
description="client schema like client_id, client_secret, etc.", description="client schema like client_id, client_secret, etc.",
) )
credentials_schema: Sequence[ProviderConfig] = Field( credentials_schema: list[ProviderConfig] = Field(
default_factory=list, default_factory=list,
description="credentials schema like access_token, refresh_token, etc.", description="credentials schema like access_token, refresh_token, etc.",
) )

View File

@ -1,11 +1,10 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import json
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ( from graphon.model_runtime.entities.provider_entities import (
@ -15,6 +14,7 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderEntity, ProviderEntity,
) )
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -58,6 +58,8 @@ from services.feature_service import FeatureService
if TYPE_CHECKING: if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime from graphon.model_runtime.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class ProviderManager: class ProviderManager:
""" """
@ -875,8 +877,8 @@ class ProviderManager:
return {"openai_api_key": encrypted_config} return {"openai_api_key": encrypted_config}
try: try:
credentials = cast(dict, json.loads(encrypted_config)) credentials = _credentials_adapter.validate_json(encrypted_config)
except JSONDecodeError: except (ValueError, JSONDecodeError):
return {} return {}
# Decrypt secret variables # Decrypt secret variables
@ -1015,7 +1017,7 @@ class ProviderManager:
if not cached_provider_credentials: if not cached_provider_credentials:
provider_credentials: dict[str, Any] = {} provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config: if provider_records and provider_records[0].encrypted_config:
provider_credentials = json.loads(provider_records[0].encrypted_config) provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(
@ -1162,8 +1164,10 @@ class ProviderManager:
if not cached_provider_model_credentials: if not cached_provider_model_credentials:
try: try:
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config) provider_model_credentials = _credentials_adapter.validate_json(
except JSONDecodeError: load_balancing_model_config.encrypted_config
)
except (ValueError, JSONDecodeError):
continue continue
# Get decoding rsa key and cipher for decrypting credentials # Get decoding rsa key and cipher for decrypting credentials
@ -1176,7 +1180,7 @@ class ProviderManager:
if variable in provider_model_credentials: if variable in provider_model_credentials:
try: try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable), provider_model_credentials.get(variable) or "",
self.decoding_rsa_key, self.decoding_rsa_key,
self.decoding_cipher_rsa, self.decoding_cipher_rsa,
) )

View File

@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition from core.rag.entities import MetadataFilteringCondition
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType from core.rag.index_processor.constant.query_type import QueryType
@ -182,7 +182,9 @@ class RetrievalService:
if not dataset: if not dataset:
return [] return []
metadata_condition = ( metadata_condition = (
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
if metadata_filtering_conditions
else None
) )
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset.tenant_id,
@ -240,7 +242,7 @@ class RetrievalService:
@classmethod @classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None: def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session: with Session(db.engine) as session:
return session.query(Dataset).where(Dataset.id == dataset_id).first() return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
@classmethod @classmethod
def keyword_search( def keyword_search(
@ -573,15 +575,13 @@ class RetrievalService:
# Batch query summaries for segments retrieved via summary (only enabled summaries) # Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids: if summary_segment_ids:
summaries = ( summaries = session.scalars(
session.query(DocumentSegmentSummary) select(DocumentSegmentSummary).where(
.filter(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)), DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed", DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
) )
.all() ).all()
)
for summary in summaries: for summary in summaries:
if summary.summary_content: if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content segment_summary_map[summary.chunk_id] = summary.summary_content
@ -851,12 +851,12 @@ class RetrievalService:
def get_segment_attachment_info( def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> SegmentAttachmentResult | None: ) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
if upload_file: if upload_file:
attachment_binding = ( attachment_binding = session.scalar(
session.query(SegmentAttachmentBinding) select(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id) .where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.first() .limit(1)
) )
if attachment_binding: if attachment_binding:
attachment_info: AttachmentInfoDict = { attachment_info: AttachmentInfoDict = {
@ -875,14 +875,12 @@ class RetrievalService:
cls, attachment_ids: list[str], session: Session cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]: ) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = [] attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
if upload_files: if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files] upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = ( attachment_bindings = session.scalars(
session.query(SegmentAttachmentBinding) select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids)) ).all()
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings} attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings: if attachment_bindings:

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any from typing import Any, TypedDict
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -13,6 +13,13 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
class AnalyticdbClientParamsDict(TypedDict):
access_key_id: str
access_key_secret: str
region_id: str
read_timeout: int
class AnalyticdbVectorOpenAPIConfig(BaseModel): class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str access_key_id: str
access_key_secret: str access_key_secret: str
@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required") raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values return values
def to_analyticdb_client_params(self): def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
return { result: AnalyticdbClientParamsDict = {
"access_key_id": self.access_key_id, "access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret, "access_key_secret": self.access_key_secret,
"region_id": self.region_id, "region_id": self.region_id,
"read_timeout": self.read_timeout, "read_timeout": self.read_timeout,
} }
return result
class AnalyticdbVectorOpenAPI: class AnalyticdbVectorOpenAPI:

View File

@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -85,8 +85,12 @@ class BaiduVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.BAIDU return VectorType.BAIDU
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0])) self._create_table(len(embeddings[0]))

View File

@ -1,12 +1,12 @@
import json import json
from typing import Any from typing import Any, TypedDict
import chromadb import chromadb
from chromadb import QueryResult, Settings from chromadb import QueryResult, Settings
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
class ChromaParamsDict(TypedDict):
host: str
port: int
ssl: bool
tenant: str
database: str
settings: Settings
class ChromaConfig(BaseModel): class ChromaConfig(BaseModel):
host: str host: str
port: int port: int
@ -23,14 +32,13 @@ class ChromaConfig(BaseModel):
auth_provider: str | None = None auth_provider: str | None = None
auth_credentials: str | None = None auth_credentials: str | None = None
def to_chroma_params(self): def to_chroma_params(self) -> ChromaParamsDict:
settings = Settings( settings = Settings(
# auth # auth
chroma_client_auth_provider=self.auth_provider, chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials, chroma_client_auth_credentials=self.auth_credentials,
) )
result: ChromaParamsDict = {
return {
"host": self.host, "host": self.host,
"port": self.port, "port": self.port,
"ssl": False, "ssl": False,
@ -38,6 +46,7 @@ class ChromaConfig(BaseModel):
"database": self.database, "database": self.database,
"settings": settings, "settings": settings,
} }
return result
class ChromaVector(BaseVector): class ChromaVector(BaseVector):
@ -145,7 +154,10 @@ class ChromaVectorFactory(AbstractVectorFactory):
else: else:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} index_struct_dict: VectorIndexStructDict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector( return ChromaVector(

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any from typing import Any, TypedDict
from packaging import version from packaging import version
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -20,6 +20,15 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MilvusParamsDict(TypedDict):
uri: str
token: str | None
user: str | None
password: str | None
db_name: str
analyzer_params: str | None
class MilvusConfig(BaseModel): class MilvusConfig(BaseModel):
""" """
Configuration class for Milvus connection. Configuration class for Milvus connection.
@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
raise ValueError("config MILVUS_PASSWORD is required") raise ValueError("config MILVUS_PASSWORD is required")
return values return values
def to_milvus_params(self): def to_milvus_params(self) -> MilvusParamsDict:
""" """
Convert the configuration to a dictionary of Milvus connection parameters. Convert the configuration to a dictionary of Milvus connection parameters.
""" """
return { result: MilvusParamsDict = {
"uri": self.uri, "uri": self.uri,
"token": self.token, "token": self.token,
"user": self.user, "user": self.user,
@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
"db_name": self.database, "db_name": self.database,
"analyzer_params": self.analyzer_params, "analyzer_params": self.analyzer_params,
} }
return result
class MilvusVector(BaseVector): class MilvusVector(BaseVector):
@ -352,6 +362,7 @@ class MilvusVector(BaseVector):
# Create Index params for the collection # Create Index params for the collection
index_params_obj = IndexParams() index_params_obj = IndexParams()
assert index_params is not None
index_params_obj.add_index(field_name=Field.VECTOR, **index_params) index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection # Create Sparse Vector Index for the collection

View File

@ -22,7 +22,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -94,8 +94,12 @@ class QdrantVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.QDRANT return VectorType.QDRANT
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts: if texts:

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import math import math
from typing import Any from typing import Any, TypedDict
from pydantic import BaseModel from pydantic import BaseModel
from tcvdb_text.encoder import BM25Encoder # type: ignore from tcvdb_text.encoder import BM25Encoder # type: ignore
@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -23,6 +23,13 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TencentParamsDict(TypedDict):
url: str
username: str | None
key: str | None
timeout: float
class TencentConfig(BaseModel): class TencentConfig(BaseModel):
url: str url: str
api_key: str | None = None api_key: str | None = None
@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
max_upsert_batch_size: int = 128 max_upsert_batch_size: int = 128
enable_hybrid_search: bool = False # Flag to enable hybrid search enable_hybrid_search: bool = False # Flag to enable hybrid search
def to_tencent_params(self): def to_tencent_params(self) -> TencentParamsDict:
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} result: TencentParamsDict = {
"url": self.url,
"username": self.username,
"key": self.api_key,
"timeout": self.timeout,
}
return result
bm25 = BM25Encoder.default("zh") bm25 = BM25Encoder.default("zh")
@ -83,8 +96,12 @@ class TencentVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.TENCENT return VectorType.TENCENT
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def _has_collection(self) -> bool: def _has_collection(self) -> bool:
return bool( return bool(

View File

@ -25,7 +25,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -91,8 +91,12 @@ class TidbOnQdrantVector(BaseVector):
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.TIDB_ON_QDRANT return VectorType.TIDB_ON_QDRANT
def to_index_struct(self): def to_index_struct(self) -> VectorIndexStructDict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts: if texts:

View File

@ -1,11 +1,20 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, TypedDict
from core.rag.models.document import Document from core.rag.models.document import Document
class VectorStoreDict(TypedDict):
class_prefix: str
class VectorIndexStructDict(TypedDict):
type: str
vector_store: VectorStoreDict
class BaseVector(ABC): class BaseVector(ABC):
def __init__(self, collection_name: str): def __init__(self, collection_name: str):
self._collection_name = collection_name self._collection_name = collection_name

View File

@ -9,7 +9,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -30,8 +30,11 @@ class AbstractVectorFactory(ABC):
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str): def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} index_struct_dict: VectorIndexStructDict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
}
return index_struct_dict return index_struct_dict

View File

@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings from core.rag.embedding.embedding_base import Embeddings
@ -184,9 +184,13 @@ class WeaviateVector(BaseVector):
dataset_id = dataset.id dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id) return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self) -> dict: def to_index_struct(self) -> VectorIndexStructDict:
"""Returns the index structure dictionary for persistence.""" """Returns the index structure dictionary for persistence."""
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
""" """

View File

@ -0,0 +1,28 @@
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
__all__ = [
"Condition",
"DatasourceCompletedEvent",
"DatasourceErrorEvent",
"DatasourceProcessingEvent",
"DocumentContext",
"EconomySetting",
"EmbeddingSetting",
"IndexMethod",
"KeywordSetting",
"MetadataFilteringCondition",
"ParentMode",
"PreProcessingRule",
"RetrievalSourceMetadata",
"Rule",
"Segmentation",
"SupportedComparisonOperator",
"VectorSetting",
"WeightedScoreConfig",
]

View File

@ -0,0 +1,30 @@
from typing import Literal
from pydantic import BaseModel
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting

View File

@ -38,9 +38,9 @@ class Condition(BaseModel):
value: str | Sequence[str] | None | int | float = None value: str | Sequence[str] | None | int | float = None
class MetadataCondition(BaseModel): class MetadataFilteringCondition(BaseModel):
""" """
Metadata Condition. Metadata Filtering Condition.
""" """
logical_operator: Literal["and", "or"] | None = "and" logical_operator: Literal["and", "or"] | None = "and"

View File

@ -0,0 +1,27 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None

View File

@ -0,0 +1,28 @@
from pydantic import BaseModel
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting

View File

@ -12,7 +12,7 @@ from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory from .index_processor_factory import IndexProcessorFactory
@ -61,7 +61,7 @@ class IndexProcessor:
chunks: Mapping[str, Any], chunks: Mapping[str, Any],
batch: Any, batch: Any,
summary_index_setting: SummaryIndexSettingDict | None = None, summary_index_setting: SummaryIndexSettingDict | None = None,
): ) -> IndexingResultDict:
with session_factory.create_session() as session: with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first() document = session.query(Document).filter_by(id=document_id).first()
if not document: if not document:
@ -129,7 +129,7 @@ class IndexProcessor:
} }
) )
return { result: IndexingResultDict = {
"dataset_id": dataset_id, "dataset_id": dataset_id,
"dataset_name": dataset_name_value, "dataset_name": dataset_name_value,
"batch": batch, "batch": batch,
@ -138,6 +138,7 @@ class IndexProcessor:
"created_at": created_at_value.timestamp(), "created_at": created_at_value.timestamp(),
"display_status": "completed", "display_status": "completed",
} }
return result
def get_preview_output( def get_preview_output(
self, self,

View File

@ -32,6 +32,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
@ -49,7 +50,6 @@ from models.account import Account
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController() _file_access_controller = DatabaseFileAccessController()

View File

@ -17,6 +17,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import ParentMode, Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
@ -30,7 +31,6 @@ from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.account_service import AccountService from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -19,6 +19,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.entities import Rule
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -30,7 +31,6 @@ from libs import helper
from models.account import Account from models.account import Account
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,16 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from core.rag.entities import KeywordSetting, VectorSetting
class VectorSetting(BaseModel):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
keyword_weight: float
class Weights(BaseModel): class Weights(BaseModel):

View File

@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.constant.query_type import QueryType from core.rag.index_processor.constant.query_type import QueryType
@ -604,7 +602,7 @@ class DatasetRetrieval:
planning_strategy: PlanningStrategy, planning_strategy: PlanningStrategy,
message_id: str | None = None, message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None, metadata_condition: MetadataFilteringCondition | None = None,
): ):
tools = [] tools = []
for dataset in available_datasets: for dataset in available_datasets:
@ -743,7 +741,7 @@ class DatasetRetrieval:
reranking_enable: bool = True, reranking_enable: bool = True,
message_id: str | None = None, message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None,
metadata_condition: MetadataCondition | None = None, metadata_condition: MetadataFilteringCondition | None = None,
attachment_ids: list[str] | None = None, attachment_ids: list[str] | None = None,
): ):
if not available_datasets: if not available_datasets:
@ -1063,7 +1061,7 @@ class DatasetRetrieval:
top_k: int, top_k: int,
all_documents: list[Document], all_documents: list[Document],
document_ids_filter: list[str] | None = None, document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None, metadata_condition: MetadataFilteringCondition | None = None,
attachment_ids: list[str] | None = None, attachment_ids: list[str] | None = None,
): ):
with flask_app.app_context(): with flask_app.app_context():
@ -1339,7 +1337,7 @@ class DatasetRetrieval:
metadata_model_config: ModelConfig, metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition | None, metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict, inputs: dict,
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: ) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
document_query = select(DatasetDocument).where( document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
@ -1371,7 +1369,7 @@ class DatasetRetrieval:
value=filter.get("value"), value=filter.get("value"),
) )
) )
metadata_condition = MetadataCondition( metadata_condition = MetadataFilteringCondition(
logical_operator=metadata_filtering_conditions.logical_operator logical_operator=metadata_filtering_conditions.logical_operator
if metadata_filtering_conditions if metadata_filtering_conditions
else "or", # type: ignore else "or", # type: ignore
@ -1400,7 +1398,7 @@ class DatasetRetrieval:
expected_value, expected_value,
filters, filters,
) )
metadata_condition = MetadataCondition( metadata_condition = MetadataFilteringCondition(
logical_operator=metadata_filtering_conditions.logical_operator, logical_operator=metadata_filtering_conditions.logical_operator,
conditions=conditions, conditions=conditions,
) )
@ -1723,7 +1721,7 @@ class DatasetRetrieval:
self, self,
flask_app: Flask, flask_app: Flask,
available_datasets: list[Dataset], available_datasets: list[Dataset],
metadata_condition: MetadataCondition | None, metadata_condition: MetadataFilteringCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None, metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document], all_documents: list[Document],
tenant_id: str, tenant_id: str,

Some files were not shown because too many files have changed in this diff Show More