Merge branch 'feat/create-app' into deploy/dev

This commit is contained in:
CodingOnStar 2026-04-13 14:39:50 +08:00
commit 35bbf702ed
545 changed files with 33700 additions and 10910 deletions

100
.github/dependabot.yml vendored
View File

@ -1,106 +1,6 @@
version: 2
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 10

View File

@ -7,6 +7,7 @@
## Summary
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
## Screenshots
@ -17,7 +18,7 @@
## Checklist
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [ ] I've updated the documentation accordingly.
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods

View File

@ -54,7 +54,7 @@ jobs:
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Upload unit coverage data
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: api-coverage-unit
path: coverage-unit
@ -129,7 +129,7 @@ jobs:
api/tests/test_containers_integration_tests
- name: Upload integration coverage data
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: api-coverage-integration
path: coverage-integration

View File

@ -81,7 +81,7 @@ jobs:
- name: Build Docker image
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
@ -101,7 +101,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*

View File

@ -50,7 +50,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Build Docker Image
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
push: false
context: ${{ matrix.context }}

View File

@ -21,7 +21,7 @@ jobs:
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@ -49,7 +49,7 @@ jobs:
run: unzip -o pyrefly_diff.zip
- name: Post comment
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -66,7 +66,7 @@ jobs:
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: pyrefly_diff
path: |
@ -75,7 +75,7 @@ jobs:
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@ -0,0 +1,118 @@
name: Comment with Pyrefly Type Coverage
on:
workflow_run:
workflows:
- Pyrefly Type Coverage
types:
- completed
permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
pull-requests: write
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Checkout default branch (trusted code)
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Download type coverage artifact
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }},
});
const match = artifacts.data.artifacts.find((artifact) =>
artifact.name === 'pyrefly_type_coverage'
);
if (!match) {
throw new Error('pyrefly_type_coverage artifact not found');
}
const download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: match.id,
archive_format: 'zip',
});
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
- name: Unzip artifact
run: unzip -o pyrefly_type_coverage.zip
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} > /tmp/type_coverage_comment.md
- name: Post comment
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
let prNumber = null;
try {
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
} catch (err) {
const prs = context.payload.workflow_run.pull_requests || [];
if (prs.length > 0 && prs[0].number) {
prNumber = prs[0].number;
}
}
if (!prNumber) {
throw new Error('PR number not found in artifact or workflow_run payload');
}
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const marker = '### Pyrefly Type Coverage';
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -0,0 +1,120 @@
name: Pyrefly Type Coverage
on:
pull_request:
paths:
- 'api/**/*.py'
permissions:
contents: read
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly report on PR branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
echo '{}' > /tmp/pyrefly_report_pr.json
- name: Save helper script from base branch
run: |
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly report on base branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
echo '{}' > /tmp/pyrefly_report_base.json
- name: Generate coverage comparison
id: coverage
run: |
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
--base /tmp/pyrefly_report_base.json \
< /tmp/pyrefly_report_pr.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload type coverage artifact
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: pyrefly_type_coverage
path: |
pr_report.json
base_report.json
pr_number.txt
- name: Comment PR with type coverage
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const marker = '### Pyrefly Type Coverage';
let body;
try {
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
} catch {
body = `${marker}\n\n_Coverage report unavailable._`;
}
const prNumber = context.payload.pull_request.number;
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -23,8 +23,8 @@ jobs:
days-before-issue-stale: 15
days-before-issue-close: 3
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'
any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'

View File

@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -56,7 +56,7 @@ jobs:
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
env:
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}

View File

@ -53,7 +53,7 @@ jobs:
- name: Upload Cucumber report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: cucumber-report
path: e2e/cucumber-report
@ -61,7 +61,7 @@ jobs:
- name: Upload E2E logs
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: e2e-logs
path: e2e/.logs

View File

@ -43,7 +43,7 @@ jobs:
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*

View File

@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@ -2,7 +2,6 @@ import base64
import secrets
import click
from sqlalchemy.orm import sessionmaker
from constants.languages import languages
from extensions.ext_database import db
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = db.session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
account = db.session.merge(account)
account.email = normalized_new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command("create-tenant", help="Create account and tenant.")

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Literal
from typing import Any, Literal, TypedDict
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
@ -107,6 +107,17 @@ class KeywordStoreConfig(BaseSettings):
)
class SQLAlchemyEngineOptionsDict(TypedDict):
pool_size: int
max_overflow: int
pool_recycle: int
pool_pre_ping: bool
connect_args: dict[str, str]
pool_use_lifo: bool
pool_reset_on_return: None
pool_timeout: int
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
@ -209,11 +220,11 @@ class DatabaseConfig(BaseSettings):
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
connect_args = {}
connect_args: dict[str, str] = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC"
@ -223,7 +234,7 @@ class DatabaseConfig(BaseSettings):
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
result: SQLAlchemyEngineOptionsDict = {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
@ -233,6 +244,7 @@ class DatabaseConfig(BaseSettings):
"pool_reset_on_return": None,
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
}
return result
class CeleryConfig(DatabaseConfig):

View File

@ -25,7 +25,13 @@ from fields.annotation_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
UpdateAnnotationArgs,
UpdateAnnotationSettingArgs,
UpsertAnnotationArgs,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
enable_args: EnableAnnotationArgs = {
"score_threshold": args.score_threshold,
"embedding_provider_name": args.embedding_provider_name,
"embedding_model_name": args.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
return result, 200
@ -237,8 +249,16 @@ class AnnotationApi(Resource):
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
upsert_args: UpsertAnnotationArgs = {}
if args.answer is not None:
upsert_args["answer"] = args.answer
if args.content is not None:
upsert_args["content"] = args.content
if args.message_id is not None:
upsert_args["message_id"] = args.message_id
if args.question is not None:
upsert_args["question"] = args.question
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
update_args: UpdateAnnotationArgs = {}
if args.answer is not None:
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required

View File

@ -1,9 +1,11 @@
import json
from typing import cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -18,30 +20,30 @@ from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
class ModelConfigRequest(BaseModel):
provider: str | None = Field(default=None, description="Model provider")
model: str | None = Field(default=None, description="Model name")
configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters")
opening_statement: str | None = Field(default=None, description="Opening statement")
suggested_questions: list[str] | None = Field(default=None, description="Suggested questions")
more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration")
speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration")
text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration")
retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration")
tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools")
dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations")
agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration")
register_schema_models(console_ns, ModelConfigRequest)
@console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource):
@console_ns.doc("update_app_model_config")
@console_ns.doc(description="Update application model configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"ModelConfigRequest",
{
"provider": fields.String(description="Model provider"),
"model": fields.String(description="Model name"),
"configs": fields.Raw(description="Model configuration parameters"),
"opening_statement": fields.String(description="Opening statement"),
"suggested_questions": fields.List(fields.String(), description="Suggested questions"),
"more_like_this": fields.Raw(description="More like this configuration"),
"speech_to_text": fields.Raw(description="Speech to text configuration"),
"text_to_speech": fields.Raw(description="Text to speech configuration"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"tools": fields.List(fields.Raw(), description="Available tools"),
"dataset_configs": fields.Raw(description="Dataset configurations"),
"agent_mode": fields.Raw(description="Agent mode configuration"),
},
)
)
@console_ns.expect(console_ns.models[ModelConfigRequest.__name__])
@console_ns.response(200, "Model configuration updated successfully")
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import Any, TypedDict
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -86,7 +86,14 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
return value_type.exposed_type().value
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
class FullContentDict(TypedDict):
size_bytes: int | None
value_type: str
length: int | None
download_url: str
def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None:
"""Serialize full_content information for large variables."""
if not variable.is_truncated():
return None
@ -94,12 +101,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
variable_file = variable.variable_file
assert variable_file is not None
return {
result: FullContentDict = {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
return result
def _ensure_variable_access(

View File

@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models import Account
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
email = register_data.get("email", "")
normalized_email = email.lower()
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
token = AccountService.send_reset_password_email(
account=account,
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AccountNotFound()
if account:
account = db.session.merge(account)
self._update_existing_account(account, password_hashed, salt)
db.session.commit()
else:
raise AccountNotFound()
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
def _update_existing_account(self, account, password_hashed, salt):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()

View File

@ -4,7 +4,6 @@ import urllib.parse
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
return account

View File

@ -4,10 +4,11 @@ from datetime import UTC, datetime, timedelta
from typing import Literal
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
@ -15,8 +16,6 @@ from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
@ -27,8 +26,7 @@ class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload)
@console_ns.route("/billing/subscription")
@ -61,12 +59,7 @@ class PartnerTenants(Resource):
@console_ns.doc("sync_partner_tenants_bindings")
@console_ns.doc(description="Sync partner tenants bindings")
@console_ns.doc(params={"partner_key": "Partner key"})
@console_ns.expect(
console_ns.model(
"SyncPartnerTenantsBindingsRequest",
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
)
)
@console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__])
@console_ns.response(200, "Tenants synced to partner successfully")
@console_ns.response(400, "Invalid partner information")
@setup_required

View File

@ -162,7 +162,9 @@ class DataSourceApi(Resource):
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
@ -222,11 +224,11 @@ class DataSourceNotionListApi(Resource):
raise ValueError("Dataset is not notion type.")
documents = session.scalars(
select(Document).filter_by(
dataset_id=query.dataset_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
select(Document).where(
Document.dataset_id == query.dataset_id,
Document.tenant_id == current_tenant_id,
Document.data_source_type == "notion_import",
Document.enabled.is_(True),
)
).all()
if documents:

View File

@ -280,7 +280,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)

View File

@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id
external_knowledge_api_id, current_tenant_id
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200

View File

@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=True
# )
#
# return result
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=False
# )
#
# return result
#
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])

View File

@ -7,7 +7,8 @@ import logging
from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
selected_action_id=payload.action,
form_data=payload.inputs,
submission_user_id=current_user.id,
)

View File

@ -8,7 +8,6 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import supported_language
@ -562,8 +561,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email
else:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
if account is None:
raise AccountNotFound()
email_for_sending = account.email

View File

@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def plugin_data[**P, R](
view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@ -116,7 +115,4 @@ def plugin_data[**P, R](
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
return decorator

View File

@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from models.model import App
from services.annotation_service import AppAnnotationService
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
InsertAnnotationArgs,
UpdateAnnotationArgs,
)
class AnnotationCreatePayload(BaseModel):
@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
enable_args: EnableAnnotationArgs = {
"score_threshold": payload.score_threshold,
"embedding_provider_name": payload.embedding_provider_name,
"embedding_model_name": payload.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
@validate_app_token
def post(self, app_model: App):
"""Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")

View File

@ -527,7 +527,7 @@ class DocumentListApi(DatasetApiResource):
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id)
if query_params.status:
query = DocumentService.apply_display_status_filter(query, query_params.status)

View File

@ -3,7 +3,6 @@ import secrets
from flask import request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
account = AccountService.get_account_by_email_with_case_fallback(request_email)
if account is None:
raise AuthenticationFailedError()
else:
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
self._update_existing_account(account, password_hashed, salt)
else:
raise AuthenticationFailedError()
if account:
account = db.session.merge(account)
self._update_existing_account(account, password_hashed, salt)
db.session.commit()
else:
raise AuthenticationFailedError()
return {"result": "success"}

View File

@ -7,7 +7,8 @@ import logging
from datetime import datetime
from flask import Response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -112,10 +119,7 @@ class HumanInputFormApi(Resource):
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
ip_address = extract_remote_ip(request)
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
@ -135,8 +139,8 @@ class HumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=None,
# submission_end_user_id=_end_user.id,
)

View File

@ -2,7 +2,7 @@ import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from typing import Any, TypedDict
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import (
@ -29,6 +29,13 @@ from models.model import Message
logger = logging.getLogger(__name__)
class ActionDict(TypedDict):
"""Shape produced by AgentScratchpadUnit.Action.to_dict()."""
action: str
action_input: dict[str, Any] | str
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
@ -331,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from graphon.file import File, FileUploadConfig
from graphon.model_runtime.entities.model_entities import AIModelEntity
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):

View File

@ -1,10 +1,10 @@
from typing import Literal, Optional
from typing import Any, Literal, TypedDict
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
class DatasourceApiEntity(BaseModel):
@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel):
output_schema: dict | None = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
class DatasourceProviderApiEntityDict(TypedDict):
id: str
author: str
name: str
plugin_id: str | None
plugin_unique_identifier: str | None
description: I18nObjectDict
icon: str | dict
label: I18nObjectDict
type: str
team_credentials: dict | None
is_team_authorization: bool
allow_delete: bool
datasources: list[Any]
labels: list[str]
class DatasourceProviderApiEntity(BaseModel):
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
def to_dict(self) -> DatasourceProviderApiEntityDict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
parameter["type"] = "files"
# -------------
return {
result: DatasourceProviderApiEntityDict = {
"id": self.id,
"author": self.author,
"name": self.name,
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
"datasources": datasources,
"labels": self.labels,
}
return result

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import enum
from enum import StrEnum
from typing import Any
from typing import Any, TypedDict
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from yarl import URL
@ -179,6 +179,12 @@ class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
datasources: list[DatasourceEntity] = Field(default_factory=list)
class DatasourceInvokeMetaDict(TypedDict):
time_cost: float
error: str | None
tool_config: dict[str, Any] | None
class DatasourceInvokeMeta(BaseModel):
"""
Datasource invoke meta
@ -202,12 +208,13 @@ class DatasourceInvokeMeta(BaseModel):
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
def to_dict(self) -> DatasourceInvokeMetaDict:
result: DatasourceInvokeMetaDict = {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
return result
class DatasourceLabel(BaseModel):

View File

@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
if not isinstance(message.message.blob, bytes):
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager()
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
user_id=user_id,

View File

@ -32,9 +32,9 @@ class Extensible:
name: str
tenant_id: str
config: dict | None = None
config: dict[str, Any] | None = None
def __init__(self, tenant_id: str, config: dict | None = None):
def __init__(self, tenant_id: str, config: dict[str, Any] | None = None):
self.tenant_id = tenant_id
self.config = config

View File

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any, TypedDict
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
@ -7,6 +10,16 @@ from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class ApiToolConfig(TypedDict, total=False):
"""Expected config shape for ApiExternalDataTool.
Not used directly in method signatures (base class accepts dict[str, Any]);
kept here to document the keys this tool reads from config.
"""
api_based_extension_id: str
class ApiExternalDataTool(ExternalDataTool):
"""
The api external data tool.
@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -1,4 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from core.extension.extensible import Extensible, ExtensionModule
@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC):
variable: str
"""the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None):
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
self.variable = variable
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -3,7 +3,7 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, TypedDict
from typing import Any, NotRequired, TypedDict
import orjson
@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False):
user_type: str
class LogDict(TypedDict):
ts: str
severity: str
service: str
caller: str
message: str
trace_id: NotRequired[str]
span_id: NotRequired[str]
identity: NotRequired[IdentityDict]
attributes: NotRequired[dict[str, Any]]
stack_trace: NotRequired[str]
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter):
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
def _build_log_dict(self, record: logging.LogRecord) -> LogDict:
# Core fields
log_dict: dict[str, Any] = {
log_dict: LogDict = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,

View File

@ -146,7 +146,7 @@ def discover_protected_resource_metadata(
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL
return None
@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata(
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL
return None
@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else:
return False, ""
return False, ""
except RequestError:
except (RequestError, json.JSONDecodeError, IndexError):
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""

View File

@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Execute a function with authentication retry logic.

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, TypeVar
from typing import Any
from pydantic import BaseModel
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
class RequestContext[SessionT: BaseSession, LifespanContextT]:
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request: ReceiveRequestT
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
def __init__(
self,
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
):
self.request_id = request_id
self.request_meta = request_meta

View File

@ -31,7 +31,6 @@ ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
type AnyFunction = Callable[..., Any]
class RequestParams(BaseModel):

View File

@ -61,27 +61,28 @@ class TokenBufferMemory:
:param is_user_message: whether this is a user message
:return: PromptMessage
"""
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
match self.conversation.mode:
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
case _:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.HIGH
if file_extra_config and app_record:

View File

@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -172,10 +172,10 @@ class ModelInstance:
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
tools=tools,
stop=stop,
tools=list(tools) if tools else None,
stop=list(stop) if stop else None,
stream=stream,
callbacks=callbacks,
),
@ -193,15 +193,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
@ -216,15 +213,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
@ -241,15 +235,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
@ -261,14 +252,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
)
def invoke_rerank(
@ -289,23 +277,20 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@ -320,17 +305,14 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_moderation(self, text: str) -> bool:
@ -342,14 +324,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return cast(
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
)
def invoke_speech2text(self, file: IO[bytes]) -> str:
@ -361,14 +340,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return cast(
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
)
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
@ -381,18 +357,15 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return cast(
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
@ -430,9 +403,8 @@ class ModelInstance:
continue
try:
if "credentials" in kwargs:
del kwargs["credentials"]
return function(*args, **kwargs, credentials=lb_config.credentials)
kwargs["credentials"] = lb_config.credentials
return function(*args, **kwargs)
except InvokeRateLimitError as e:
# expire in 60 seconds
self.load_balancing_manager.cooldown(lb_config, expire=60)

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
)
@classmethod
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
def _get_user(cls, user_id: str) -> EndUser | Account:
"""
get the user by user id
"""

View File

@ -2,7 +2,7 @@ import json
import os
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypedDict, cast
from graphon.file import file_manager
from graphon.model_runtime.entities.message_entities import (
@ -34,6 +34,13 @@ class ModelMode(StrEnum):
prompt_file_contents: dict[str, Any] = {}
class PromptTemplateConfigDict(TypedDict):
prompt_template: PromptTemplateParser
custom_variable_keys: list[str]
special_variable_keys: list[str]
prompt_rules: dict[str, Any]
class SimplePromptTransform(PromptTransform):
"""
Simple Prompt Transform for Chatbot App Basic Mode.
@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform):
with_memory_prompt=histories is not None,
)
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
custom_variable_keys = prompt_template_config["custom_variable_keys"]
if not isinstance(custom_variable_keys, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}")
# Type check for custom_variable_keys
if not isinstance(custom_variable_keys_obj, list):
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
# Type check for special_variable_keys
if not isinstance(special_variable_keys_obj, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
special_variable_keys = cast(list[str], special_variable_keys_obj)
special_variable_keys = prompt_template_config["special_variable_keys"]
if not isinstance(special_variable_keys, list):
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}")
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform):
has_context: bool,
query_in_prompt: bool,
with_memory_prompt: bool = False,
) -> dict[str, object]:
) -> PromptTemplateConfigDict:
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
custom_variable_keys: list[str] = []
@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform):
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
special_variable_keys.append("#query#")
return {
result: PromptTemplateConfigDict = {
"prompt_template": PromptTemplateParser(template=prompt),
"custom_variable_keys": custom_variable_keys,
"special_variable_keys": special_variable_keys,
"prompt_rules": prompt_rules,
}
return result
def _get_chat_model_prompt_messages(
self,

View File

@ -5,6 +5,7 @@ from typing import Any
from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -19,6 +20,16 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
hosts: list[str]
verify_certs: bool
ssl_show_warn: bool
request_timeout: int
retry_on_timeout: bool
max_retries: int
basic_auth: tuple[str, str]
def create_ssl_context() -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
raise ValueError("config HOSTS is required")
return values
def to_elasticsearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts.split(","),
"verify_certs": False,
"ssl_show_warn": False,
"request_timeout": 30000,
"retry_on_timeout": True,
"max_retries": 10,
}
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
params = HuaweiElasticsearchParamsDict(
hosts=self.hosts.split(","),
verify_certs=False,
ssl_show_warn=False,
request_timeout=30000,
retry_on_timeout=True,
max_retries=10,
)
if self.username and self.password:
params["basic_auth"] = (self.username, self.password)
return params

View File

@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormOpenSearchParamsDict(TypedDict, total=False):
hosts: str | None
use_ssl: bool
pool_maxsize: int
timeout: int
http_auth: tuple[str, str]
class LindormVectorStoreConfig(BaseModel):
hosts: str | None
username: str | None = None
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
raise ValueError("config PASSWORD is required")
return values
def to_opensearch_params(self) -> dict[str, Any]:
params: dict[str, Any] = {
"hosts": self.hosts,
"use_ssl": False,
"pool_maxsize": 128,
"timeout": 30,
}
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
params = LindormOpenSearchParamsDict(
hosts=self.hosts,
use_ssl=False,
pool_maxsize=128,
timeout=30,
)
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params

View File

@ -6,6 +6,7 @@ from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
@ -21,6 +22,20 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class _OpenSearchHostDict(TypedDict):
host: str
port: int
class OpenSearchParamsDict(TypedDict, total=False):
hosts: list[_OpenSearchHostDict]
use_ssl: bool
verify_certs: bool
connection_class: type
pool_maxsize: int
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
class OpenSearchConfig(BaseModel):
host: str
port: int
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.verify_certs,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
def to_opensearch_params(self) -> OpenSearchParamsDict:
params = OpenSearchParamsDict(
hosts=[{"host": self.host, "port": self.port}],
use_ssl=self.secure,
verify_certs=self.verify_certs,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement)
create_statement = sql_text(f"""
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
self.delete_by_uuids(ids)
def delete(self):
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self.client) as session:

View File

@ -6,7 +6,7 @@ import sqlalchemy
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
with sessionmaker(bind=self._engine).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id CHAR(36) PRIMARY KEY,
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
);
""")
session.execute(create_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
return []
def delete(self):
with Session(self._engine) as session:
with sessionmaker(bind=self._engine).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()
def _get_distance_func(self) -> str:
match self._distance_func:

View File

@ -41,7 +41,23 @@ class AbstractVectorFactory(ABC):
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
# `is_summary` and `original_chunk_id` are stored on summary vectors
# by `SummaryIndexService` and read back by `RetrievalService` to
# route summary hits through their original parent chunks. They
# must be listed here so vector backends that use this list as an
# explicit return-properties projection (notably Weaviate) actually
# return those fields; without them, summary hits silently
# collapse into `is_summary = False` branches and the summary
# retrieval path is a no-op. See #34884.
attributes = [
"doc_id",
"dataset_id",
"document_id",
"doc_hash",
"doc_type",
"is_summary",
"original_chunk_id",
]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes

View File

@ -6,7 +6,7 @@ from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import delete, func, select
from sqlalchemy import delete, func, select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -63,11 +63,11 @@ class IndexProcessor:
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> IndexingResultDict:
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
@ -104,12 +104,12 @@ class IndexProcessor:
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
session.scalar(
select(func.sum(DocumentSegment.word_count)).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
)
.scalar()
) or 0
# Update need_summary based on dataset's summary_index_setting
if summary_index_setting and summary_index_setting.get("enable") is True:
@ -118,15 +118,17 @@ class IndexProcessor:
document.need_summary = False
session.add(document)
# update document segment status
session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
.values(
status="completed",
enabled=True,
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
)
)
result: IndexingResultDict = {
@ -151,11 +153,11 @@ class IndexProcessor:
doc_language = None
with session_factory.create_session() as session:
if document_id:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
else:
document = None
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")

View File

@ -3,8 +3,7 @@
import logging
import re
import uuid
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, TypedDict, cast
logger = logging.getLogger(__name__)
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController()
class ParagraphFormatPreviewDict(TypedDict):
chunk_structure: str
preview: list[dict[str, Any]]
total_segments: int
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {
result: ParagraphFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
return result
else:
raise ValueError("Chunks is not a list")

View File

@ -3,8 +3,7 @@
import json
import logging
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from sqlalchemy import delete, select
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class ParentChildFormatPreviewDict(TypedDict):
chunk_structure: str
parent_mode: str
preview: list[dict[str, Any]]
total_segments: int
class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -153,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
@ -351,17 +355,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
result: ParentChildFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),
}
return result
def generate_summary_preview(
self,

View File

@ -4,11 +4,11 @@ import logging
import re
import threading
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
import pandas as pd
from flask import Flask, current_app
from sqlalchemy import select
from werkzeug.datastructures import FileStorage
from core.db.session_factory import session_factory
@ -36,6 +36,12 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class QAFormatPreviewDict(TypedDict):
chunk_structure: str
qa_preview: list[dict[str, Any]]
total_segments: int
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -158,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor):
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
@ -230,16 +234,17 @@ class QAIndexProcessor(BaseIndexProcessor):
else:
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
qa_chunks = QAStructureChunk.model_validate(chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
result: QAFormatPreviewDict = {
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}
return result
def generate_summary_preview(
self,

View File

@ -1,7 +1,7 @@
import base64
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from core.model_manager import ModelInstance, ModelManager
from core.rag.index_processor.constant.doc_type import DocType
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
:param query_type: query type
:return: rerank result
"""
docs = []
docs: list[MultimodalRerankInput] = []
doc_ids = set()
unique_documents = []
for document in documents:
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
docs.append(
MultimodalRerankInput(
content=document_file_base64,
content_type=document.metadata["doc_type"],
)
)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
docs.append(
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
unique_documents.append(document)
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
file_query_input = MultimodalRerankInput(
content=file_query,
content_type=DocType.IMAGE,
)
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, func, literal, or_, select, update
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import (
@ -276,8 +276,8 @@ class DatasetRetrieval:
document_ids = [i.segment.document_id for i in records]
with session_factory.create_session() as session:
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
@ -971,9 +971,11 @@ class DatasetRetrieval:
# Batch update hit_count for all segments
if segment_ids_to_update:
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
session.execute(
update(DocumentSegment)
.where(DocumentSegment.id.in_(segment_ids_to_update))
.values(hit_count=DocumentSegment.hit_count + 1)
.execution_options(synchronize_session=False)
)
self._send_trace_task(message_id, documents, timer)
@ -1822,7 +1824,7 @@ class DatasetRetrieval:
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
with session_factory.create_session() as session:
subquery = (
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
.where(
DocumentModel.indexing_status == "completed",
DocumentModel.enabled == True,
@ -1834,13 +1836,12 @@ class DatasetRetrieval:
.subquery()
)
results = (
session.query(Dataset)
results = session.scalars(
select(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
).all()
available_datasets = []
for dataset in results:

View File

@ -1,6 +1,8 @@
import concurrent.futures
import logging
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
@ -21,7 +23,7 @@ class SummaryIndex:
) -> None:
if is_preview:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
@ -34,32 +36,31 @@ class SummaryIndex:
if not document_id:
return
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
# Skip qa_model documents
if document is None or document.doc_form == "qa_model":
return
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset_id,
document_id=document_id,
status="completed",
enabled=True,
)
segments = query.all()
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
).all()
segment_ids = [segment.id for segment in segments]
if not segment_ids:
return
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
existing_summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.status == "completed",
)
.all()
)
).all()
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
# Preview mode should process segments that are MISSING completed summaries
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
@ -73,7 +74,7 @@ class SummaryIndex:
def process_segment(segment_id: str) -> None:
"""Process a single segment in a thread with a fresh DB session."""
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
if segment is None:
return
try:

View File

@ -450,6 +450,12 @@ class WorkflowToolParameterConfiguration(BaseModel):
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
class ToolInvokeMetaDict(TypedDict):
time_cost: float
error: str | None
tool_config: dict[str, Any] | None
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
@ -473,12 +479,13 @@ class ToolInvokeMeta(BaseModel):
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self):
return {
def to_dict(self) -> ToolInvokeMetaDict:
result: ToolInvokeMetaDict = {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
return result
class ToolLabel(BaseModel):

View File

@ -262,6 +262,8 @@ class ToolEngine:
ensure_ascii=False,
)
)
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
continue
else:
parts.append(str(response.message))

View File

@ -6,7 +6,6 @@ import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Union
from uuid import uuid4
import httpx
@ -158,7 +157,7 @@ class ToolFileManager:
return tool_file
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
"""
get file binary
@ -176,7 +175,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
"""
get file binary

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
import sqlalchemy as sa
from graphon.runtime import VariablePool
@ -100,7 +100,7 @@ class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
_builtin_tools_labels: dict[str, I18nObject | None] = {}
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
@ -190,7 +190,7 @@ class ToolManager:
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: str | None = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
"""
get the tool runtime
@ -398,7 +398,7 @@ class ToolManager:
agent_tool: AgentToolEntity,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: "VariablePool | None" = None,
) -> Tool:
"""
get the agent tool runtime
@ -442,7 +442,7 @@ class ToolManager:
workflow_tool: WorkflowToolRuntimeSpec,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: "VariablePool | None" = None,
) -> Tool:
"""
get the workflow tool runtime
@ -634,7 +634,7 @@ class ToolManager:
cls._builtin_providers_loaded = False
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
def get_tool_label(cls, tool_name: str) -> I18nObject | None:
"""
get the tool label
@ -682,7 +682,7 @@ class ToolManager:
with Session(db.engine, autoflush=False) as session:
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids))))
@classmethod
def list_providers_from_api(
@ -993,7 +993,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
try:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
@ -1001,7 +1001,7 @@ class ToolManager:
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
return cast(EmojiIconDict | str, mcp_provider.provider_icon)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
@ -1013,7 +1013,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> str | EmojiIconDict | dict[str, str]:
) -> str | EmojiIconDict:
"""
get the tool icon
@ -1052,7 +1052,7 @@ class ToolManager:
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
variable_pool: Optional["VariablePool"],
variable_pool: "VariablePool | None",
tool_configurations: Mapping[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:

View File

@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
assert isinstance(message.message.blob, bytes)
if not isinstance(message.message.blob, bytes):
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,

View File

@ -17,10 +17,8 @@ class WorkflowToolConfigurationUtils:
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod

View File

@ -5,12 +5,30 @@ from typing import Any
import pytz # type: ignore[import-untyped]
from celery import Celery, Task
from celery.schedules import crontab
from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
def get_celery_ssl_options() -> dict[str, Any] | None:
class _CelerySentinelKwargsDict(TypedDict):
socket_timeout: float | None
password: str | None
class CelerySentinelTransportDict(TypedDict):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
class CelerySSLOptionsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.BROKER_USE_SSL:
@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None:
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
ssl_options = {
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
}
return ssl_options
return CelerySSLOptionsDict(
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS,
ssl_certfile=dify_config.REDIS_SSL_CERTFILE,
ssl_keyfile=dify_config.REDIS_SSL_KEYFILE,
)
def get_celery_broker_transport_options() -> dict[str, Any]:
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
if dify_config.CELERY_USE_SENTINEL:
return {
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": {
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
"password": dify_config.CELERY_SENTINEL_PASSWORD,
},
}
return CelerySentinelTransportDict(
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
sentinel_kwargs=_CelerySentinelKwargsDict(
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
password=dify_config.CELERY_SENTINEL_PASSWORD,
),
)
return {}

View File

@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.retry import Retry
from redis.sentinel import Sentinel
from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
@ -126,6 +127,41 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
class RedisSSLParamsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
class RedisHealthParamsDict(TypedDict):
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
class RedisClusterHealthParamsDict(TypedDict):
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
class RedisBaseParamsDict(TypedDict):
username: str | None
password: str | None
db: int
encoding: str
encoding_errors: str
decode_responses: bool
protocol: int
cache_config: CacheConfig | None
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
"""Get SSL configuration for Redis connection."""
if not dify_config.REDIS_USE_SSL:
@ -171,17 +207,17 @@ def _get_retry_policy() -> Retry:
)
def _get_connection_health_params() -> dict[str, Any]:
def _get_connection_health_params() -> RedisHealthParamsDict:
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
return {
"retry": _get_retry_policy(),
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
}
return RedisHealthParamsDict(
retry=_get_retry_policy(),
socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
)
def _get_cluster_connection_health_params() -> dict[str, Any]:
def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict:
"""Get retry and timeout parameters for Redis Cluster clients.
RedisCluster does not support ``health_check_interval`` as a constructor
@ -189,26 +225,31 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params = _get_connection_health_params()
return {k: v for k, v in params.items() if k != "health_check_interval"}
def _get_base_redis_params() -> dict[str, Any]:
"""Get base Redis connection parameters including retry and health policy."""
return {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None,
"db": dify_config.REDIS_DB,
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
**_get_connection_health_params(),
health_params = _get_connection_health_params()
result: RedisClusterHealthParamsDict = {
"retry": health_params["retry"],
"socket_timeout": health_params["socket_timeout"],
"socket_connect_timeout": health_params["socket_connect_timeout"],
}
return result
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
def _get_base_redis_params() -> RedisBaseParamsDict:
"""Get base Redis connection parameters including retry and health policy."""
return RedisBaseParamsDict(
username=dify_config.REDIS_USERNAME,
password=dify_config.REDIS_PASSWORD or None,
db=dify_config.REDIS_DB,
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
cache_config=_get_cache_configuration(),
**_get_connection_health_params(),
)
def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create Redis client using Sentinel configuration."""
if not dify_config.REDIS_SENTINELS:
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
@ -232,7 +273,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
sentinel_kwargs=sentinel_kwargs,
)
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
params: dict[str, Any] = {**redis_params}
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
return master
@ -259,18 +301,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
return cluster
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration()
params = {**redis_params}
params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
)
params: dict[str, Any] = {
**redis_params,
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
if dify_config.REDIS_MAX_CONNECTIONS:
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
@ -293,8 +333,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
kwargs["max_connections"] = max_conns
return RedisCluster.from_url(pubsub_url, **kwargs)
health_params = _get_connection_health_params()
kwargs = {**health_params}
standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
kwargs = {**standalone_health_params}
if max_conns:
kwargs["max_connections"] = max_conns
return redis.Redis.from_url(pubsub_url, **kwargs)

View File

@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
handler = _get_handler_instance(handler_class or SpanHandler)
tracer = get_tracer(__name__)
return handler.wrapper(
tracer=tracer,
wrapped=func,
args=args,
kwargs=kwargs,
)
return handler.wrapper(tracer, func, *args, **kwargs)
return cast(Callable[P, R], wrapper)

View File

@ -1,8 +1,8 @@
import inspect
from collections.abc import Callable, Mapping
from collections.abc import Callable
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
class SpanHandler:
@ -16,9 +16,9 @@ class SpanHandler:
exceptions. Handlers can override the wrapper method to customize behavior.
"""
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
"""
Build the span name from the wrapped function.
@ -29,11 +29,11 @@ class SpanHandler:
"""
return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments[T](
def _extract_arguments[**P, R](
self,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> dict[str, Any] | None:
"""
Extract function arguments using inspect.signature.
@ -59,13 +59,13 @@ class SpanHandler:
except Exception:
return None
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""
Fully control the wrapper behavior.

View File

@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
class WorkflowAppRunnerHandler(SpanHandler):
"""Span handler for ``WorkflowAppRunner.run``."""
def wrapper(
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., Any],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@ -14,9 +14,15 @@ from __future__ import annotations
import logging
import threading
from typing import Any
from typing import TYPE_CHECKING, Any
import redis
from redis.cluster import RedisCluster
from redis.exceptions import LockNotOwnedError, RedisError
from redis.lock import Lock
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
logger = logging.getLogger(__name__)
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
primary error/exit code.
"""
_redis_client: Any
_redis_client: redis.Redis | RedisCluster | RedisClientWrapper
_name: str
_ttl_seconds: float
_renew_interval_seconds: float
_log_context: str | None
_logger: logging.Logger
_lock: Any
_lock: Lock | None
_stop_event: threading.Event | None
_thread: threading.Thread | None
_acquired: bool
def __init__(
self,
redis_client: Any,
redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
name: str,
ttl_seconds: float = 60,
renew_interval_seconds: float | None = None,
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
)
self._thread.start()
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
while not stop_event.wait(self._renew_interval_seconds):
try:
lock.reacquire()

View File

@ -17,7 +17,6 @@ def http_status_message(code):
def register_external_error_handlers(api: Api):
@api.errorhandler(HTTPException)
def handle_http_exception(e: HTTPException):
got_request_exception.send(current_app, exception=e)
@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api):
headers["Set-Cookie"] = build_force_logout_cookie_headers()
return data, status_code, headers
_ = handle_http_exception
@api.errorhandler(ValueError)
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code
_ = handle_value_error
@api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
got_request_exception.send(current_app, exception=e)
status_code = 429
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code
_ = handle_quota_exceeded
@api.errorhandler(Exception)
def handle_general_exception(e: Exception):
got_request_exception.send(current_app, exception=e)
@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api):
return data, status_code
_ = handle_general_exception
api.errorhandler(HTTPException)(handle_http_exception)
api.errorhandler(ValueError)(handle_value_error)
api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded)
api.errorhandler(Exception)(handle_general_exception)
class ExternalApi(Api):

View File

@ -10,7 +10,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
from uuid import UUID
from zoneinfo import available_timezones
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
def extract_tenant_id(user: "Account | EndUser") -> str | None:
"""
Extract tenant_id from Account or EndUser object.
@ -164,7 +164,10 @@ def email(email):
EmailStr = Annotated[str, AfterValidator(email)]
def uuid_value(value: Any) -> str:
def uuid_value(value: str | UUID) -> str:
if isinstance(value, UUID):
return str(value)
if value == "":
return str(value)
@ -405,7 +408,7 @@ class TokenManager:
def generate_token(
cls,
token_type: str,
account: Optional["Account"] = None,
account: "Account | None" = None,
email: str | None = None,
additional_data: dict | None = None,
) -> str:
@ -465,9 +468,7 @@ class TokenManager:
return current_token
@classmethod
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
):
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
key = cls._get_account_token_key(account_id, token_type)
expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(key, expiry_seconds, token)

View File

@ -0,0 +1,145 @@
"""Helpers for generating type-coverage summaries from pyrefly report output."""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import TypedDict
class CoverageSummary(TypedDict):
n_modules: int
n_typable: int
n_typed: int
n_any: int
n_untyped: int
coverage: float
strict_coverage: float
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
_EMPTY_SUMMARY: CoverageSummary = {
"n_modules": 0,
"n_typable": 0,
"n_typed": 0,
"n_any": 0,
"n_untyped": 0,
"coverage": 0.0,
"strict_coverage": 0.0,
}
def parse_summary(report_json: str) -> CoverageSummary:
"""Extract the summary section from ``pyrefly report`` JSON output.
Returns an empty summary when *report_json* is empty or malformed so that
the CI workflow can degrade gracefully instead of crashing.
"""
if not report_json or not report_json.strip():
return _EMPTY_SUMMARY.copy()
try:
data = json.loads(report_json)
except json.JSONDecodeError:
return _EMPTY_SUMMARY.copy()
summary = data.get("summary")
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
return _EMPTY_SUMMARY.copy()
return {
"n_modules": summary["n_modules"],
"n_typable": summary["n_typable"],
"n_typed": summary["n_typed"],
"n_any": summary["n_any"],
"n_untyped": summary["n_untyped"],
"coverage": summary["coverage"],
"strict_coverage": summary["strict_coverage"],
}
def format_summary_markdown(summary: CoverageSummary) -> str:
"""Format a single coverage summary as a Markdown table."""
return (
"| Metric | Value |\n"
"| --- | ---: |\n"
f"| Modules | {summary['n_modules']} |\n"
f"| Typable symbols | {summary['n_typable']:,} |\n"
f"| Typed symbols | {summary['n_typed']:,} |\n"
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
f"| Any symbols | {summary['n_any']:,} |\n"
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
)
def format_comparison_markdown(
base: CoverageSummary,
pr: CoverageSummary,
) -> str:
"""Format a comparison between base and PR coverage as Markdown."""
coverage_delta = pr["coverage"] - base["coverage"]
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
typed_delta = pr["n_typed"] - base["n_typed"]
untyped_delta = pr["n_untyped"] - base["n_untyped"]
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
sign = "+" if value > 0 else ""
return f"{sign}{value:{fmt}}"
lines = [
"| Metric | Base | PR | Delta |",
"| --- | ---: | ---: | ---: |",
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
(
f"| Strict coverage | {base['strict_coverage']:.2f}% "
f"| {pr['strict_coverage']:.2f}% "
f"| {_fmt_delta(strict_delta)}% |"
),
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
(
f"| Modules | {base['n_modules']} "
f"| {pr['n_modules']} "
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
),
]
return "\n".join(lines)
def main() -> int:
"""Read pyrefly report JSON from stdin and print a Markdown summary.
Accepts an optional ``--base <file>`` argument. When provided, the output
includes a base-vs-PR comparison table.
"""
args = sys.argv[1:]
base_file: str | None = None
if "--base" in args:
idx = args.index("--base")
if idx + 1 >= len(args):
sys.stderr.write("error: --base requires a file path\n")
return 1
base_file = args[idx + 1]
pr_report = sys.stdin.read()
pr_summary = parse_summary(pr_report)
if base_file is not None:
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
base_summary = parse_summary(base_text)
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
else:
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
class DefaultFieldsMixin:
"""Mixin for models that inherit from Base (non-dataclass)."""
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
return f"<{self.__class__.__name__}(id={self.id})>"
class DefaultFieldsDCMixin(MappedAsDataclass):
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
__abstract__ = True
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuidv7()),
default_factory=lambda: str(uuidv7()),
init=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(id={self.id})>"
def gen_uuidv4_string() -> str:
"""gen_uuidv4_string generate a UUIDv4 string.

View File

@ -108,6 +108,56 @@ class ExternalKnowledgeApiDict(TypedDict):
created_at: str
class DocumentDict(TypedDict):
id: str
tenant_id: str
dataset_id: str
position: int
data_source_type: str
data_source_info: str | None
dataset_process_rule_id: str | None
batch: str
name: str
created_from: str
created_by: str
created_api_request_id: str | None
created_at: datetime
processing_started_at: datetime | None
file_id: str | None
word_count: int | None
parsing_completed_at: datetime | None
cleaning_completed_at: datetime | None
splitting_completed_at: datetime | None
tokens: int | None
indexing_latency: float | None
completed_at: datetime | None
is_paused: bool | None
paused_by: str | None
paused_at: datetime | None
error: str | None
stopped_at: datetime | None
indexing_status: str
enabled: bool
disabled_at: datetime | None
disabled_by: str | None
archived: bool
archived_reason: str | None
archived_by: str | None
archived_at: datetime | None
updated_at: datetime
doc_type: str | None
doc_metadata: Any
doc_form: IndexStructureType
doc_language: str | None
display_status: str | None
data_source_info_dict: dict[str, Any]
average_segment_length: int
dataset_process_rule: ProcessRuleDict | None
dataset: None
segment_count: int | None
hit_count: int | None
class DatasetPermissionEnum(enum.StrEnum):
ONLY_ME = "only_me"
ALL_TEAM = "all_team_members"
@ -303,13 +353,17 @@ class Dataset(Base):
if self.provider != "external":
return None
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
select(ExternalKnowledgeBindings).where(
ExternalKnowledgeBindings.dataset_id == self.id,
ExternalKnowledgeBindings.tenant_id == self.tenant_id,
)
)
if not external_knowledge_binding:
return None
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis).where(
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
ExternalKnowledgeApis.tenant_id == self.tenant_id,
)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
@ -675,8 +729,8 @@ class Document(Base):
)
return built_in_fields
def to_dict(self) -> dict[str, Any]:
return {
def to_dict(self) -> DocumentDict:
result: DocumentDict = {
"id": self.id,
"tenant_id": self.tenant_id,
"dataset_id": self.dataset_id,
@ -721,10 +775,11 @@ class Document(Base):
"data_source_info_dict": self.data_source_info_dict,
"average_segment_length": self.average_segment_length,
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
"dataset": None, # Dataset class doesn't have a to_dict method
"dataset": None,
"segment_count": self.segment_count,
"hit_count": self.hit_count,
}
return result
@classmethod
def from_dict(cls, data: dict[str, Any]):

View File

@ -674,28 +674,24 @@ class AppModelConfig(TypeBase):
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig:
return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled})
@property
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast(
EnabledConfig,
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False},
)
return self._get_enabled_config(self.suggested_questions_after_answer)
@property
def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
return self._get_enabled_config(self.speech_to_text)
@property
def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
return self._get_enabled_config(self.text_to_speech)
@property
def retriever_resource_dict(self) -> EnabledConfig:
return cast(
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
return self._get_enabled_config(self.retriever_resource, default_enabled=True)
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
@ -722,7 +718,7 @@ class AppModelConfig(TypeBase):
@property
def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
return self._get_enabled_config(self.more_like_this)
@property
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
@ -902,7 +898,7 @@ class InstalledApp(TypeBase):
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
class TrialApp(TypeBase):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
@ -911,18 +907,22 @@ class TrialApp(Base):
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
class AccountTrialAppRecord(TypeBase):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
@ -930,11 +930,15 @@ class AccountTrialAppRecord(Base):
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@property
def app(self) -> App | None:

View File

@ -1,5 +1,6 @@
import json
from datetime import datetime
from typing import Any, TypedDict
from uuid import uuid4
import sqlalchemy as sa
@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase):
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
class DataSourceApiKeyAuthBindingDict(TypedDict):
id: str
tenant_id: str
category: str
provider: str
credentials: Any
created_at: float
updated_at: float
disabled: bool
class DataSourceApiKeyAuthBinding(TypeBase):
__tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = (
@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase):
)
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
def to_dict(self):
return {
def to_dict(self) -> DataSourceApiKeyAuthBindingDict:
result: DataSourceApiKeyAuthBindingDict = {
"id": self.id,
"tenant_id": self.tenant_id,
"category": self.category,
@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase):
"updated_at": self.updated_at.timestamp(),
"disabled": self.disabled,
}
return result

View File

@ -66,12 +66,15 @@ def build_file_from_stored_mapping(
record_id = resolve_file_record_id(mapping)
transfer_method = FileTransferMethod.value_of(mapping["transfer_method"])
if transfer_method == FileTransferMethod.TOOL_FILE and record_id:
mapping["tool_file_id"] = record_id
elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id:
mapping["upload_file_id"] = record_id
elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id:
mapping["datasource_file_id"] = record_id
match transfer_method:
case FileTransferMethod.TOOL_FILE if record_id:
mapping["tool_file_id"] = record_id
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id:
mapping["upload_file_id"] = record_id
case FileTransferMethod.DATASOURCE_FILE if record_id:
mapping["datasource_file_id"] = record_id
case _:
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
remote_url = mapping.get("remote_url")

View File

@ -4,7 +4,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
"""
Get workflow type from app mode.
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property

View File

@ -4,76 +4,76 @@ version = "1.13.3"
requires-python = "~=3.12.0"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
"aliyun-log-python-sdk~=0.9.44",
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.3",
"beautifulsoup4==4.14.3",
"boto3==1.42.83",
"boto3==1.42.88",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.6.2",
"charset-normalizer>=3.4.4",
"flask~=3.1.2",
"flask-compress>=1.17,<1.25",
"flask-cors~=6.0.0",
"cachetools~=7.0.5",
"celery~=5.6.3",
"charset-normalizer>=3.4.7",
"flask~=3.1.3",
"flask-compress>=1.24,<1.25",
"flask-cors~=6.0.2",
"flask-login~=0.6.3",
"flask-migrate~=4.1.0",
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent~=26.4.0",
"gmpy2~=2.3.0",
"google-api-core>=2.19.1",
"google-api-python-client==2.193.0",
"google-auth>=2.47.0",
"google-api-core>=2.30.3",
"google-api-python-client==2.194.0",
"google-auth>=2.49.2",
"google-auth-httplib2==0.3.1",
"google-cloud-aiplatform>=1.123.0",
"googleapis-common-protos>=1.65.0",
"google-cloud-aiplatform>=1.147.0",
"googleapis-common-protos>=1.74.0",
"graphon>=0.1.2",
"gunicorn~=25.3.0",
"httpx[socks]~=0.28.0",
"httpx[socks]~=0.28.1",
"jieba==0.42.1",
"json-repair>=0.55.1",
"langfuse>=3.0.0,<5.0.0",
"langsmith~=0.7.16",
"json-repair>=0.59.2",
"langfuse>=4.2.0,<5.0.0",
"langsmith~=0.7.30",
"markdown~=3.10.2",
"mlflow-skinny>=3.0.0",
"numpy~=1.26.4",
"mlflow-skinny>=3.11.1",
"numpy~=2.4.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"opik~=1.11.2",
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.40.0",
"opentelemetry-distro==0.61b0",
"opentelemetry-exporter-otlp==1.40.0",
"opentelemetry-exporter-otlp-proto-common==1.40.0",
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
"opentelemetry-exporter-otlp-proto-http==1.40.0",
"opentelemetry-instrumentation==0.61b0",
"opentelemetry-instrumentation-celery==0.61b0",
"opentelemetry-instrumentation-flask==0.61b0",
"opentelemetry-instrumentation-httpx==0.61b0",
"opentelemetry-instrumentation-redis==0.61b0",
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
"opentelemetry-propagator-b3==1.40.0",
"opentelemetry-proto==1.40.0",
"opentelemetry-sdk==1.40.0",
"opentelemetry-semantic-conventions==0.61b0",
"opentelemetry-util-http==0.61b0",
"pandas[excel,output-formatting,performance]~=3.0.1",
"opentelemetry-api==1.41.0",
"opentelemetry-distro==0.62b0",
"opentelemetry-exporter-otlp==1.41.0",
"opentelemetry-exporter-otlp-proto-common==1.41.0",
"opentelemetry-exporter-otlp-proto-grpc==1.41.0",
"opentelemetry-exporter-otlp-proto-http==1.41.0",
"opentelemetry-instrumentation==0.62b0",
"opentelemetry-instrumentation-celery==0.62b0",
"opentelemetry-instrumentation-flask==0.62b0",
"opentelemetry-instrumentation-httpx==0.62b0",
"opentelemetry-instrumentation-redis==0.62b0",
"opentelemetry-instrumentation-sqlalchemy==0.62b0",
"opentelemetry-propagator-b3==1.41.0",
"opentelemetry-proto==1.41.0",
"opentelemetry-sdk==1.41.0",
"opentelemetry-semantic-conventions==0.62b0",
"opentelemetry-util-http==0.62b0",
"pandas[excel,output-formatting,performance]~=3.0.2",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
"psycopg2-binary~=2.9.11",
"pycryptodome==3.23.0",
"pydantic~=2.12.5",
"pydantic-settings~=2.13.1",
"pyjwt~=2.12.0",
"pyjwt~=2.12.1",
"pypdfium2==5.6.0",
"python-docx~=1.2.0",
"python-dotenv==1.2.2",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.4.0",
"resend~=2.26.0",
"sentry-sdk[flask]~=2.55.0",
"sqlalchemy~=2.0.29",
"resend~=2.27.0",
"sentry-sdk[flask]~=2.57.0",
"sqlalchemy~=2.0.49",
"starlette==1.0.0",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
@ -82,12 +82,12 @@ dependencies = [
"yarl~=1.23.0",
"sseclient-py~=1.9.0",
"httpx-sse~=0.4.0",
"sendgrid~=6.12.3",
"sendgrid~=6.12.5",
"flask-restx~=1.3.2",
"packaging~=23.2",
"croniter>=6.0.0",
"weaviate-client==4.20.4",
"apscheduler>=3.11.0",
"packaging~=26.0",
"croniter>=6.2.2",
"weaviate-client==4.20.5",
"apscheduler>=3.11.2",
"weave>=0.52.16",
"fastopenapi[flask]>=0.7.0",
"bleach~=6.3.0",
@ -111,16 +111,16 @@ package = false
dev = [
"coverage~=7.13.4",
"dotenv-linter~=0.7.0",
"faker~=40.12.0",
"faker~=40.13.0",
"lxml-stubs~=0.5.1",
"basedpyright~=1.39.0",
"ruff~=0.15.5",
"pytest~=9.0.2",
"ruff~=0.15.10",
"pytest~=9.0.3",
"pytest-benchmark~=5.2.3",
"pytest-cov~=7.1.0",
"pytest-env~=1.6.0",
"pytest-mock~=3.15.1",
"testcontainers~=4.14.1",
"testcontainers~=4.14.2",
"types-aiofiles~=25.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=6.2.0",
@ -130,8 +130,8 @@ dev = [
"types-docutils~=0.22.3",
"types-flask-cors~=6.0.0",
"types-flask-migrate~=4.1.0",
"types-gevent~=25.9.0",
"types-greenlet~=3.3.0",
"types-gevent~=26.4.0",
"types-greenlet~=3.4.0",
"types-html5lib~=1.1.11",
"types-markdown~=3.10.2",
"types-oauthlib~=3.3.0",
@ -149,20 +149,20 @@ dev = [
"types-pyyaml~=6.0.12",
"types-regex~=2026.4.4",
"types-shapely~=2.1.0",
"types-simplejson>=3.20.0",
"types-six>=1.17.0",
"types-tensorflow>=2.18.0",
"types-tqdm>=4.67.0",
"types-simplejson>=3.20.0.20260408",
"types-six>=1.17.0.20260408",
"types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0",
"boto3-stubs>=1.38.20",
"types-jmespath>=1.0.2.20240106",
"hypothesis>=6.131.15",
"boto3-stubs>=1.42.88",
"types-jmespath>=1.1.0.20260408",
"hypothesis>=6.151.12",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=1.17.0",
"types_setuptools>=80.9.0",
"types_cffi>=2.0.0.20260408",
"types_setuptools>=82.0.0.20260408",
"pandas-stubs~=3.0.0",
"scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
"types-python-http-client>=3.3.7.20260408",
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
@ -180,10 +180,10 @@ dev = [
############################################################
storage = [
"azure-storage-blob==12.28.0",
"bce-python-sdk~=0.9.23",
"bce-python-sdk~=0.9.69",
"cos-python-sdk-v5==1.9.41",
"esdk-obs-python==3.26.2",
"google-cloud-storage>=3.0.0",
"google-cloud-storage>=3.10.1",
"opendal~=0.46.0",
"oss2==2.19.1",
"supabase~=2.18.1",
@ -209,23 +209,23 @@ vdb = [
"elasticsearch==8.14.0",
"opensearch-py==3.1.0",
"oracledb==3.4.2",
"pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvecto-rs[sqlalchemy]~=0.2.2",
"pgvector==0.4.2",
"pymilvus~=2.6.10",
"pymilvus~=2.6.12",
"pymochow==2.4.0",
"pyobvector~=0.2.17",
"qdrant-client==1.9.0",
"intersystems-irispython>=5.1.0",
"tablestore==6.4.3",
"tablestore==6.4.4",
"tcvectordb~=2.1.0",
"tidb-vector==0.0.15",
"upstash-vector==0.8.0",
"volcengine-compat~=1.0.0",
"weaviate-client==4.20.4",
"weaviate-client==4.20.5",
"xinference-client~=2.4.0",
"mo-vector~=0.1.13",
"mysql-connector-python>=9.3.0",
"holo-search-sdk>=0.4.1",
"holo-search-sdk>=0.4.2",
]
[tool.pyrefly]

View File

@ -47,7 +47,6 @@
"reportMissingTypeArgument": "hint",
"reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint",
"pythonVersion": "3.12",

View File

@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from core.db.session_factory import session_factory
class InvitationData(TypedDict):
@ -800,19 +801,19 @@ class AccountService:
return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
with session_factory.create_session() as session:
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
@ -1516,8 +1517,7 @@ class RegisterService:
check_workspace_member_invite_permission(tenant.id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")

View File

@ -1,11 +1,8 @@
import logging
import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
import pandas as pd
from sqlalchemy import delete, or_, select, update
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
logger = logging.getLogger(__name__)
class AnnotationJobStatusDict(TypedDict):
job_id: str
@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class EnableAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to enable_app_annotation."""
score_threshold: float
embedding_provider_name: str
embedding_model_name: str
class UpsertAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to up_insert_app_annotation_from_message."""
answer: str
content: str
message_id: str
question: str
class InsertAnnotationArgs(TypedDict):
"""Expected shape of the args dict passed to insert_app_annotation_directly."""
question: str
answer: str
class UpdateAnnotationArgs(TypedDict, total=False):
"""Expected shape of the args dict passed to update_app_annotation_directly.
Both fields are optional at the type level; the service validates at runtime
and raises ValueError if either is missing.
"""
answer: str
question: str
class UpdateAnnotationSettingArgs(TypedDict):
"""Expected shape of the args dict passed to update_app_annotation_setting."""
score_threshold: float
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@ -62,8 +102,9 @@ class AppAnnotationService:
if answer is None:
raise ValueError("Either 'answer' or 'content' must be provided")
if args.get("message_id"):
message_id = str(args["message_id"])
raw_message_id = args.get("message_id")
if raw_message_id:
message_id = str(raw_message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
)
@ -87,9 +128,10 @@ class AppAnnotationService:
account_id=current_user.id,
)
else:
question = args.get("question")
if not question:
maybe_question = args.get("question")
if not maybe_question:
raise ValueError("'question' is required when 'message_id' is not provided")
question = maybe_question
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
db.session.add(annotation)
@ -110,7 +152,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@ -217,7 +259,7 @@ class AppAnnotationService:
return annotations
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@ -251,7 +293,7 @@ class AppAnnotationService:
return annotation
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
@ -270,7 +312,11 @@ class AppAnnotationService:
if question is None:
raise ValueError("'question' is required")
annotation.content = args["answer"]
answer = args.get("answer")
if answer is None:
raise ValueError("'answer' is required")
annotation.content = answer
annotation.question = question
db.session.commit()
@ -613,7 +659,7 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info

View File

@ -467,61 +467,67 @@ class AppDslService:
)
# Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
match app_mode:
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj)
for obj in conversation_variables_list
]
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (
decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id, tenant_id=app.tenant_id
)
)
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
else:
raise ValueError("Invalid app mode")
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
case _:
raise ValueError("Invalid app mode")
return app
@classmethod

View File

@ -4,7 +4,7 @@ import logging
import threading
import uuid
from collections.abc import Callable, Generator, Mapping
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -89,7 +89,7 @@ class AppGenerateService:
def generate(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@ -358,11 +358,11 @@ class AppGenerateService:
def generate_more_like_this(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator]:
) -> Mapping | Generator:
"""
Generate more like this
:param app_model: app model

View File

@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
import json
from datetime import UTC, datetime
from typing import Any, Union
from typing import Any
from celery.result import AsyncResult
from sqlalchemy import select
@ -51,7 +51,7 @@ class AsyncWorkflowService:
@classmethod
def trigger_workflow_async(
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
) -> AsyncTriggerResponse:
"""
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
@ -184,7 +184,7 @@ class AsyncWorkflowService:
@classmethod
def reinvoke_trigger(
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
) -> AsyncTriggerResponse:
"""
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK

View File

@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import select
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs:
for model, table_name in related_tables:
# Query records related to expired messages
records = (
session.query(model)
.where(
records = session.scalars(
select(model).where(
model.message_id.in_(batch_message_ids), # type: ignore
)
.all()
)
).all()
if len(records) == 0:
continue
@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs:
except Exception:
logger.exception("Failed to save %s records", table_name)
session.query(model).where(
model.id.in_(record_ids), # type: ignore
).delete(synchronize_session=False)
session.execute(
delete(model)
.where(
model.id.in_(record_ids), # type: ignore
)
.execution_options(synchronize_session=False)
)
click.echo(
click.style(
@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs:
app_ids = [app.id for app in apps]
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
messages = (
session.query(Message)
messages = session.scalars(
select(Message)
.where(
Message.app_id.in_(app_ids),
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(messages) == 0:
break
@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs:
message_ids = [message.id for message in messages]
# delete messages
session.query(Message).where(
Message.id.in_(message_ids),
).delete(synchronize_session=False)
session.execute(
delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False)
)
cls._clear_message_related_tables(session, tenant_id, message_ids)
@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
conversations = (
session.query(Conversation)
conversations = session.scalars(
select(Conversation)
.where(
Conversation.app_id.in_(app_ids),
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(conversations) == 0:
break
@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs:
)
conversation_ids = [conversation.id for conversation in conversations]
session.query(Conversation).where(
Conversation.id.in_(conversation_ids),
).delete(synchronize_session=False)
session.execute(
delete(Conversation)
.where(Conversation.id.in_(conversation_ids))
.execution_options(synchronize_session=False)
)
click.echo(
click.style(
@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs:
while True:
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
workflow_app_logs = session.scalars(
select(WorkflowAppLog)
.where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
).all()
if len(workflow_app_logs) == 0:
break
@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs:
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
# delete workflow app logs
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
synchronize_session=False
session.execute(
delete(WorkflowAppLog)
.where(WorkflowAppLog.id.in_(workflow_app_log_ids))
.execution_options(synchronize_session=False)
)
click.echo(
@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs:
current_time = started_at
with sessionmaker(db.engine).begin() as session:
total_tenant_count = session.query(Tenant.id).count()
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs:
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
session.scalar(
select(func.count(Tenant.id)).where(
Tenant.created_at.between(current_time, current_time + test_interval)
)
)
or 0
)
if tenant_count <= 100:
interval = test_interval
@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs:
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
rs = session.execute(
select(Tenant.id)
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)

View File

@ -528,6 +528,8 @@ class DatasetService:
raise ValueError("External knowledge id is required.")
if not external_knowledge_api_id:
raise ValueError("External knowledge api id is required.")
# Ensure the referenced external API template exists and belongs to the dataset tenant.
ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id, dataset.tenant_id)
# Update metadata fields
dataset.updated_by = user.id if user else None
dataset.updated_at = naive_utc_now()
@ -552,8 +554,8 @@ class DatasetService:
external_knowledge_api_id: External knowledge API identifier
"""
with sessionmaker(db.engine).begin() as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
external_knowledge_binding = session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1)
)
if not external_knowledge_binding:
@ -1454,15 +1456,17 @@ class DocumentService:
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
with session_factory.create_session() as session:
updated_count = (
session.query(Document)
.filter(
result = session.execute(
update(Document)
.where(
Document.id.in_(document_id_list),
Document.dataset_id == dataset_id,
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.update({Document.need_summary: need_summary}, synchronize_session=False)
.values(need_summary=need_summary)
.execution_options(synchronize_session=False)
)
updated_count = result.rowcount # type: ignore[union-attr,attr-defined]
session.commit()
logger.info(
"Updated need_summary to %s for %d documents in dataset %s",
@ -2822,6 +2826,10 @@ class DocumentService:
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
if not knowledge_config.process_rule.rules.parent_mode:
knowledge_config.process_rule.rules.parent_mode = "paragraph"
if not knowledge_config.process_rule.rules.segmentation:
raise ValueError("Process rule segmentation is required")

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping
from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -54,11 +54,13 @@ class DatasourceProviderService:
remove oauth custom client params
"""
with sessionmaker(bind=db.engine).begin() as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.execute(
delete(DatasourceOauthTenantParamConfig).where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
)
def decrypt_datasource_provider_credentials(
self,
@ -110,15 +112,21 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id)
.limit(1)
)
else:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
.limit(1)
)
if not datasource_provider:
return {}
@ -173,12 +181,15 @@ class DatasourceProviderService:
get all datasource credentials by provider
"""
with sessionmaker(bind=db.engine).begin() as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
datasource_providers = session.scalars(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.all()
)
).all()
if not datasource_providers:
return []
current_user = get_current_user()
@ -232,15 +243,15 @@ class DatasourceProviderService:
update datasource provider name
"""
with sessionmaker(bind=db.engine).begin() as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
target_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == credential_id,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
@ -250,16 +261,16 @@ class DatasourceProviderService:
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=name,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == name,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
)
.count()
> 0
):
or 0
) > 0:
raise ValueError("Authorization name is already exists")
target_provider.name = name
@ -273,26 +284,31 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
target_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == credential_id,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=target_provider.provider,
plugin_id=target_provider.plugin_id,
is_default=True,
).update({"is_default": False})
session.execute(
update(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == target_provider.provider,
DatasourceProvider.plugin_id == target_provider.plugin_id,
DatasourceProvider.is_default.is_(True),
)
.values(is_default=False)
.execution_options(synchronize_session=False)
)
# set new default provider
target_provider.is_default = True
@ -311,14 +327,14 @@ class DatasourceProviderService:
if client_params is None and enabled is None:
return
with sessionmaker(bind=db.engine).begin() as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
tenant_oauth_client_params = session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if not tenant_oauth_client_params:
@ -351,9 +367,14 @@ class DatasourceProviderService:
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
session.scalar(
select(DatasourceOauthParamConfig)
.where(
DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.limit(1)
)
is not None
)
@ -423,15 +444,15 @@ class DatasourceProviderService:
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
tenant_oauth_client_params = session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == provider,
DatasourceOauthTenantParamConfig.plugin_id == plugin_id,
DatasourceOauthTenantParamConfig.enabled.is_(True),
)
.first()
.limit(1)
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
@ -443,8 +464,13 @@ class DatasourceProviderService:
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
if is_verified:
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
oauth_client_params = session.scalar(
select(DatasourceOauthParamConfig)
.where(
DatasourceOauthParamConfig.provider == provider,
DatasourceOauthParamConfig.plugin_id == plugin_id,
)
.limit(1)
)
if oauth_client_params:
return oauth_client_params.system_credentials
@ -455,15 +481,13 @@ class DatasourceProviderService:
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
db_providers = session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
.all()
)
).all()
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
@ -485,8 +509,10 @@ class DatasourceProviderService:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
target_provider = session.scalar(
select(DatasourceProvider)
.where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id)
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
@ -496,25 +522,28 @@ class DatasourceProviderService:
db_provider_name = target_provider.name
else:
name_conflict = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=CredentialType.OAUTH2.value,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == db_provider_name,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
DatasourceProvider.auth_type == CredentialType.OAUTH2.value,
)
)
.count()
or 0
)
if name_conflict > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
for provider in session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
).all()
],
db_provider_name,
)
@ -556,25 +585,27 @@ class DatasourceProviderService:
)
else:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == db_provider_name,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
DatasourceProvider.auth_type == credential_type.value,
)
)
.count()
> 0
):
or 0
) > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
for provider in session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
).all()
],
db_provider_name,
)
@ -627,11 +658,16 @@ class DatasourceProviderService:
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
.count()
> 0
):
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.plugin_id == plugin_id,
DatasourceProvider.provider == provider_name,
DatasourceProvider.name == db_provider_name,
)
)
or 0
) > 0:
raise ValueError("Authorization name is already exists")
try:
@ -918,21 +954,31 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == auth_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.limit(1)
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
# update name
if name and name != datasource_provider.name:
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
.count()
> 0
):
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == name,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
)
or 0
) > 0:
raise ValueError("Authorization name is already exists")
datasource_provider.name = name

View File

@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Union, cast
from typing import Any, cast
from urllib.parse import urlparse
import httpx
@ -148,18 +148,23 @@ class ExternalDatasetService:
db.session.commit()
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
"""
Return usage for an external knowledge API within a single tenant.
The caller already scopes access by tenant, so this query must do the
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
"""
count = (
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
ExternalKnowledgeBindings.tenant_id == tenant_id,
)
)
or 0
)
if count > 0:
return True, count
return False, 0
return count > 0, count
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
@ -190,9 +195,7 @@ class ExternalDatasetService:
raise ValueError(f"{parameter.get('name')} is required")
@staticmethod
def process_external_api(
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
) -> httpx.Response:
def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
"""
do http request depending on api bundle
"""
@ -314,7 +317,10 @@ class ExternalDatasetService:
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.where(
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
ExternalKnowledgeApis.tenant_id == tenant_id,
)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Iterator, Sequence
from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile
from typing import Literal, Union
from typing import Literal
from zipfile import ZIP_DEFLATED, ZipFile
from graphon.file import helpers as file_helpers
@ -52,7 +52,7 @@ class FileService:
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser],
user: Account | EndUser,
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:
@ -132,8 +132,8 @@ class FileService:
return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = self._session_maker(expire_on_commit=False).scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
if not upload_file:
raise NotFound("File not found")
@ -178,7 +178,7 @@ class FileService:
Return a short text preview extracted from a document file.
"""
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found")
@ -200,7 +200,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -220,7 +220,7 @@ class FileService:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -231,7 +231,7 @@ class FileService:
def get_public_image_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -247,7 +247,7 @@ class FileService:
def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found")

Some files were not shown because too many files have changed in this diff Show More