mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
Merge branch 'feat/new-biliing-quota' into deploy/dev
This commit is contained in:
commit
ef7dc9eabb
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -1,3 +1,10 @@
|
|||||||
web:
|
web:
|
||||||
- changed-files:
|
- changed-files:
|
||||||
- any-glob-to-any-file: 'web/**'
|
- any-glob-to-any-file:
|
||||||
|
- 'web/**'
|
||||||
|
- 'packages/**'
|
||||||
|
- 'package.json'
|
||||||
|
- 'pnpm-lock.yaml'
|
||||||
|
- 'pnpm-workspace.yaml'
|
||||||
|
- '.npmrc'
|
||||||
|
- '.nvmrc'
|
||||||
|
|||||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@ -20,4 +20,4 @@
|
|||||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||||
- [x] I've updated the documentation accordingly.
|
- [x] I've updated the documentation accordingly.
|
||||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
|
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||||
|
|||||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -39,9 +39,11 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
web/**
|
web/**
|
||||||
|
packages/**
|
||||||
package.json
|
package.json
|
||||||
pnpm-lock.yaml
|
pnpm-lock.yaml
|
||||||
pnpm-workspace.yaml
|
pnpm-workspace.yaml
|
||||||
|
.npmrc
|
||||||
.nvmrc
|
.nvmrc
|
||||||
- name: Check api inputs
|
- name: Check api inputs
|
||||||
if: github.event_name != 'merge_group'
|
if: github.event_name != 'merge_group'
|
||||||
|
|||||||
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -65,7 +65,7 @@ jobs:
|
|||||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||||
with:
|
with:
|
||||||
username: ${{ env.DOCKERHUB_USER }}
|
username: ${{ env.DOCKERHUB_USER }}
|
||||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||||
@ -130,7 +130,7 @@ jobs:
|
|||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
|
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||||
with:
|
with:
|
||||||
username: ${{ env.DOCKERHUB_USER }}
|
username: ${{ env.DOCKERHUB_USER }}
|
||||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||||
|
|||||||
2
.github/workflows/docker-build.yml
vendored
2
.github/workflows/docker-build.yml
vendored
@ -8,9 +8,11 @@ on:
|
|||||||
- api/Dockerfile
|
- api/Dockerfile
|
||||||
- web/docker/**
|
- web/docker/**
|
||||||
- web/Dockerfile
|
- web/Dockerfile
|
||||||
|
- packages/**
|
||||||
- package.json
|
- package.json
|
||||||
- pnpm-lock.yaml
|
- pnpm-lock.yaml
|
||||||
- pnpm-workspace.yaml
|
- pnpm-workspace.yaml
|
||||||
|
- .npmrc
|
||||||
- .nvmrc
|
- .nvmrc
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|||||||
4
.github/workflows/main-ci.yml
vendored
4
.github/workflows/main-ci.yml
vendored
@ -65,9 +65,11 @@ jobs:
|
|||||||
- 'docker/volumes/sandbox/conf/**'
|
- 'docker/volumes/sandbox/conf/**'
|
||||||
web:
|
web:
|
||||||
- 'web/**'
|
- 'web/**'
|
||||||
|
- 'packages/**'
|
||||||
- 'package.json'
|
- 'package.json'
|
||||||
- 'pnpm-lock.yaml'
|
- 'pnpm-lock.yaml'
|
||||||
- 'pnpm-workspace.yaml'
|
- 'pnpm-workspace.yaml'
|
||||||
|
- '.npmrc'
|
||||||
- '.nvmrc'
|
- '.nvmrc'
|
||||||
- '.github/workflows/web-tests.yml'
|
- '.github/workflows/web-tests.yml'
|
||||||
- '.github/actions/setup-web/**'
|
- '.github/actions/setup-web/**'
|
||||||
@ -77,9 +79,11 @@ jobs:
|
|||||||
- 'api/uv.lock'
|
- 'api/uv.lock'
|
||||||
- 'e2e/**'
|
- 'e2e/**'
|
||||||
- 'web/**'
|
- 'web/**'
|
||||||
|
- 'packages/**'
|
||||||
- 'package.json'
|
- 'package.json'
|
||||||
- 'pnpm-lock.yaml'
|
- 'pnpm-lock.yaml'
|
||||||
- 'pnpm-workspace.yaml'
|
- 'pnpm-workspace.yaml'
|
||||||
|
- '.npmrc'
|
||||||
- '.nvmrc'
|
- '.nvmrc'
|
||||||
- 'docker/docker-compose.middleware.yaml'
|
- 'docker/docker-compose.middleware.yaml'
|
||||||
- 'docker/middleware.env.example'
|
- 'docker/middleware.env.example'
|
||||||
|
|||||||
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -77,9 +77,11 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
files: |
|
files: |
|
||||||
web/**
|
web/**
|
||||||
|
packages/**
|
||||||
package.json
|
package.json
|
||||||
pnpm-lock.yaml
|
pnpm-lock.yaml
|
||||||
pnpm-workspace.yaml
|
pnpm-workspace.yaml
|
||||||
|
.npmrc
|
||||||
.nvmrc
|
.nvmrc
|
||||||
.github/workflows/style.yml
|
.github/workflows/style.yml
|
||||||
.github/actions/setup-web/**
|
.github/actions/setup-web/**
|
||||||
@ -149,7 +151,7 @@ jobs:
|
|||||||
.editorconfig
|
.editorconfig
|
||||||
|
|
||||||
- name: Super-linter
|
- name: Super-linter
|
||||||
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
|
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
env:
|
env:
|
||||||
BASH_SEVERITY: warning
|
BASH_SEVERITY: warning
|
||||||
|
|||||||
1
.github/workflows/tool-test-sdks.yaml
vendored
1
.github/workflows/tool-test-sdks.yaml
vendored
@ -9,6 +9,7 @@ on:
|
|||||||
- package.json
|
- package.json
|
||||||
- pnpm-lock.yaml
|
- pnpm-lock.yaml
|
||||||
- pnpm-workspace.yaml
|
- pnpm-workspace.yaml
|
||||||
|
- .npmrc
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||||
|
|||||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -240,7 +240,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run Claude Code for Translation Sync
|
- name: Run Claude Code for Translation Sync
|
||||||
if: steps.context.outputs.CHANGED_FILES != ''
|
if: steps.context.outputs.CHANGED_FILES != ''
|
||||||
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
|
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
|
||||||
with:
|
with:
|
||||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
|||||||
remove_tool_cache: true
|
remove_tool_cache: true
|
||||||
|
|
||||||
- name: Setup UV and Python
|
- name: Setup UV and Python
|
||||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|||||||
@ -89,30 +89,10 @@ if $web_modified; then
|
|||||||
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Running unit tests check"
|
echo "Running knip"
|
||||||
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
|
if ! pnpm run knip; then
|
||||||
|
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
|
||||||
if [ -n "$modified_files" ]; then
|
exit 1
|
||||||
for file in $modified_files; do
|
|
||||||
test_file="${file%.*}.spec.ts"
|
|
||||||
echo "Checking for test file: $test_file"
|
|
||||||
|
|
||||||
# check if the test file exists
|
|
||||||
if [ -f "../$test_file" ]; then
|
|
||||||
echo "Detected changes in $file, running corresponding unit tests..."
|
|
||||||
pnpm run test "../$test_file"
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Unit tests failed. Please fix the errors before committing."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "Unit tests for $file passed."
|
|
||||||
else
|
|
||||||
echo "Warning: $file does not have a corresponding test file."
|
|
||||||
fi
|
|
||||||
|
|
||||||
done
|
|
||||||
echo "All unit tests for modified web/utils files have passed."
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd ../
|
cd ../
|
||||||
|
|||||||
18
api/celery_healthcheck.py
Normal file
18
api/celery_healthcheck.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# This module provides a lightweight Celery instance for use in Docker health checks.
|
||||||
|
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
|
||||||
|
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
|
||||||
|
# Using this module keeps the health check fast and low-cost.
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
|
||||||
|
|
||||||
|
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
|
||||||
|
|
||||||
|
broker_transport_options = get_celery_broker_transport_options()
|
||||||
|
if broker_transport_options:
|
||||||
|
celery.conf.update(broker_transport_options=broker_transport_options)
|
||||||
|
|
||||||
|
ssl_options = get_celery_ssl_options()
|
||||||
|
if ssl_options:
|
||||||
|
celery.conf.update(broker_use_ssl=ssl_options)
|
||||||
@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import TypedDict
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@ -503,7 +503,19 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
|||||||
return [row[0] for row in result]
|
return [row[0] for row in result]
|
||||||
|
|
||||||
|
|
||||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
class _AppOrphanCounts(TypedDict):
|
||||||
|
variables: int
|
||||||
|
files: int
|
||||||
|
|
||||||
|
|
||||||
|
class OrphanedDraftVariableStatsDict(TypedDict):
|
||||||
|
total_orphaned_variables: int
|
||||||
|
total_orphaned_files: int
|
||||||
|
orphaned_app_count: int
|
||||||
|
orphaned_by_app: dict[str, _AppOrphanCounts]
|
||||||
|
|
||||||
|
|
||||||
|
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
|
||||||
"""
|
"""
|
||||||
Count orphaned draft variables by app, including associated file counts.
|
Count orphaned draft variables by app, including associated file counts.
|
||||||
|
|
||||||
@ -526,7 +538,7 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
|
|||||||
|
|
||||||
with db.engine.connect() as conn:
|
with db.engine.connect() as conn:
|
||||||
result = conn.execute(sa.text(variables_query))
|
result = conn.execute(sa.text(variables_query))
|
||||||
orphaned_by_app = {}
|
orphaned_by_app: dict[str, _AppOrphanCounts] = {}
|
||||||
total_files = 0
|
total_files = 0
|
||||||
|
|
||||||
for row in result:
|
for row in result:
|
||||||
|
|||||||
63
api/controllers/common/controller_schemas.py
Normal file
63
api/controllers/common/controller_schemas.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
|
|
||||||
|
# --- Conversation schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationRenamePayload(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
auto_generate: bool = False
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_name_requirement(self):
|
||||||
|
if not self.auto_generate:
|
||||||
|
if self.name is None or not self.name.strip():
|
||||||
|
raise ValueError("name is required when auto_generate is false")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
# --- Message schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class MessageListQuery(BaseModel):
|
||||||
|
conversation_id: UUIDStrOrEmpty
|
||||||
|
first_id: UUIDStrOrEmpty | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFeedbackPayload(BaseModel):
|
||||||
|
rating: Literal["like", "dislike"] | None = None
|
||||||
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Saved message schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class SavedMessageListQuery(BaseModel):
|
||||||
|
last_id: UUIDStrOrEmpty | None = None
|
||||||
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
|
class SavedMessageCreatePayload(BaseModel):
|
||||||
|
message_id: UUIDStrOrEmpty
|
||||||
|
|
||||||
|
|
||||||
|
# --- Workflow schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunPayload(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
files: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Audio schemas ---
|
||||||
|
|
||||||
|
|
||||||
|
class TextToAudioPayload(BaseModel):
|
||||||
|
message_id: str | None = None
|
||||||
|
voice: str | None = None
|
||||||
|
text: str | None = None
|
||||||
|
streaming: bool | None = None
|
||||||
@ -2,6 +2,7 @@ import csv
|
|||||||
import io
|
import io
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -17,7 +18,7 @@ from core.db.session_factory import session_factory
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_access_token
|
||||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService, LangContentDict
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
@ -328,7 +329,7 @@ class UpsertNotificationApi(Resource):
|
|||||||
def post(self):
|
def post(self):
|
||||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||||
result = BillingService.upsert_notification(
|
result = BillingService.upsert_notification(
|
||||||
contents=[c.model_dump() for c in payload.contents],
|
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
|
||||||
frequency=payload.frequency,
|
frequency=payload.frequency,
|
||||||
status=payload.status,
|
status=payload.status,
|
||||||
notification_id=payload.notification_id,
|
notification_id=payload.notification_id,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from flask import request
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from graphon.enums import WorkflowExecutionStatus
|
from graphon.enums import WorkflowExecutionStatus
|
||||||
from graphon.file import helpers as file_helpers
|
from graphon.file import helpers as file_helpers
|
||||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
@ -26,9 +26,11 @@ from controllers.console.wraps import (
|
|||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.ops.ops_trace_manager import OpsTraceManager
|
from core.ops.ops_trace_manager import OpsTraceManager
|
||||||
|
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from fields.base import ResponseModel
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App, DatasetPermissionEnum, Workflow
|
from models import App, DatasetPermissionEnum, Workflow
|
||||||
from models.model import IconType
|
from models.model import IconType
|
||||||
@ -41,10 +43,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||||||
NotionIcon,
|
NotionIcon,
|
||||||
NotionInfo,
|
NotionInfo,
|
||||||
NotionPage,
|
NotionPage,
|
||||||
PreProcessingRule,
|
|
||||||
RerankingModel,
|
RerankingModel,
|
||||||
Rule,
|
|
||||||
Segmentation,
|
|
||||||
WebsiteInfo,
|
WebsiteInfo,
|
||||||
WeightKeywordSetting,
|
WeightKeywordSetting,
|
||||||
WeightModel,
|
WeightModel,
|
||||||
@ -155,16 +154,6 @@ class AppTracePayload(BaseModel):
|
|||||||
type JSONValue = Any
|
type JSONValue = Any
|
||||||
|
|
||||||
|
|
||||||
class ResponseModel(BaseModel):
|
|
||||||
model_config = ConfigDict(
|
|
||||||
from_attributes=True,
|
|
||||||
extra="ignore",
|
|
||||||
populate_by_name=True,
|
|
||||||
serialize_by_alias=True,
|
|
||||||
protected_namespaces=(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||||
if isinstance(value, datetime):
|
if isinstance(value, datetime):
|
||||||
return int(value.timestamp())
|
return int(value.timestamp())
|
||||||
|
|||||||
@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
|
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||||
"""Common prerequisites for all draft workflow variable APIs.
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
It ensures the following conditions are satisfied:
|
It ensures the following conditions are satisfied:
|
||||||
@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
|
|||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args: Any, **kwargs: Any):
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@ -66,13 +66,13 @@ class WebhookTriggerApi(Resource):
|
|||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
# Get webhook trigger for this app and node
|
# Get webhook trigger for this app and node
|
||||||
webhook_trigger = (
|
webhook_trigger = session.scalar(
|
||||||
session.query(WorkflowWebhookTrigger)
|
select(WorkflowWebhookTrigger)
|
||||||
.where(
|
.where(
|
||||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||||
WorkflowWebhookTrigger.node_id == node_id,
|
WorkflowWebhookTrigger.node_id == node_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not webhook_trigger:
|
if not webhook_trigger:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any
|
from typing import overload
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
|||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(
|
@overload
|
||||||
view: Callable[..., Any] | None = None,
|
def get_app_model[**P, R](
|
||||||
|
view: Callable[P, R],
|
||||||
*,
|
*,
|
||||||
mode: AppMode | list[AppMode] | None = None,
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
) -> Callable[P, R]: ...
|
||||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_app_model[**P, R](
|
||||||
|
view: None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_model[**P, R](
|
||||||
|
view: Callable[P, R] | None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
|
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: Any, **kwargs: Any):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
@ -68,14 +84,30 @@ def get_app_model(
|
|||||||
return decorator(view)
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
def get_app_model_with_trial(
|
@overload
|
||||||
view: Callable[..., Any] | None = None,
|
def get_app_model_with_trial[**P, R](
|
||||||
|
view: Callable[P, R],
|
||||||
*,
|
*,
|
||||||
mode: AppMode | list[AppMode] | None = None,
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
) -> Callable[P, R]: ...
|
||||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_app_model_with_trial[**P, R](
|
||||||
|
view: None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_model_with_trial[**P, R](
|
||||||
|
view: Callable[P, R] | None = None,
|
||||||
|
*,
|
||||||
|
mode: AppMode | list[AppMode] | None = None,
|
||||||
|
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
|
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: Any, **kwargs: Any):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import secrets
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@ -20,35 +20,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
|||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import EmailStr, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
|
from services.entities.auth_entities import (
|
||||||
|
ForgotPasswordCheckPayload,
|
||||||
|
ForgotPasswordResetPayload,
|
||||||
|
ForgotPasswordSendPayload,
|
||||||
|
)
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordSendPayload(BaseModel):
|
|
||||||
email: EmailStr = Field(...)
|
|
||||||
language: str | None = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordCheckPayload(BaseModel):
|
|
||||||
email: EmailStr = Field(...)
|
|
||||||
code: str = Field(...)
|
|
||||||
token: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordResetPayload(BaseModel):
|
|
||||||
token: str = Field(...)
|
|
||||||
new_password: str = Field(...)
|
|
||||||
password_confirm: str = Field(...)
|
|
||||||
|
|
||||||
@field_validator("new_password", "password_confirm")
|
|
||||||
@classmethod
|
|
||||||
def validate_password(cls, value: str) -> str:
|
|
||||||
return valid_password(value)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordEmailResponse(BaseModel):
|
class ForgotPasswordEmailResponse(BaseModel):
|
||||||
result: str = Field(description="Operation result")
|
result: str = Field(description="Operation result")
|
||||||
data: str | None = Field(default=None, description="Reset token")
|
data: str | None = Field(default=None, description="Reset token")
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -42,8 +40,9 @@ from libs.token import (
|
|||||||
set_csrf_token_to_cookie,
|
set_csrf_token_to_cookie,
|
||||||
set_refresh_token_to_cookie,
|
set_refresh_token_to_cookie,
|
||||||
)
|
)
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
from services.entities.auth_entities import LoginPayloadBase
|
||||||
from services.errors.account import AccountRegisterError
|
from services.errors.account import AccountRegisterError
|
||||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
@ -51,9 +50,7 @@ from services.feature_service import FeatureService
|
|||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
class LoginPayload(BaseModel):
|
class LoginPayload(LoginPayloadBase):
|
||||||
email: EmailStr = Field(..., description="Email address")
|
|
||||||
password: str = Field(..., description="Password")
|
|
||||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||||
|
|
||||||
@ -101,7 +98,7 @@ class LoginApi(Resource):
|
|||||||
raise EmailPasswordLoginLimitError()
|
raise EmailPasswordLoginLimitError()
|
||||||
|
|
||||||
invite_token = args.invite_token
|
invite_token = args.invite_token
|
||||||
invitation_data: dict[str, Any] | None = None
|
invitation_data: InvitationDetailDict | None = None
|
||||||
if invite_token:
|
if invite_token:
|
||||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||||
if invitation_data is None:
|
if invitation_data is None:
|
||||||
|
|||||||
@ -158,10 +158,11 @@ class DataSourceApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
binding_id = str(binding_id)
|
binding_id = str(binding_id)
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
data_source_binding = session.execute(
|
data_source_binding = session.execute(
|
||||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
if data_source_binding is None:
|
if data_source_binding is None:
|
||||||
raise NotFound("Data source binding not found.")
|
raise NotFound("Data source binding not found.")
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import logging
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@ -86,8 +87,8 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def post(self, template_id: str):
|
def post(self, template_id: str):
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
template = (
|
template = session.scalar(
|
||||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
|
||||||
)
|
)
|
||||||
if not template:
|
if not template:
|
||||||
raise ValueError("Customized pipeline template not found.")
|
raise ValueError("Customized pipeline template not found.")
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any, NoReturn
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
|||||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||||
|
|
||||||
|
|
||||||
def _api_prerequisite(f):
|
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||||
"""Common prerequisites for all draft workflow variable APIs.
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
It ensures the following conditions are satisfied:
|
It ensures the following conditions are satisfied:
|
||||||
@ -70,7 +71,7 @@ def _api_prerequisite(f):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|||||||
@ -2,10 +2,10 @@ import logging
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import TextToAudioPayload
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
@ -32,14 +32,6 @@ from .. import console_ns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TextToAudioPayload(BaseModel):
|
|
||||||
message_id: str | None = None
|
|
||||||
voice: str | None = None
|
|
||||||
text: str | None = None
|
|
||||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_model(console_ns, TextToAudioPayload)
|
register_schema_model(console_ns, TextToAudioPayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.explore.error import NotChatAppError
|
from controllers.console.explore.error import NotChatAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
|
|||||||
pinned: bool | None = None
|
pinned: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenamePayload(BaseModel):
|
|
||||||
name: str | None = None
|
|
||||||
auto_generate: bool = False
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_name_requirement(self):
|
|
||||||
if not self.auto_generate:
|
|
||||||
if self.name is None or not self.name.strip():
|
|
||||||
raise ValueError("name is required when auto_generate is false")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,10 @@ from typing import Literal
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
AppMoreLikeThisDisabledError,
|
AppMoreLikeThisDisabledError,
|
||||||
@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
|||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from models.enums import FeedbackRating
|
from models.enums import FeedbackRating
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
@ -44,17 +44,6 @@ from .. import console_ns
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MessageListQuery(BaseModel):
|
|
||||||
conversation_id: UUIDStrOrEmpty
|
|
||||||
first_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackPayload(BaseModel):
|
|
||||||
rating: Literal["like", "dislike"] | None = None
|
|
||||||
content: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class MoreLikeThisQuery(BaseModel):
|
class MoreLikeThisQuery(BaseModel):
|
||||||
response_mode: Literal["blocking", "streaming"]
|
response_mode: Literal["blocking", "streaming"]
|
||||||
|
|
||||||
|
|||||||
@ -1,28 +1,18 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.explore.error import NotCompletionAppError
|
from controllers.console.explore.error import NotCompletionAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListQuery(BaseModel):
|
|
||||||
last_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageCreatePayload(BaseModel):
|
|
||||||
message_id: UUIDStrOrEmpty
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from graphon.graph_engine.manager import GraphEngineManager
|
from graphon.graph_engine.manager import GraphEngineManager
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel
|
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
CompletionRequestError,
|
CompletionRequestError,
|
||||||
@ -34,12 +33,6 @@ from .. import console_ns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunPayload(BaseModel):
|
|
||||||
inputs: dict[str, Any]
|
|
||||||
files: list[dict[str, Any]] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_model(console_ns, WorkflowRunPayload)
|
register_schema_model(console_ns, WorkflowRunPayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -11,6 +13,21 @@ from services.billing_service import BillingService
|
|||||||
_FALLBACK_LANG = "en-US"
|
_FALLBACK_LANG = "en-US"
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationItemDict(TypedDict):
|
||||||
|
notification_id: str | None
|
||||||
|
frequency: str | None
|
||||||
|
lang: str
|
||||||
|
title: str
|
||||||
|
subtitle: str
|
||||||
|
body: str
|
||||||
|
title_pic_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationResponseDict(TypedDict):
|
||||||
|
should_show: bool
|
||||||
|
notifications: list[NotificationItemDict]
|
||||||
|
|
||||||
|
|
||||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||||
"""Return the single LangContent for *lang*, falling back to English."""
|
"""Return the single LangContent for *lang*, falling back to English."""
|
||||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||||
@ -45,28 +62,30 @@ class NotificationApi(Resource):
|
|||||||
result = BillingService.get_account_notification(str(current_user.id))
|
result = BillingService.get_account_notification(str(current_user.id))
|
||||||
|
|
||||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||||
|
response: NotificationResponseDict
|
||||||
if not result.get("shouldShow"):
|
if not result.get("shouldShow"):
|
||||||
return {"should_show": False, "notifications": []}, 200
|
response = {"should_show": False, "notifications": []}
|
||||||
|
return response, 200
|
||||||
|
|
||||||
lang = current_user.interface_language or _FALLBACK_LANG
|
lang = current_user.interface_language or _FALLBACK_LANG
|
||||||
|
|
||||||
notifications = []
|
notifications: list[NotificationItemDict] = []
|
||||||
for notification in result.get("notifications") or []:
|
for notification in result.get("notifications") or []:
|
||||||
contents: dict = notification.get("contents") or {}
|
contents: dict = notification.get("contents") or {}
|
||||||
lang_content = _pick_lang_content(contents, lang)
|
lang_content = _pick_lang_content(contents, lang)
|
||||||
notifications.append(
|
item: NotificationItemDict = {
|
||||||
{
|
"notification_id": notification.get("notificationId"),
|
||||||
"notification_id": notification.get("notificationId"),
|
"frequency": notification.get("frequency"),
|
||||||
"frequency": notification.get("frequency"),
|
"lang": lang_content.get("lang", lang),
|
||||||
"lang": lang_content.get("lang", lang),
|
"title": lang_content.get("title", ""),
|
||||||
"title": lang_content.get("title", ""),
|
"subtitle": lang_content.get("subtitle", ""),
|
||||||
"subtitle": lang_content.get("subtitle", ""),
|
"body": lang_content.get("body", ""),
|
||||||
"body": lang_content.get("body", ""),
|
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
}
|
||||||
}
|
notifications.append(item)
|
||||||
)
|
|
||||||
|
|
||||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
response = {"should_show": bool(notifications), "notifications": notifications}
|
||||||
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/notification/dismiss")
|
@console_ns.route("/notification/dismiss")
|
||||||
|
|||||||
@ -9,7 +9,14 @@ from controllers.common.schema import register_schema_models
|
|||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.tag_service import TagService
|
from models.enums import TagType
|
||||||
|
from services.tag_service import (
|
||||||
|
SaveTagPayload,
|
||||||
|
TagBindingCreatePayload,
|
||||||
|
TagBindingDeletePayload,
|
||||||
|
TagService,
|
||||||
|
UpdateTagPayload,
|
||||||
|
)
|
||||||
|
|
||||||
dataset_tag_fields = {
|
dataset_tag_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
@ -25,19 +32,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
|
|||||||
|
|
||||||
class TagBasePayload(BaseModel):
|
class TagBasePayload(BaseModel):
|
||||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
type: TagType = Field(description="Tag type")
|
||||||
|
|
||||||
|
|
||||||
class TagBindingPayload(BaseModel):
|
class TagBindingPayload(BaseModel):
|
||||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||||
target_id: str = Field(description="Target ID to bind tags to")
|
target_id: str = Field(description="Target ID to bind tags to")
|
||||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
type: TagType = Field(description="Tag type")
|
||||||
|
|
||||||
|
|
||||||
class TagBindingRemovePayload(BaseModel):
|
class TagBindingRemovePayload(BaseModel):
|
||||||
tag_id: str = Field(description="Tag ID to remove")
|
tag_id: str = Field(description="Tag ID to remove")
|
||||||
target_id: str = Field(description="Target ID to unbind tag from")
|
target_id: str = Field(description="Target ID to unbind tag from")
|
||||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
type: TagType = Field(description="Tag type")
|
||||||
|
|
||||||
|
|
||||||
class TagListQueryParam(BaseModel):
|
class TagListQueryParam(BaseModel):
|
||||||
@ -82,7 +89,7 @@ class TagListApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||||
tag = TagService.save_tags(payload.model_dump())
|
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||||
|
|
||||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||||
|
|
||||||
@ -103,7 +110,7 @@ class TagUpdateDeleteApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
|
||||||
|
|
||||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||||
|
|
||||||
@ -136,7 +143,9 @@ class TagBindingCreateApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||||
TagService.save_tag_binding(payload.model_dump())
|
TagService.save_tag_binding(
|
||||||
|
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||||
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
@ -154,6 +163,8 @@ class TagBindingDeleteApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||||
TagService.delete_tag_binding(payload.model_dump())
|
TagService.delete_tag_binding(
|
||||||
|
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||||
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -21,12 +22,12 @@ def plugin_permission_required(
|
|||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
permission = (
|
permission = session.scalar(
|
||||||
session.query(TenantPluginPermission)
|
select(TenantPluginPermission)
|
||||||
.where(
|
.where(
|
||||||
TenantPluginPermission.tenant_id == tenant_id,
|
TenantPluginPermission.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not permission:
|
if not permission:
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Tenant, TenantStatus
|
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.billing_service import BillingService, SubscriptionPlan
|
from services.billing_service import BillingService, SubscriptionPlan
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -240,8 +240,10 @@ class CustomConfigWorkspaceApi(Resource):
|
|||||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||||
|
|
||||||
custom_config_dict = {
|
custom_config_dict: TenantCustomConfigDict = {
|
||||||
"remove_webapp_brand": args.remove_webapp_brand,
|
"remove_webapp_brand": args.remove_webapp_brand
|
||||||
|
if args.remove_webapp_brand is not None
|
||||||
|
else tenant.custom_config_dict.get("remove_webapp_brand", False),
|
||||||
"replace_webapp_logo": args.replace_webapp_logo
|
"replace_webapp_logo": args.replace_webapp_logo
|
||||||
if args.replace_webapp_logo is not None
|
if args.replace_webapp_logo is not None
|
||||||
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from flask import request
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.console.wraps import setup_required
|
from controllers.console.wraps import setup_required
|
||||||
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
|
|||||||
|
|
||||||
account.set_tenant_id(workspace_id)
|
account.set_tenant_id(workspace_id)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
dsl_service = AppDslService(session)
|
dsl_service = AppDslService(session)
|
||||||
result = dsl_service.import_app(
|
result = dsl_service.import_app(
|
||||||
account=account,
|
account=account,
|
||||||
@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
|
|||||||
name=args.name,
|
name=args.name,
|
||||||
description=args.description,
|
description=args.description,
|
||||||
)
|
)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
if result.status == ImportStatus.FAILED:
|
if result.status == ImportStatus.FAILED:
|
||||||
return result.model_dump(mode="json"), 400
|
return result.model_dump(mode="json"), 400
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from flask import Response
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from graphon.variables.input_entities import VariableEntity
|
from graphon.variables.input_entities import VariableEntity
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
@ -80,11 +81,11 @@ class MCPAppApi(Resource):
|
|||||||
|
|
||||||
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||||
"""Get and validate MCP server and app in one query session"""
|
"""Get and validate MCP server and app in one query session"""
|
||||||
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||||
|
|
||||||
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
|
||||||
if not app:
|
if not app:
|
||||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||||
|
|
||||||
@ -190,12 +191,12 @@ class MCPAppApi(Resource):
|
|||||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||||
"""Get end user - manages its own database session"""
|
"""Get end user - manages its own database session"""
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
return (
|
return session.scalar(
|
||||||
session.query(EndUser)
|
select(EndUser)
|
||||||
.where(EndUser.tenant_id == tenant_id)
|
.where(EndUser.tenant_id == tenant_id)
|
||||||
.where(EndUser.session_id == mcp_server_id)
|
.where(EndUser.session_id == mcp_server_id)
|
||||||
.where(EndUser.type == "mcp")
|
.where(EndUser.type == "mcp")
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_end_user(
|
def _create_end_user(
|
||||||
|
|||||||
@ -2,11 +2,12 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import NotChatAppError
|
from controllers.service_api.app.error import NotChatAppError
|
||||||
@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenamePayload(BaseModel):
|
|
||||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
|
||||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_name_requirement(self):
|
|
||||||
if not self.auto_generate:
|
|
||||||
if self.name is None or not self.name.strip():
|
|
||||||
raise ValueError("name is required when auto_generate is false")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariablesQuery(BaseModel):
|
class ConversationVariablesQuery(BaseModel):
|
||||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
|
|||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import NotChatAppError
|
from controllers.service_api.app.error import NotChatAppError
|
||||||
@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from models.enums import FeedbackRating
|
from models.enums import FeedbackRating
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.errors.message import (
|
from services.errors.message import (
|
||||||
@ -27,17 +26,6 @@ from services.message_service import MessageService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MessageListQuery(BaseModel):
|
|
||||||
conversation_id: UUIDStrOrEmpty
|
|
||||||
first_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackPayload(BaseModel):
|
|
||||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
|
||||||
content: str | None = Field(default=None, description="Feedback content")
|
|
||||||
|
|
||||||
|
|
||||||
class FeedbackListQuery(BaseModel):
|
class FeedbackListQuery(BaseModel):
|
||||||
page: int = Field(default=1, ge=1, description="Page number")
|
page: int = Field(default=1, ge=1, description="Page number")
|
||||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Literal
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
from flask import request
|
from flask import request
|
||||||
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
|||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunPayload(BaseModel):
|
class WorkflowRunPayload(WorkflowRunPayloadBase):
|
||||||
inputs: dict[str, Any]
|
|
||||||
files: list[dict[str, Any]] | None = None
|
|
||||||
response_mode: Literal["blocking", "streaming"] | None = None
|
response_mode: Literal["blocking", "streaming"] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -22,10 +22,17 @@ from fields.tag_fields import DataSetTag
|
|||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
|
from models.enums import TagType
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||||
from services.tag_service import TagService
|
from services.tag_service import (
|
||||||
|
SaveTagPayload,
|
||||||
|
TagBindingCreatePayload,
|
||||||
|
TagBindingDeletePayload,
|
||||||
|
TagService,
|
||||||
|
UpdateTagPayload,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
@ -513,7 +520,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||||
|
|
||||||
response = DataSetTag.model_validate(
|
response = DataSetTag.model_validate(
|
||||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||||
@ -536,9 +543,8 @@ class DatasetTagsApi(DatasetApiResource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||||
params = {"name": payload.name, "type": "knowledge"}
|
|
||||||
tag_id = payload.tag_id
|
tag_id = payload.tag_id
|
||||||
tag = TagService.update_tags(params, tag_id)
|
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
|
||||||
|
|
||||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||||
|
|
||||||
@ -585,7 +591,9 @@ class DatasetTagBindingApi(DatasetApiResource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
TagService.save_tag_binding(
|
||||||
|
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||||
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
@ -609,7 +617,9 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
TagService.delete_tag_binding(
|
||||||
|
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||||
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from controllers.service_api.wraps import (
|
|||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
)
|
)
|
||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
|
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.document_fields import document_fields, document_status_fields
|
from fields.document_fields import document_fields, document_status_fields
|
||||||
@ -40,11 +41,8 @@ from models.enums import SegmentStatus
|
|||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
KnowledgeConfig,
|
KnowledgeConfig,
|
||||||
PreProcessingRule,
|
|
||||||
ProcessRule,
|
ProcessRule,
|
||||||
RetrievalModel,
|
RetrievalModel,
|
||||||
Rule,
|
|
||||||
Segmentation,
|
|
||||||
)
|
)
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|||||||
@ -4,13 +4,23 @@ Serialization helpers for Service API knowledge pipeline endpoints.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
|
|
||||||
|
|
||||||
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
|
class UploadFileDict(TypedDict):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
size: int
|
||||||
|
extension: str
|
||||||
|
mime_type: str | None
|
||||||
|
created_by: str
|
||||||
|
created_at: str | None
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict:
|
||||||
return {
|
return {
|
||||||
"id": upload_file.id,
|
"id": upload_file.id,
|
||||||
"name": upload_file.name,
|
"name": upload_file.name,
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, cast, overload
|
from typing import cast, overload
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_login import user_logged_in
|
from flask_login import user_logged_in
|
||||||
@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
|
|||||||
return interceptor
|
return interceptor
|
||||||
|
|
||||||
|
|
||||||
def validate_dataset_token(
|
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||||
view: Callable[..., Any] | None = None,
|
positional_parameters = [
|
||||||
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
|
parameter
|
||||||
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
|
for parameter in inspect.signature(view).parameters.values()
|
||||||
@wraps(view_func)
|
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
]
|
||||||
api_token = validate_and_get_api_token("dataset")
|
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
|
||||||
|
|
||||||
# get url path dataset_id from positional args or kwargs
|
@wraps(view)
|
||||||
# Flask passes URL path parameters as positional arguments
|
def decorated(*args: object, **kwargs: object) -> R:
|
||||||
dataset_id = None
|
api_token = validate_and_get_api_token("dataset")
|
||||||
|
|
||||||
# First try to get from kwargs (explicit parameter)
|
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
|
||||||
dataset_id = kwargs.get("dataset_id")
|
dataset_id = kwargs.get("dataset_id")
|
||||||
|
|
||||||
# If not in kwargs, try to extract from positional args
|
if not dataset_id and args:
|
||||||
if not dataset_id and args:
|
potential_id = args[0]
|
||||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
try:
|
||||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
str_id = str(potential_id)
|
||||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||||
# This is a class method, dataset_id should be in args[1]
|
dataset_id = str_id
|
||||||
potential_id = args[1]
|
except Exception:
|
||||||
# Validate it's a string-like UUID, not another object
|
logger.exception("Failed to parse dataset_id from positional args")
|
||||||
try:
|
|
||||||
# Try to convert to string and check if it's a valid UUID format
|
|
||||||
str_id = str(potential_id)
|
|
||||||
# Basic check: UUIDs are 36 chars with hyphens
|
|
||||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
|
||||||
dataset_id = str_id
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to parse dataset_id from class method args")
|
|
||||||
elif len(args) > 0:
|
|
||||||
# Not a class method, check if args[0] looks like a UUID
|
|
||||||
potential_id = args[0]
|
|
||||||
try:
|
|
||||||
str_id = str(potential_id)
|
|
||||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
|
||||||
dataset_id = str_id
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to parse dataset_id from positional args")
|
|
||||||
|
|
||||||
# Validate dataset if dataset_id is provided
|
if dataset_id:
|
||||||
if dataset_id:
|
dataset_id = str(dataset_id)
|
||||||
dataset_id = str(dataset_id)
|
dataset = db.session.scalar(
|
||||||
dataset = db.session.scalar(
|
select(Dataset)
|
||||||
select(Dataset)
|
.where(
|
||||||
.where(
|
Dataset.id == dataset_id,
|
||||||
Dataset.id == dataset_id,
|
Dataset.tenant_id == api_token.tenant_id,
|
||||||
Dataset.tenant_id == api_token.tenant_id,
|
|
||||||
)
|
|
||||||
.limit(1)
|
|
||||||
)
|
)
|
||||||
if not dataset:
|
.limit(1)
|
||||||
raise NotFound("Dataset not found.")
|
)
|
||||||
if not dataset.enable_api:
|
if not dataset:
|
||||||
raise Forbidden("Dataset api access is not enabled.")
|
raise NotFound("Dataset not found.")
|
||||||
tenant_account_join = db.session.execute(
|
if not dataset.enable_api:
|
||||||
select(Tenant, TenantAccountJoin)
|
raise Forbidden("Dataset api access is not enabled.")
|
||||||
.where(Tenant.id == api_token.tenant_id)
|
|
||||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
tenant_account_join = db.session.execute(
|
||||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
select(Tenant, TenantAccountJoin)
|
||||||
.where(Tenant.status == TenantStatus.NORMAL)
|
.where(Tenant.id == api_token.tenant_id)
|
||||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||||
if tenant_account_join:
|
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||||
tenant, ta = tenant_account_join
|
.where(Tenant.status == TenantStatus.NORMAL)
|
||||||
account = db.session.get(Account, ta.account_id)
|
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||||
# Login admin
|
if tenant_account_join:
|
||||||
if account:
|
tenant, ta = tenant_account_join
|
||||||
account.current_tenant = tenant
|
account = db.session.get(Account, ta.account_id)
|
||||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
# Login admin
|
||||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
if account:
|
||||||
else:
|
account.current_tenant = tenant
|
||||||
raise Unauthorized("Tenant owner account does not exist.")
|
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||||
|
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise Unauthorized("Tenant does not exist.")
|
raise Unauthorized("Tenant owner account does not exist.")
|
||||||
if args and isinstance(args[0], Resource):
|
else:
|
||||||
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
raise Unauthorized("Tenant does not exist.")
|
||||||
|
|
||||||
return view_func(api_token.tenant_id, *args, **kwargs)
|
if expects_bound_instance:
|
||||||
|
if not args:
|
||||||
|
raise TypeError("validate_dataset_token expected a bound resource instance.")
|
||||||
|
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||||
|
|
||||||
return decorated
|
return view(api_token.tenant_id, *args, **kwargs)
|
||||||
|
|
||||||
if view:
|
return decorated
|
||||||
return decorator(view)
|
|
||||||
|
|
||||||
# if view is None, it means that the decorator is used without parentheses
|
|
||||||
# use the decorator as a function for method_decorators
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def validate_and_get_api_token(scope: str | None = None):
|
def validate_and_get_api_token(scope: str | None = None):
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
|||||||
from controllers.trigger import bp
|
from controllers.trigger import bp
|
||||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||||
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
||||||
from services.trigger.webhook_service import WebhookService
|
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
|||||||
webhook_id, is_debug=is_debug
|
webhook_id, is_debug=is_debug
|
||||||
)
|
)
|
||||||
|
|
||||||
|
webhook_data: RawWebhookDataDict
|
||||||
try:
|
try:
|
||||||
# Use new unified extraction and validation
|
# Use new unified extraction and validation
|
||||||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||||
|
|||||||
@ -3,10 +3,11 @@ import logging
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import fields, marshal_with
|
from flask_restx import fields, marshal_with
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import field_validator
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
@ -34,12 +35,7 @@ from services.errors.audio import (
|
|||||||
from ..common.schema import register_schema_models
|
from ..common.schema import register_schema_models
|
||||||
|
|
||||||
|
|
||||||
class TextToAudioPayload(BaseModel):
|
class TextToAudioPayload(TextToAudioPayloadBase):
|
||||||
message_id: str | None = None
|
|
||||||
voice: str | None = None
|
|
||||||
text: str | None = None
|
|
||||||
streaming: bool | None = None
|
|
||||||
|
|
||||||
@field_validator("message_id")
|
@field_validator("message_id")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_message_id(cls, value: str | None) -> str | None:
|
def validate_message_id(cls, value: str | None) -> str | None:
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import NotChatAppError
|
from controllers.web.error import NotChatAppError
|
||||||
@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
|
|||||||
return uuid_value(value)
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
class ConversationRenamePayload(BaseModel):
|
|
||||||
name: str | None = None
|
|
||||||
auto_generate: bool = False
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_name_requirement(self):
|
|
||||||
if not self.auto_generate:
|
|
||||||
if self.name is None or not self.name.strip():
|
|
||||||
raise ValueError("name is required when auto_generate is false")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import secrets
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
@ -19,33 +18,15 @@ from controllers.console.error import EmailSendIpLimitError
|
|||||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import EmailStr, extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
|
from services.entities.auth_entities import (
|
||||||
|
ForgotPasswordCheckPayload,
|
||||||
class ForgotPasswordSendPayload(BaseModel):
|
ForgotPasswordResetPayload,
|
||||||
email: EmailStr
|
ForgotPasswordSendPayload,
|
||||||
language: str | None = None
|
)
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordCheckPayload(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
code: str
|
|
||||||
token: str = Field(min_length=1)
|
|
||||||
|
|
||||||
|
|
||||||
class ForgotPasswordResetPayload(BaseModel):
|
|
||||||
token: str = Field(min_length=1)
|
|
||||||
new_password: str
|
|
||||||
password_confirm: str
|
|
||||||
|
|
||||||
@field_validator("new_password", "password_confirm")
|
|
||||||
@classmethod
|
|
||||||
def validate_password(cls, value: str) -> str:
|
|
||||||
return valid_password(value)
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||||
|
|
||||||
|
|||||||
@ -29,13 +29,11 @@ from libs.token import (
|
|||||||
)
|
)
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
from services.entities.auth_entities import LoginPayloadBase
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
|
||||||
class LoginPayload(BaseModel):
|
class LoginPayload(LoginPayloadBase):
|
||||||
email: EmailStr
|
|
||||||
password: str
|
|
||||||
|
|
||||||
@field_validator("password")
|
@field_validator("password")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_password(cls, value: str) -> str:
|
def validate_password(cls, value: str) -> str:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
|||||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
|
|||||||
return uuid_value(value)
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedbackPayload(BaseModel):
|
|
||||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
|
||||||
content: str | None = Field(default=None, description="Feedback content")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageMoreLikeThisQuery(BaseModel):
|
class MessageMoreLikeThisQuery(BaseModel):
|
||||||
response_mode: Literal["blocking", "streaming"] = Field(
|
response_mode: Literal["blocking", "streaming"] = Field(
|
||||||
description="Response mode",
|
description="Response mode",
|
||||||
|
|||||||
@ -1,27 +1,17 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import NotCompletionAppError
|
from controllers.web.error import NotCompletionAppError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from fields.conversation_fields import ResultResponse
|
from fields.conversation_fields import ResultResponse
|
||||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||||
from libs.helper import UUIDStrOrEmpty
|
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageListQuery(BaseModel):
|
|
||||||
last_id: UUIDStrOrEmpty | None = None
|
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
|
||||||
|
|
||||||
|
|
||||||
class SavedMessageCreatePayload(BaseModel):
|
|
||||||
message_id: UUIDStrOrEmpty
|
|
||||||
|
|
||||||
|
|
||||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from graphon.graph_engine.manager import GraphEngineManager
|
from graphon.graph_engine.manager import GraphEngineManager
|
||||||
from graphon.model_runtime.errors.invoke import InvokeError
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
|
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
|
|||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunPayload(BaseModel):
|
|
||||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
|
||||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
register_schema_models(web_ns, WorkflowRunPayload)
|
register_schema_models(web_ns, WorkflowRunPayload)
|
||||||
|
|||||||
@ -79,21 +79,18 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||||||
if not agent_scratchpad:
|
if not agent_scratchpad:
|
||||||
assistant_messages = []
|
assistant_messages = []
|
||||||
else:
|
else:
|
||||||
assistant_message = AssistantPromptMessage(content="")
|
content = ""
|
||||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
|
||||||
for unit in agent_scratchpad:
|
for unit in agent_scratchpad:
|
||||||
if unit.is_final():
|
if unit.is_final():
|
||||||
assert isinstance(assistant_message.content, str)
|
content += f"Final Answer: {unit.agent_response}"
|
||||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(assistant_message.content, str)
|
content += f"Thought: {unit.thought}\n\n"
|
||||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
|
||||||
if unit.action_str:
|
if unit.action_str:
|
||||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
content += f"Action: {unit.action_str}\n\n"
|
||||||
if unit.observation:
|
if unit.observation:
|
||||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
content += f"Observation: {unit.observation}\n\n"
|
||||||
|
|
||||||
assistant_messages = [assistant_message]
|
assistant_messages = [AssistantPromptMessage(content=content)]
|
||||||
|
|
||||||
# query messages
|
# query messages
|
||||||
query_messages = self._organize_user_query(self._query, [])
|
query_messages = self._organize_user_query(self._query, [])
|
||||||
|
|||||||
@ -5,6 +5,10 @@ from configs import dify_config
|
|||||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureToggleDict(TypedDict):
|
||||||
|
enabled: bool
|
||||||
|
|
||||||
|
|
||||||
class SystemParametersDict(TypedDict):
|
class SystemParametersDict(TypedDict):
|
||||||
image_file_size_limit: int
|
image_file_size_limit: int
|
||||||
video_file_size_limit: int
|
video_file_size_limit: int
|
||||||
@ -16,12 +20,12 @@ class SystemParametersDict(TypedDict):
|
|||||||
class AppParametersDict(TypedDict):
|
class AppParametersDict(TypedDict):
|
||||||
opening_statement: str | None
|
opening_statement: str | None
|
||||||
suggested_questions: list[str]
|
suggested_questions: list[str]
|
||||||
suggested_questions_after_answer: dict[str, Any]
|
suggested_questions_after_answer: FeatureToggleDict
|
||||||
speech_to_text: dict[str, Any]
|
speech_to_text: FeatureToggleDict
|
||||||
text_to_speech: dict[str, Any]
|
text_to_speech: FeatureToggleDict
|
||||||
retriever_resource: dict[str, Any]
|
retriever_resource: FeatureToggleDict
|
||||||
annotation_reply: dict[str, Any]
|
annotation_reply: FeatureToggleDict
|
||||||
more_like_this: dict[str, Any]
|
more_like_this: FeatureToggleDict
|
||||||
user_input_form: list[dict[str, Any]]
|
user_input_form: list[dict[str, Any]]
|
||||||
sensitive_word_avoidance: dict[str, Any]
|
sensitive_word_avoidance: dict[str, Any]
|
||||||
file_upload: dict[str, Any]
|
file_upload: dict[str, Any]
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||||
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
|
|||||||
config: dict[str, Any] = Field(default_factory=dict)
|
config: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
SupportedComparisonOperator = Literal[
|
|
||||||
# for string or array
|
|
||||||
"contains",
|
|
||||||
"not contains",
|
|
||||||
"start with",
|
|
||||||
"end with",
|
|
||||||
"is",
|
|
||||||
"is not",
|
|
||||||
"empty",
|
|
||||||
"not empty",
|
|
||||||
"in",
|
|
||||||
"not in",
|
|
||||||
# for number
|
|
||||||
"=",
|
|
||||||
"≠",
|
|
||||||
">",
|
|
||||||
"<",
|
|
||||||
"≥",
|
|
||||||
"≤",
|
|
||||||
# for time
|
|
||||||
"before",
|
|
||||||
"after",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
provider: str
|
provider: str
|
||||||
name: str
|
name: str
|
||||||
@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
|
|||||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class Condition(BaseModel):
|
|
||||||
"""
|
|
||||||
Condition detail
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
comparison_operator: SupportedComparisonOperator
|
|
||||||
value: str | Sequence[str] | None | int | float = None
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataFilteringCondition(BaseModel):
|
|
||||||
"""
|
|
||||||
Metadata Filtering Condition.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logical_operator: Literal["and", "or"] | None = "and"
|
|
||||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieveConfigEntity(BaseModel):
|
class DatasetRetrieveConfigEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
Dataset Retrieve Config Entity.
|
Dataset Retrieve Config Entity.
|
||||||
|
|||||||
@ -107,13 +107,13 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _error_to_stream_response(cls, e: Exception):
|
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Error to stream response.
|
Error to stream response.
|
||||||
:param e: exception
|
:param e: exception
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
error_responses = {
|
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||||
ValueError: {"code": "invalid_param", "status": 400},
|
ValueError: {"code": "invalid_param", "status": 400},
|
||||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||||
QuotaExceededError: {
|
QuotaExceededError: {
|
||||||
@ -127,7 +127,7 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Determine the response based on the type of exception
|
# Determine the response based on the type of exception
|
||||||
data = None
|
data: dict[str, Any] | None = None
|
||||||
for k, v in error_responses.items():
|
for k, v in error_responses.items():
|
||||||
if isinstance(e, k):
|
if isinstance(e, k):
|
||||||
data = v
|
data = v
|
||||||
|
|||||||
@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||||
from core.workflow.system_variables import (
|
from core.workflow.system_variables import (
|
||||||
build_bootstrap_variables,
|
build_bootstrap_variables,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(StrEnum):
|
class QueueEvent(StrEnum):
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
|
|
||||||
|
|
||||||
class AnnotationReplyAccount(BaseModel):
|
class AnnotationReplyAccount(BaseModel):
|
||||||
|
|||||||
@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
agent_thought: MessageAgentThought | None = (
|
agent_thought: MessageAgentThought | None = session.scalar(
|
||||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if agent_thought:
|
if agent_thought:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from sqlalchemy import select, update
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import RetrievalSourceMetadata
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|||||||
@ -345,8 +345,8 @@ class DatasourceManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
upload_file = (
|
upload_file = session.scalar(
|
||||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
|
||||||
)
|
)
|
||||||
if not upload_file:
|
if not upload_file:
|
||||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||||
|
|||||||
@ -1,22 +1,3 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||||
|
|
||||||
|
__all__ = ["I18nObject", "I18nObjectDict"]
|
||||||
class I18nObject(BaseModel):
|
|
||||||
"""
|
|
||||||
Model class for i18n object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
en_US: str
|
|
||||||
zh_Hans: str | None = Field(default=None)
|
|
||||||
pt_BR: str | None = Field(default=None)
|
|
||||||
ja_JP: str | None = Field(default=None)
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def _(self):
|
|
||||||
self.zh_Hans = self.zh_Hans or self.en_US
|
|
||||||
self.pt_BR = self.pt_BR or self.en_US
|
|
||||||
self.ja_JP = self.ja_JP or self.en_US
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from yarl import URL
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.plugin.entities.oauth import OAuthSchema
|
from core.plugin.entities import OAuthSchema
|
||||||
from core.plugin.entities.parameters import (
|
from core.plugin.entities.parameters import (
|
||||||
PluginParameter,
|
PluginParameter,
|
||||||
PluginParameterOption,
|
PluginParameterOption,
|
||||||
|
|||||||
@ -1 +1,8 @@
|
|||||||
|
from core.entities.plugin_credential_type import PluginCredentialType
|
||||||
|
|
||||||
DEFAULT_PLUGIN_ID = "langgenius"
|
DEFAULT_PLUGIN_ID = "langgenius"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_PLUGIN_ID",
|
||||||
|
"PluginCredentialType",
|
||||||
|
]
|
||||||
|
|||||||
9
api/core/entities/plugin_credential_type.py
Normal file
9
api/core/entities/plugin_credential_type.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
|
||||||
|
class PluginCredentialType(enum.Enum):
|
||||||
|
MODEL = 0 # must be 0 for API contract compatibility
|
||||||
|
TOOL = 1 # must be 1 for API contract compatibility
|
||||||
|
|
||||||
|
def to_number(self):
|
||||||
|
return self.value
|
||||||
@ -22,6 +22,7 @@ from sqlalchemy import func, select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
|
from core.entities import PluginCredentialType
|
||||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||||
from core.entities.provider_entities import (
|
from core.entities.provider_entities import (
|
||||||
CustomConfiguration,
|
CustomConfiguration,
|
||||||
@ -46,7 +47,6 @@ from models.provider import (
|
|||||||
TenantPreferredModelProvider,
|
TenantPreferredModelProvider,
|
||||||
)
|
)
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
Credential utility functions for checking credential existence and policy compliance.
|
Credential utility functions for checking credential existence and policy compliance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
from core.entities import PluginCredentialType
|
||||||
|
|
||||||
|
|
||||||
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
|
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Protocol, cast
|
from typing import Protocol, TypedDict, cast
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||||
@ -49,6 +49,17 @@ class WorkflowServiceInterface(Protocol):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CodeGenerateResultDict(TypedDict):
|
||||||
|
code: str
|
||||||
|
language: str
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutputResultDict(TypedDict):
|
||||||
|
output: str
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
class LLMGenerator:
|
class LLMGenerator:
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_conversation_name(
|
def generate_conversation_name(
|
||||||
@ -293,7 +304,7 @@ class LLMGenerator:
|
|||||||
cls,
|
cls,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
args: RuleCodeGeneratePayload,
|
args: RuleCodeGeneratePayload,
|
||||||
):
|
) -> CodeGenerateResultDict:
|
||||||
if args.code_language == "python":
|
if args.code_language == "python":
|
||||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||||
else:
|
else:
|
||||||
@ -362,7 +373,9 @@ class LLMGenerator:
|
|||||||
return answer.strip()
|
return answer.strip()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
|
def generate_structured_output(
|
||||||
|
cls, tenant_id: str, args: RuleStructuredOutputPayload
|
||||||
|
) -> StructuredOutputResultDict:
|
||||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||||
model_instance = model_manager.get_model_instance(
|
model_instance = model_manager.get_model_instance(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -454,7 +467,7 @@ class LLMGenerator:
|
|||||||
):
|
):
|
||||||
session = db.session()
|
session = db.session()
|
||||||
|
|
||||||
app: App | None = session.query(App).where(App.id == flow_id).first()
|
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError("App not found.")
|
raise ValueError("App not found.")
|
||||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import logging
|
|||||||
import flask
|
import flask
|
||||||
|
|
||||||
from core.logging.context import get_request_id, get_trace_id
|
from core.logging.context import get_request_id, get_trace_id
|
||||||
|
from core.logging.structured_formatter import IdentityDict
|
||||||
|
|
||||||
|
|
||||||
class TraceContextFilter(logging.Filter):
|
class TraceContextFilter(logging.Filter):
|
||||||
@ -60,7 +61,7 @@ class IdentityContextFilter(logging.Filter):
|
|||||||
record.user_type = identity.get("user_type", "")
|
record.user_type = identity.get("user_type", "")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _extract_identity(self) -> dict[str, str]:
|
def _extract_identity(self) -> IdentityDict:
|
||||||
"""Extract identity from current_user if in request context."""
|
"""Extract identity from current_user if in request context."""
|
||||||
try:
|
try:
|
||||||
if not flask.has_request_context():
|
if not flask.has_request_context():
|
||||||
@ -77,7 +78,7 @@ class IdentityContextFilter(logging.Filter):
|
|||||||
from models import Account
|
from models import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
identity: dict[str, str] = {}
|
identity: IdentityDict = {}
|
||||||
|
|
||||||
if isinstance(user, Account):
|
if isinstance(user, Account):
|
||||||
if user.current_tenant_id:
|
if user.current_tenant_id:
|
||||||
|
|||||||
@ -3,13 +3,19 @@
|
|||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityDict(TypedDict, total=False):
|
||||||
|
tenant_id: str
|
||||||
|
user_id: str
|
||||||
|
user_type: str
|
||||||
|
|
||||||
|
|
||||||
class StructuredJSONFormatter(logging.Formatter):
|
class StructuredJSONFormatter(logging.Formatter):
|
||||||
"""
|
"""
|
||||||
JSON log formatter following the specified schema:
|
JSON log formatter following the specified schema:
|
||||||
@ -84,7 +90,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
|||||||
|
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
|
||||||
tenant_id = getattr(record, "tenant_id", None)
|
tenant_id = getattr(record, "tenant_id", None)
|
||||||
user_id = getattr(record, "user_id", None)
|
user_id = getattr(record, "user_id", None)
|
||||||
user_type = getattr(record, "user_type", None)
|
user_type = getattr(record, "user_type", None)
|
||||||
@ -92,7 +98,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
|||||||
if not any([tenant_id, user_id, user_type]):
|
if not any([tenant_id, user_id, user_type]):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
identity: dict[str, str] = {}
|
identity: IdentityDict = {}
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
identity["tenant_id"] = tenant_id
|
identity["tenant_id"] = tenant_id
|
||||||
if user_id:
|
if user_id:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, cast
|
from typing import Any, NotRequired, TypedDict, cast
|
||||||
|
|
||||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||||
|
|
||||||
@ -15,6 +15,17 @@ from services.app_generate_service import AppGenerateService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParameterSchemaDict(TypedDict):
|
||||||
|
type: str
|
||||||
|
properties: dict[str, Any]
|
||||||
|
required: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolArgumentsDict(TypedDict):
|
||||||
|
query: NotRequired[str]
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
def handle_mcp_request(
|
def handle_mcp_request(
|
||||||
app: App,
|
app: App,
|
||||||
request: mcp_types.ClientRequest,
|
request: mcp_types.ClientRequest,
|
||||||
@ -119,7 +130,7 @@ def handle_list_tools(
|
|||||||
mcp_types.Tool(
|
mcp_types.Tool(
|
||||||
name=app_name,
|
name=app_name,
|
||||||
description=description,
|
description=description,
|
||||||
inputSchema=parameter_schema,
|
inputSchema=cast(dict[str, Any], parameter_schema),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -154,7 +165,7 @@ def build_parameter_schema(
|
|||||||
app_mode: str,
|
app_mode: str,
|
||||||
user_input_form: list[VariableEntity],
|
user_input_form: list[VariableEntity],
|
||||||
parameters_dict: dict[str, str],
|
parameters_dict: dict[str, str],
|
||||||
) -> dict[str, Any]:
|
) -> ToolParameterSchemaDict:
|
||||||
"""Build parameter schema for the tool"""
|
"""Build parameter schema for the tool"""
|
||||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||||
|
|
||||||
@ -174,7 +185,7 @@ def build_parameter_schema(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
|
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
|
||||||
"""Prepare arguments based on app mode"""
|
"""Prepare arguments based on app mode"""
|
||||||
if app.mode == AppMode.WORKFLOW:
|
if app.mode == AppMode.WORKFLOW:
|
||||||
return {"inputs": arguments}
|
return {"inputs": arguments}
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from collections.abc import Callable
|
|||||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Self, cast
|
from typing import Any, Self
|
||||||
|
|
||||||
from httpx import HTTPStatusError
|
from httpx import HTTPStatusError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -338,12 +338,11 @@ class BaseSession[
|
|||||||
validated_request = self._receive_request_type.model_validate(
|
validated_request = self._receive_request_type.model_validate(
|
||||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
)
|
)
|
||||||
validated_request = cast(ReceiveRequestT, validated_request)
|
|
||||||
|
|
||||||
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
||||||
request_id=message.message.root.id,
|
request_id=message.message.root.id,
|
||||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||||
request=validated_request,
|
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
|
||||||
session=self,
|
session=self,
|
||||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||||
)
|
)
|
||||||
@ -359,15 +358,14 @@ class BaseSession[
|
|||||||
notification = self._receive_notification_type.model_validate(
|
notification = self._receive_notification_type.model_validate(
|
||||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
)
|
)
|
||||||
notification = cast(ReceiveNotificationT, notification)
|
|
||||||
# Handle cancellation notifications
|
# Handle cancellation notifications
|
||||||
if isinstance(notification.root, CancelledNotification):
|
if isinstance(notification.root, CancelledNotification):
|
||||||
cancelled_id = notification.root.params.requestId
|
cancelled_id = notification.root.params.requestId
|
||||||
if cancelled_id in self._in_flight:
|
if cancelled_id in self._in_flight:
|
||||||
self._in_flight[cancelled_id].cancel()
|
self._in_flight[cancelled_id].cancel()
|
||||||
else:
|
else:
|
||||||
self._received_notification(notification)
|
self._received_notification(notification) # type: ignore[arg-type]
|
||||||
self._handle_incoming(notification)
|
self._handle_incoming(notification) # type: ignore[arg-type]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# For other validation errors, log and continue
|
# For other validation errors, log and continue
|
||||||
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
|
|||||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.entities import PluginCredentialType
|
||||||
from core.entities.embedding_type import EmbeddingInputType
|
from core.entities.embedding_type import EmbeddingInputType
|
||||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||||
@ -25,7 +26,6 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
|
|||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from graphon.entities import WorkflowNodeExecution
|
from graphon.entities import WorkflowNodeExecution
|
||||||
from graphon.enums import WorkflowNodeExecutionStatus
|
from graphon.enums import WorkflowNodeExecutionStatus
|
||||||
@ -56,10 +56,22 @@ def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
|
|||||||
return links
|
return links
|
||||||
|
|
||||||
|
|
||||||
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
|
class RetrievalDocumentMetadataDict(TypedDict):
|
||||||
documents_data = []
|
dataset_id: Any
|
||||||
|
doc_id: Any
|
||||||
|
document_id: Any
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalDocumentDict(TypedDict):
|
||||||
|
content: str
|
||||||
|
metadata: RetrievalDocumentMetadataDict
|
||||||
|
score: Any
|
||||||
|
|
||||||
|
|
||||||
|
def extract_retrieval_documents(documents: list[Document]) -> list[RetrievalDocumentDict]:
|
||||||
|
documents_data: list[RetrievalDocumentDict] = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
document_data = {
|
document_data: RetrievalDocumentDict = {
|
||||||
"content": document.page_content,
|
"content": document.page_content,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"dataset_id": document.metadata.get("dataset_id"),
|
"dataset_id": document.metadata.get("dataset_id"),
|
||||||
@ -83,7 +95,7 @@ def create_common_span_attributes(
|
|||||||
framework: str = DEFAULT_FRAMEWORK_NAME,
|
framework: str = DEFAULT_FRAMEWORK_NAME,
|
||||||
inputs: str = "",
|
inputs: str = "",
|
||||||
outputs: str = "",
|
outputs: str = "",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
GEN_AI_SESSION_ID: session_id,
|
GEN_AI_SESSION_ID: session_id,
|
||||||
GEN_AI_USER_ID: user_id,
|
GEN_AI_USER_ID: user_id,
|
||||||
|
|||||||
@ -56,8 +56,10 @@ class BaseTraceInstance(ABC):
|
|||||||
if not service_account:
|
if not service_account:
|
||||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||||
|
|
||||||
current_tenant = (
|
current_tenant = session.scalar(
|
||||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
select(TenantAccountJoin)
|
||||||
|
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not current_tenant:
|
if not current_tenant:
|
||||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||||
|
|||||||
@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance):
|
|||||||
if not service_account:
|
if not service_account:
|
||||||
raise ValueError(f"Creator account not found for app {app_id}")
|
raise ValueError(f"Creator account not found for app {app_id}")
|
||||||
|
|
||||||
current_tenant = (
|
current_tenant = session.scalar(
|
||||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
select(TenantAccountJoin)
|
||||||
|
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not current_tenant:
|
if not current_tenant:
|
||||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||||
|
|||||||
5
api/core/plugin/entities/__init__.py
Normal file
5
api/core/plugin/entities/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from core.plugin.entities.oauth import OAuthSchema
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OAuthSchema",
|
||||||
|
]
|
||||||
@ -1,5 +1,3 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
@ -10,12 +8,12 @@ class OAuthSchema(BaseModel):
|
|||||||
OAuth schema
|
OAuth schema
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client_schema: Sequence[ProviderConfig] = Field(
|
client_schema: list[ProviderConfig] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="client schema like client_id, client_secret, etc.",
|
description="client schema like client_id, client_secret, etc.",
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
credentials_schema: list[ProviderConfig] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="credentials schema like access_token, refresh_token, etc.",
|
description="credentials schema like access_token, refresh_token, etc.",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import json
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from graphon.model_runtime.entities.provider_entities import (
|
from graphon.model_runtime.entities.provider_entities import (
|
||||||
@ -15,6 +14,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
|||||||
ProviderEntity,
|
ProviderEntity,
|
||||||
)
|
)
|
||||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -58,6 +58,8 @@ from services.feature_service import FeatureService
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from graphon.model_runtime.runtime import ModelRuntime
|
from graphon.model_runtime.runtime import ModelRuntime
|
||||||
|
|
||||||
|
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||||
|
|
||||||
|
|
||||||
class ProviderManager:
|
class ProviderManager:
|
||||||
"""
|
"""
|
||||||
@ -875,8 +877,8 @@ class ProviderManager:
|
|||||||
return {"openai_api_key": encrypted_config}
|
return {"openai_api_key": encrypted_config}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = cast(dict, json.loads(encrypted_config))
|
credentials = _credentials_adapter.validate_json(encrypted_config)
|
||||||
except JSONDecodeError:
|
except (ValueError, JSONDecodeError):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Decrypt secret variables
|
# Decrypt secret variables
|
||||||
@ -1015,7 +1017,7 @@ class ProviderManager:
|
|||||||
if not cached_provider_credentials:
|
if not cached_provider_credentials:
|
||||||
provider_credentials: dict[str, Any] = {}
|
provider_credentials: dict[str, Any] = {}
|
||||||
if provider_records and provider_records[0].encrypted_config:
|
if provider_records and provider_records[0].encrypted_config:
|
||||||
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
|
||||||
|
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
provider_credential_secret_variables = self._extract_secret_variables(
|
provider_credential_secret_variables = self._extract_secret_variables(
|
||||||
@ -1162,8 +1164,10 @@ class ProviderManager:
|
|||||||
|
|
||||||
if not cached_provider_model_credentials:
|
if not cached_provider_model_credentials:
|
||||||
try:
|
try:
|
||||||
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
|
provider_model_credentials = _credentials_adapter.validate_json(
|
||||||
except JSONDecodeError:
|
load_balancing_model_config.encrypted_config
|
||||||
|
)
|
||||||
|
except (ValueError, JSONDecodeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get decoding rsa key and cipher for decrypting credentials
|
# Get decoding rsa key and cipher for decrypting credentials
|
||||||
@ -1176,7 +1180,7 @@ class ProviderManager:
|
|||||||
if variable in provider_model_credentials:
|
if variable in provider_model_credentials:
|
||||||
try:
|
try:
|
||||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
provider_model_credentials.get(variable),
|
provider_model_credentials.get(variable) or "",
|
||||||
self.decoding_rsa_key,
|
self.decoding_rsa_key,
|
||||||
self.decoding_cipher_rsa,
|
self.decoding_cipher_rsa,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
|
|||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
|
||||||
from core.rag.entities.metadata_entities import MetadataCondition
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||||
from core.rag.index_processor.constant.query_type import QueryType
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
@ -182,7 +182,9 @@ class RetrievalService:
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
return []
|
return []
|
||||||
metadata_condition = (
|
metadata_condition = (
|
||||||
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
|
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||||
|
if metadata_filtering_conditions
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
dataset.tenant_id,
|
dataset.tenant_id,
|
||||||
@ -240,7 +242,7 @@ class RetrievalService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
return session.query(Dataset).where(Dataset.id == dataset_id).first()
|
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def keyword_search(
|
def keyword_search(
|
||||||
@ -573,15 +575,13 @@ class RetrievalService:
|
|||||||
|
|
||||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||||
if summary_segment_ids:
|
if summary_segment_ids:
|
||||||
summaries = (
|
summaries = session.scalars(
|
||||||
session.query(DocumentSegmentSummary)
|
select(DocumentSegmentSummary).where(
|
||||||
.filter(
|
|
||||||
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
||||||
DocumentSegmentSummary.status == "completed",
|
DocumentSegmentSummary.status == "completed",
|
||||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
|
||||||
)
|
)
|
||||||
.all()
|
).all()
|
||||||
)
|
|
||||||
for summary in summaries:
|
for summary in summaries:
|
||||||
if summary.summary_content:
|
if summary.summary_content:
|
||||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||||
@ -851,12 +851,12 @@ class RetrievalService:
|
|||||||
def get_segment_attachment_info(
|
def get_segment_attachment_info(
|
||||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||||
) -> SegmentAttachmentResult | None:
|
) -> SegmentAttachmentResult | None:
|
||||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
|
||||||
if upload_file:
|
if upload_file:
|
||||||
attachment_binding = (
|
attachment_binding = session.scalar(
|
||||||
session.query(SegmentAttachmentBinding)
|
select(SegmentAttachmentBinding)
|
||||||
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
|
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if attachment_binding:
|
if attachment_binding:
|
||||||
attachment_info: AttachmentInfoDict = {
|
attachment_info: AttachmentInfoDict = {
|
||||||
@ -875,14 +875,12 @@ class RetrievalService:
|
|||||||
cls, attachment_ids: list[str], session: Session
|
cls, attachment_ids: list[str], session: Session
|
||||||
) -> list[SegmentAttachmentInfoResult]:
|
) -> list[SegmentAttachmentInfoResult]:
|
||||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
|
||||||
if upload_files:
|
if upload_files:
|
||||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||||
attachment_bindings = (
|
attachment_bindings = session.scalars(
|
||||||
session.query(SegmentAttachmentBinding)
|
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||||
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
).all()
|
||||||
.all()
|
|
||||||
)
|
|
||||||
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
||||||
|
|
||||||
if attachment_bindings:
|
if attachment_bindings:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
@ -13,6 +13,13 @@ from core.rag.models.document import Document
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticdbClientParamsDict(TypedDict):
|
||||||
|
access_key_id: str
|
||||||
|
access_key_secret: str
|
||||||
|
region_id: str
|
||||||
|
read_timeout: int
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||||
access_key_id: str
|
access_key_id: str
|
||||||
access_key_secret: str
|
access_key_secret: str
|
||||||
@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
|||||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def to_analyticdb_client_params(self):
|
def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
|
||||||
return {
|
result: AnalyticdbClientParamsDict = {
|
||||||
"access_key_id": self.access_key_id,
|
"access_key_id": self.access_key_id,
|
||||||
"access_key_secret": self.access_key_secret,
|
"access_key_secret": self.access_key_secret,
|
||||||
"region_id": self.region_id,
|
"region_id": self.region_id,
|
||||||
"read_timeout": self.read_timeout,
|
"read_timeout": self.read_timeout,
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbVectorOpenAPI:
|
class AnalyticdbVectorOpenAPI:
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field as VDBField
|
from core.rag.datasource.vdb.field import Field as VDBField
|
||||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -85,8 +85,12 @@ class BaiduVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.BAIDU
|
return VectorType.BAIDU
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
self._create_table(len(embeddings[0]))
|
self._create_table(len(embeddings[0]))
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb import QueryResult, Settings
|
from chromadb import QueryResult, Settings
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client
|
|||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaParamsDict(TypedDict):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
ssl: bool
|
||||||
|
tenant: str
|
||||||
|
database: str
|
||||||
|
settings: Settings
|
||||||
|
|
||||||
|
|
||||||
class ChromaConfig(BaseModel):
|
class ChromaConfig(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
port: int
|
port: int
|
||||||
@ -23,14 +32,13 @@ class ChromaConfig(BaseModel):
|
|||||||
auth_provider: str | None = None
|
auth_provider: str | None = None
|
||||||
auth_credentials: str | None = None
|
auth_credentials: str | None = None
|
||||||
|
|
||||||
def to_chroma_params(self):
|
def to_chroma_params(self) -> ChromaParamsDict:
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
# auth
|
# auth
|
||||||
chroma_client_auth_provider=self.auth_provider,
|
chroma_client_auth_provider=self.auth_provider,
|
||||||
chroma_client_auth_credentials=self.auth_credentials,
|
chroma_client_auth_credentials=self.auth_credentials,
|
||||||
)
|
)
|
||||||
|
result: ChromaParamsDict = {
|
||||||
return {
|
|
||||||
"host": self.host,
|
"host": self.host,
|
||||||
"port": self.port,
|
"port": self.port,
|
||||||
"ssl": False,
|
"ssl": False,
|
||||||
@ -38,6 +46,7 @@ class ChromaConfig(BaseModel):
|
|||||||
"database": self.database,
|
"database": self.database,
|
||||||
"settings": settings,
|
"settings": settings,
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ChromaVector(BaseVector):
|
class ChromaVector(BaseVector):
|
||||||
@ -145,7 +154,10 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
|||||||
else:
|
else:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
|
index_struct_dict: VectorIndexStructDict = {
|
||||||
|
"type": VectorType.CHROMA,
|
||||||
|
"vector_store": {"class_prefix": collection_name},
|
||||||
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
|
||||||
return ChromaVector(
|
return ChromaVector(
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
@ -20,6 +20,15 @@ from models.dataset import Dataset
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusParamsDict(TypedDict):
|
||||||
|
uri: str
|
||||||
|
token: str | None
|
||||||
|
user: str | None
|
||||||
|
password: str | None
|
||||||
|
db_name: str
|
||||||
|
analyzer_params: str | None
|
||||||
|
|
||||||
|
|
||||||
class MilvusConfig(BaseModel):
|
class MilvusConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Configuration class for Milvus connection.
|
Configuration class for Milvus connection.
|
||||||
@ -50,11 +59,11 @@ class MilvusConfig(BaseModel):
|
|||||||
raise ValueError("config MILVUS_PASSWORD is required")
|
raise ValueError("config MILVUS_PASSWORD is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def to_milvus_params(self):
|
def to_milvus_params(self) -> MilvusParamsDict:
|
||||||
"""
|
"""
|
||||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||||
"""
|
"""
|
||||||
return {
|
result: MilvusParamsDict = {
|
||||||
"uri": self.uri,
|
"uri": self.uri,
|
||||||
"token": self.token,
|
"token": self.token,
|
||||||
"user": self.user,
|
"user": self.user,
|
||||||
@ -62,6 +71,7 @@ class MilvusConfig(BaseModel):
|
|||||||
"db_name": self.database,
|
"db_name": self.database,
|
||||||
"analyzer_params": self.analyzer_params,
|
"analyzer_params": self.analyzer_params,
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class MilvusVector(BaseVector):
|
class MilvusVector(BaseVector):
|
||||||
@ -352,6 +362,7 @@ class MilvusVector(BaseVector):
|
|||||||
|
|
||||||
# Create Index params for the collection
|
# Create Index params for the collection
|
||||||
index_params_obj = IndexParams()
|
index_params_obj = IndexParams()
|
||||||
|
assert index_params is not None
|
||||||
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
|
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
|
||||||
|
|
||||||
# Create Sparse Vector Index for the collection
|
# Create Sparse Vector Index for the collection
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -94,8 +94,12 @@ class QdrantVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.QDRANT
|
return VectorType.QDRANT
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
from tcvdb_text.encoder import BM25Encoder # type: ignore
|
||||||
@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import parse_metadata_json
|
from core.rag.datasource.vdb.field import parse_metadata_json
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -23,6 +23,13 @@ from models.dataset import Dataset
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentParamsDict(TypedDict):
|
||||||
|
url: str
|
||||||
|
username: str | None
|
||||||
|
key: str | None
|
||||||
|
timeout: float
|
||||||
|
|
||||||
|
|
||||||
class TencentConfig(BaseModel):
|
class TencentConfig(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
@ -36,8 +43,14 @@ class TencentConfig(BaseModel):
|
|||||||
max_upsert_batch_size: int = 128
|
max_upsert_batch_size: int = 128
|
||||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||||
|
|
||||||
def to_tencent_params(self):
|
def to_tencent_params(self) -> TencentParamsDict:
|
||||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
result: TencentParamsDict = {
|
||||||
|
"url": self.url,
|
||||||
|
"username": self.username,
|
||||||
|
"key": self.api_key,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
bm25 = BM25Encoder.default("zh")
|
bm25 = BM25Encoder.default("zh")
|
||||||
@ -83,8 +96,12 @@ class TencentVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.TENCENT
|
return VectorType.TENCENT
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def _has_collection(self) -> bool:
|
def _has_collection(self) -> bool:
|
||||||
return bool(
|
return bool(
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from sqlalchemy import select
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -91,8 +91,12 @@ class TidbOnQdrantVector(BaseVector):
|
|||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.TIDB_ON_QDRANT
|
return VectorType.TIDB_ON_QDRANT
|
||||||
|
|
||||||
def to_index_struct(self):
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
|
|||||||
@ -1,11 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStoreDict(TypedDict):
|
||||||
|
class_prefix: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIndexStructDict(TypedDict):
|
||||||
|
type: str
|
||||||
|
vector_store: VectorStoreDict
|
||||||
|
|
||||||
|
|
||||||
class BaseVector(ABC):
|
class BaseVector(ABC):
|
||||||
def __init__(self, collection_name: str):
|
def __init__(self, collection_name: str):
|
||||||
self._collection_name = collection_name
|
self._collection_name = collection_name
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -30,8 +30,11 @@ class AbstractVectorFactory(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
|
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
|
||||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
index_struct_dict: VectorIndexStructDict = {
|
||||||
|
"type": vector_type,
|
||||||
|
"vector_store": {"class_prefix": collection_name},
|
||||||
|
}
|
||||||
return index_struct_dict
|
return index_struct_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from weaviate.exceptions import UnexpectedStatusCodeError
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
@ -184,9 +184,13 @@ class WeaviateVector(BaseVector):
|
|||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
return Dataset.gen_collection_name_by_id(dataset_id)
|
return Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
|
||||||
def to_index_struct(self) -> dict:
|
def to_index_struct(self) -> VectorIndexStructDict:
|
||||||
"""Returns the index structure dictionary for persistence."""
|
"""Returns the index structure dictionary for persistence."""
|
||||||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
|
result: VectorIndexStructDict = {
|
||||||
|
"type": self.get_type(),
|
||||||
|
"vector_store": {"class_prefix": self._collection_name},
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
28
api/core/rag/entities/__init__.py
Normal file
28
api/core/rag/entities/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
|
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
|
||||||
|
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
|
||||||
|
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
|
||||||
|
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
|
||||||
|
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Condition",
|
||||||
|
"DatasourceCompletedEvent",
|
||||||
|
"DatasourceErrorEvent",
|
||||||
|
"DatasourceProcessingEvent",
|
||||||
|
"DocumentContext",
|
||||||
|
"EconomySetting",
|
||||||
|
"EmbeddingSetting",
|
||||||
|
"IndexMethod",
|
||||||
|
"KeywordSetting",
|
||||||
|
"MetadataFilteringCondition",
|
||||||
|
"ParentMode",
|
||||||
|
"PreProcessingRule",
|
||||||
|
"RetrievalSourceMetadata",
|
||||||
|
"Rule",
|
||||||
|
"Segmentation",
|
||||||
|
"SupportedComparisonOperator",
|
||||||
|
"VectorSetting",
|
||||||
|
"WeightedScoreConfig",
|
||||||
|
]
|
||||||
30
api/core/rag/entities/index_entities.py
Normal file
30
api/core/rag/entities/index_entities.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Embedding Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EconomySetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Economy Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_number: int
|
||||||
|
|
||||||
|
|
||||||
|
class IndexMethod(BaseModel):
|
||||||
|
"""
|
||||||
|
Knowledge Index Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indexing_technique: Literal["high_quality", "economy"]
|
||||||
|
embedding_setting: EmbeddingSetting
|
||||||
|
economy_setting: EconomySetting
|
||||||
@ -38,9 +38,9 @@ class Condition(BaseModel):
|
|||||||
value: str | Sequence[str] | None | int | float = None
|
value: str | Sequence[str] | None | int | float = None
|
||||||
|
|
||||||
|
|
||||||
class MetadataCondition(BaseModel):
|
class MetadataFilteringCondition(BaseModel):
|
||||||
"""
|
"""
|
||||||
Metadata Condition.
|
Metadata Filtering Condition.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logical_operator: Literal["and", "or"] | None = "and"
|
logical_operator: Literal["and", "or"] | None = "and"
|
||||||
|
|||||||
27
api/core/rag/entities/processing_entities.py
Normal file
27
api/core/rag/entities/processing_entities.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ParentMode(StrEnum):
|
||||||
|
FULL_DOC = "full-doc"
|
||||||
|
PARAGRAPH = "paragraph"
|
||||||
|
|
||||||
|
|
||||||
|
class PreProcessingRule(BaseModel):
|
||||||
|
id: str
|
||||||
|
enabled: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Segmentation(BaseModel):
|
||||||
|
separator: str = "\n"
|
||||||
|
max_tokens: int
|
||||||
|
chunk_overlap: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Rule(BaseModel):
|
||||||
|
pre_processing_rules: list[PreProcessingRule] | None = None
|
||||||
|
segmentation: Segmentation | None = None
|
||||||
|
parent_mode: Literal["full-doc", "paragraph"] | None = None
|
||||||
|
subchunk_segmentation: Segmentation | None = None
|
||||||
28
api/core/rag/entities/retrieval_settings.py
Normal file
28
api/core/rag/entities/retrieval_settings.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Vector Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_weight: float
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Keyword Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_weight: float
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedScoreConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Weighted score Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_setting: VectorSetting
|
||||||
|
keyword_setting: KeywordSetting
|
||||||
@ -12,7 +12,7 @@ from core.db.session_factory import session_factory
|
|||||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||||
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
|
from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
|
||||||
from .index_processor_factory import IndexProcessorFactory
|
from .index_processor_factory import IndexProcessorFactory
|
||||||
@ -61,7 +61,7 @@ class IndexProcessor:
|
|||||||
chunks: Mapping[str, Any],
|
chunks: Mapping[str, Any],
|
||||||
batch: Any,
|
batch: Any,
|
||||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||||
):
|
) -> IndexingResultDict:
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
document = session.query(Document).filter_by(id=document_id).first()
|
document = session.query(Document).filter_by(id=document_id).first()
|
||||||
if not document:
|
if not document:
|
||||||
@ -129,7 +129,7 @@ class IndexProcessor:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
result: IndexingResultDict = {
|
||||||
"dataset_id": dataset_id,
|
"dataset_id": dataset_id,
|
||||||
"dataset_name": dataset_name_value,
|
"dataset_name": dataset_name_value,
|
||||||
"batch": batch,
|
"batch": batch,
|
||||||
@ -138,6 +138,7 @@ class IndexProcessor:
|
|||||||
"created_at": created_at_value.timestamp(),
|
"created_at": created_at_value.timestamp(),
|
||||||
"display_status": "completed",
|
"display_status": "completed",
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def get_preview_output(
|
def get_preview_output(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword
|
|||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
|
from core.rag.entities import Rule
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
@ -49,7 +50,6 @@ from models.account import Account
|
|||||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
_file_access_controller = DatabaseFileAccessController()
|
_file_access_controller = DatabaseFileAccessController()
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
|||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
|
from core.rag.entities import ParentMode, Rule
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
@ -30,7 +31,6 @@ from models import Account
|
|||||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict
|
|||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
|
from core.rag.entities import Rule
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||||
@ -30,7 +31,6 @@ from libs import helper
|
|||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import Dataset, DocumentSegment
|
from models.dataset import Dataset, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
|
||||||
from services.summary_index_service import SummaryIndexService
|
from services.summary_index_service import SummaryIndexService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -1,16 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.rag.entities import KeywordSetting, VectorSetting
|
||||||
class VectorSetting(BaseModel):
|
|
||||||
vector_weight: float
|
|
||||||
|
|
||||||
embedding_provider_name: str
|
|
||||||
|
|
||||||
embedding_model_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class KeywordSetting(BaseModel):
|
|
||||||
keyword_weight: float
|
|
||||||
|
|
||||||
|
|
||||||
class Weights(BaseModel):
|
class Weights(BaseModel):
|
||||||
|
|||||||
@ -39,9 +39,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
|
||||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
|
||||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||||
from core.rag.index_processor.constant.query_type import QueryType
|
from core.rag.index_processor.constant.query_type import QueryType
|
||||||
@ -604,7 +602,7 @@ class DatasetRetrieval:
|
|||||||
planning_strategy: PlanningStrategy,
|
planning_strategy: PlanningStrategy,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataFilteringCondition | None = None,
|
||||||
):
|
):
|
||||||
tools = []
|
tools = []
|
||||||
for dataset in available_datasets:
|
for dataset in available_datasets:
|
||||||
@ -743,7 +741,7 @@ class DatasetRetrieval:
|
|||||||
reranking_enable: bool = True,
|
reranking_enable: bool = True,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
metadata_filter_document_ids: dict[str, list[str]] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataFilteringCondition | None = None,
|
||||||
attachment_ids: list[str] | None = None,
|
attachment_ids: list[str] | None = None,
|
||||||
):
|
):
|
||||||
if not available_datasets:
|
if not available_datasets:
|
||||||
@ -1063,7 +1061,7 @@ class DatasetRetrieval:
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
all_documents: list[Document],
|
all_documents: list[Document],
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
metadata_condition: MetadataCondition | None = None,
|
metadata_condition: MetadataFilteringCondition | None = None,
|
||||||
attachment_ids: list[str] | None = None,
|
attachment_ids: list[str] | None = None,
|
||||||
):
|
):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
@ -1339,7 +1337,7 @@ class DatasetRetrieval:
|
|||||||
metadata_model_config: ModelConfig,
|
metadata_model_config: ModelConfig,
|
||||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
|
||||||
document_query = select(DatasetDocument).where(
|
document_query = select(DatasetDocument).where(
|
||||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||||
DatasetDocument.indexing_status == "completed",
|
DatasetDocument.indexing_status == "completed",
|
||||||
@ -1371,7 +1369,7 @@ class DatasetRetrieval:
|
|||||||
value=filter.get("value"),
|
value=filter.get("value"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
metadata_condition = MetadataCondition(
|
metadata_condition = MetadataFilteringCondition(
|
||||||
logical_operator=metadata_filtering_conditions.logical_operator
|
logical_operator=metadata_filtering_conditions.logical_operator
|
||||||
if metadata_filtering_conditions
|
if metadata_filtering_conditions
|
||||||
else "or", # type: ignore
|
else "or", # type: ignore
|
||||||
@ -1400,7 +1398,7 @@ class DatasetRetrieval:
|
|||||||
expected_value,
|
expected_value,
|
||||||
filters,
|
filters,
|
||||||
)
|
)
|
||||||
metadata_condition = MetadataCondition(
|
metadata_condition = MetadataFilteringCondition(
|
||||||
logical_operator=metadata_filtering_conditions.logical_operator,
|
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
)
|
)
|
||||||
@ -1723,7 +1721,7 @@ class DatasetRetrieval:
|
|||||||
self,
|
self,
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
available_datasets: list[Dataset],
|
available_datasets: list[Dataset],
|
||||||
metadata_condition: MetadataCondition | None,
|
metadata_condition: MetadataFilteringCondition | None,
|
||||||
metadata_filter_document_ids: dict[str, list[str]] | None,
|
metadata_filter_document_ids: dict[str, list[str]] | None,
|
||||||
all_documents: list[Document],
|
all_documents: list[Document],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user