mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 22:28:55 +08:00
Merge branch 'feat/create-app' into deploy/dev
This commit is contained in:
commit
35bbf702ed
100
.github/dependabot.yml
vendored
100
.github/dependabot.yml
vendored
@ -1,106 +1,6 @@
|
||||
version: 2
|
||||
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
|
||||
9
.github/pull_request_template.md
vendored
9
.github/pull_request_template.md
vendored
@ -7,6 +7,7 @@
|
||||
## Summary
|
||||
|
||||
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
||||
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
|
||||
|
||||
## Screenshots
|
||||
|
||||
@ -17,7 +18,7 @@
|
||||
## Checklist
|
||||
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [ ] I've updated the documentation accordingly.
|
||||
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
|
||||
4
.github/workflows/api-tests.yml
vendored
4
.github/workflows/api-tests.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Upload unit coverage data
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: api-coverage-unit
|
||||
path: coverage-unit
|
||||
@ -129,7 +129,7 @@ jobs:
|
||||
api/tests/test_containers_integration_tests
|
||||
|
||||
- name: Upload integration coverage data
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: api-coverage-integration
|
||||
path: coverage-integration
|
||||
|
||||
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -81,7 +81,7 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
@ -101,7 +101,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
|
||||
2
.github/workflows/docker-build.yml
vendored
2
.github/workflows/docker-build.yml
vendored
@ -50,7 +50,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
|
||||
4
.github/workflows/pyrefly-diff-comment.yml
vendored
4
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Download pyrefly diff artifact
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@ -49,7 +49,7 @@ jobs:
|
||||
run: unzip -o pyrefly_diff.zip
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
4
.github/workflows/pyrefly-diff.yml
vendored
4
.github/workflows/pyrefly-diff.yml
vendored
@ -66,7 +66,7 @@ jobs:
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload pyrefly diff
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: pyrefly_diff
|
||||
path: |
|
||||
@ -75,7 +75,7 @@ jobs:
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
@ -0,0 +1,118 @@
|
||||
name: Comment with Pyrefly Type Coverage
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows:
|
||||
- Pyrefly Type Coverage
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Checkout default branch (trusted code)
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Download type coverage artifact
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }},
|
||||
});
|
||||
const match = artifacts.data.artifacts.find((artifact) =>
|
||||
artifact.name === 'pyrefly_type_coverage'
|
||||
);
|
||||
if (!match) {
|
||||
throw new Error('pyrefly_type_coverage artifact not found');
|
||||
}
|
||||
const download = await github.rest.actions.downloadArtifact({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
artifact_id: match.id,
|
||||
archive_format: 'zip',
|
||||
});
|
||||
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
|
||||
|
||||
- name: Unzip artifact
|
||||
run: unzip -o pyrefly_type_coverage.zip
|
||||
|
||||
- name: Render coverage markdown from structured data
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
echo ""
|
||||
echo "$comment_body"
|
||||
} > /tmp/type_coverage_comment.md
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||
let prNumber = null;
|
||||
try {
|
||||
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
|
||||
} catch (err) {
|
||||
const prs = context.payload.workflow_run.pull_requests || [];
|
||||
if (prs.length > 0 && prs[0].number) {
|
||||
prNumber = prs[0].number;
|
||||
}
|
||||
}
|
||||
if (!prNumber) {
|
||||
throw new Error('PR number not found in artifact or workflow_run payload');
|
||||
}
|
||||
|
||||
// Update existing comment if one exists, otherwise create new
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
const marker = '### Pyrefly Type Coverage';
|
||||
const existing = comments.find(c => c.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
@ -0,0 +1,120 @@
|
||||
name: Pyrefly Type Coverage
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'api/**/*.py'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run pyrefly report on PR branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
|
||||
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
|
||||
echo '{}' > /tmp/pyrefly_report_pr.json
|
||||
|
||||
- name: Save helper script from base branch
|
||||
run: |
|
||||
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|
||||
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
|
||||
|
||||
- name: Checkout base branch
|
||||
run: git checkout ${{ github.base_ref }}
|
||||
|
||||
- name: Run pyrefly report on base branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
|
||||
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
|
||||
echo '{}' > /tmp/pyrefly_report_base.json
|
||||
|
||||
- name: Generate coverage comparison
|
||||
id: coverage
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
|
||||
--base /tmp/pyrefly_report_base.json \
|
||||
< /tmp/pyrefly_report_pr.json)"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
echo ""
|
||||
echo "$comment_body"
|
||||
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
|
||||
|
||||
# Save structured data for the fork-PR comment workflow
|
||||
cp /tmp/pyrefly_report_pr.json pr_report.json
|
||||
cp /tmp/pyrefly_report_base.json base_report.json
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload type coverage artifact
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: pyrefly_type_coverage
|
||||
path: |
|
||||
pr_report.json
|
||||
base_report.json
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with type coverage
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const marker = '### Pyrefly Type Coverage';
|
||||
let body;
|
||||
try {
|
||||
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||
} catch {
|
||||
body = `${marker}\n\n_Coverage report unavailable._`;
|
||||
}
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
|
||||
// Update existing comment if one exists, otherwise create new
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
const existing = comments.find(c => c.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
6
.github/workflows/stale.yml
vendored
6
.github/workflows/stale.yml
vendored
@ -23,8 +23,8 @@ jobs:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'
|
||||
any-of-labels: '🌚 invalid,🙋♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
|
||||
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -56,7 +56,7 @@ jobs:
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
|
||||
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}
|
||||
|
||||
4
.github/workflows/web-e2e.yml
vendored
4
.github/workflows/web-e2e.yml
vendored
@ -53,7 +53,7 @@ jobs:
|
||||
|
||||
- name: Upload Cucumber report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: cucumber-report
|
||||
path: e2e/cucumber-report
|
||||
@ -61,7 +61,7 @@ jobs:
|
||||
|
||||
- name: Upload E2E logs
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: e2e-logs
|
||||
path: e2e/.logs
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@ -43,7 +43,7 @@ jobs:
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
|
||||
@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
## Automated Agent Contributions
|
||||
|
||||
> [!NOTE]
|
||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
||||
|
||||
@ -2,7 +2,6 @@ import base64
|
||||
import secrets
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account = db.session.merge(account)
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account = db.session.merge(account)
|
||||
account.email = normalized_new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, TypedDict
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
@ -107,6 +107,17 @@ class KeywordStoreConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SQLAlchemyEngineOptionsDict(TypedDict):
|
||||
pool_size: int
|
||||
max_overflow: int
|
||||
pool_recycle: int
|
||||
pool_pre_ping: bool
|
||||
connect_args: dict[str, str]
|
||||
pool_use_lifo: bool
|
||||
pool_reset_on_return: None
|
||||
pool_timeout: int
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
|
||||
@ -209,11 +220,11 @@ class DatabaseConfig(BaseSettings):
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
connect_args = {}
|
||||
connect_args: dict[str, str] = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
@ -223,7 +234,7 @@ class DatabaseConfig(BaseSettings):
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
result: SQLAlchemyEngineOptionsDict = {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
@ -233,6 +244,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"pool_reset_on_return": None,
|
||||
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
|
||||
@ -25,7 +25,13 @@ from fields.annotation_fields import (
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
UpdateAnnotationSettingArgs,
|
||||
UpsertAnnotationArgs,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": args.score_threshold,
|
||||
"embedding_provider_name": args.embedding_provider_name,
|
||||
"embedding_model_name": args.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -237,8 +249,16 @@ class AnnotationApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
upsert_args: UpsertAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
upsert_args["answer"] = args.answer
|
||||
if args.content is not None:
|
||||
upsert_args["content"] = args.content
|
||||
if args.message_id is not None:
|
||||
upsert_args["message_id"] = args.message_id
|
||||
if args.question is not None:
|
||||
upsert_args["question"] = args.question
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
update_args: UpdateAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@ -18,30 +20,30 @@ from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
class ModelConfigRequest(BaseModel):
|
||||
provider: str | None = Field(default=None, description="Model provider")
|
||||
model: str | None = Field(default=None, description="Model name")
|
||||
configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters")
|
||||
opening_statement: str | None = Field(default=None, description="Opening statement")
|
||||
suggested_questions: list[str] | None = Field(default=None, description="Suggested questions")
|
||||
more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration")
|
||||
speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration")
|
||||
text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration")
|
||||
retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration")
|
||||
tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools")
|
||||
dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations")
|
||||
agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ModelConfigRequest)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/model-config")
|
||||
class ModelConfigResource(Resource):
|
||||
@console_ns.doc("update_app_model_config")
|
||||
@console_ns.doc(description="Update application model configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ModelConfigRequest",
|
||||
{
|
||||
"provider": fields.String(description="Model provider"),
|
||||
"model": fields.String(description="Model name"),
|
||||
"configs": fields.Raw(description="Model configuration parameters"),
|
||||
"opening_statement": fields.String(description="Opening statement"),
|
||||
"suggested_questions": fields.List(fields.String(), description="Suggested questions"),
|
||||
"more_like_this": fields.Raw(description="More like this configuration"),
|
||||
"speech_to_text": fields.Raw(description="Speech to text configuration"),
|
||||
"text_to_speech": fields.Raw(description="Text to speech configuration"),
|
||||
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
|
||||
"tools": fields.List(fields.Raw(), description="Available tools"),
|
||||
"dataset_configs": fields.Raw(description="Dataset configurations"),
|
||||
"agent_mode": fields.Raw(description="Agent mode configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ModelConfigRequest.__name__])
|
||||
@console_ns.response(200, "Model configuration updated successfully")
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -86,7 +86,14 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
class FullContentDict(TypedDict):
|
||||
size_bytes: int | None
|
||||
value_type: str
|
||||
length: int | None
|
||||
download_url: str
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None:
|
||||
"""Serialize full_content information for large variables."""
|
||||
if not variable.is_truncated():
|
||||
return None
|
||||
@ -94,12 +101,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
variable_file = variable.variable_file
|
||||
assert variable_file is not None
|
||||
|
||||
return {
|
||||
result: FullContentDict = {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
def _update_existing_account(self, account, password_hashed, salt):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
||||
@ -4,7 +4,6 @@ import urllib.parse
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@ -4,10 +4,11 @@ from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@ -15,8 +16,6 @@ from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||
@ -27,8 +26,7 @@ class PartnerTenantsPayload(BaseModel):
|
||||
click_id: str = Field(..., description="Click Id from partner referral link")
|
||||
|
||||
|
||||
for model in (SubscriptionQuery, PartnerTenantsPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
@ -61,12 +59,7 @@ class PartnerTenants(Resource):
|
||||
@console_ns.doc("sync_partner_tenants_bindings")
|
||||
@console_ns.doc(description="Sync partner tenants bindings")
|
||||
@console_ns.doc(params={"partner_key": "Partner key"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__])
|
||||
@console_ns.response(200, "Tenants synced to partner successfully")
|
||||
@console_ns.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
|
||||
@ -162,7 +162,9 @@ class DataSourceApi(Resource):
|
||||
binding_id = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
@ -222,11 +224,11 @@ class DataSourceNotionListApi(Resource):
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=query.dataset_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
select(Document).where(
|
||||
Document.dataset_id == query.dataset_id,
|
||||
Document.tenant_id == current_tenant_id,
|
||||
Document.data_source_type == "notion_import",
|
||||
Document.enabled.is_(True),
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
|
||||
@ -280,7 +280,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
external_knowledge_api_id
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
||||
|
||||
|
||||
@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_rag_pipeline
|
||||
# def post(self, pipeline: Pipeline, node_id: str):
|
||||
# """
|
||||
# Run rag pipeline datasource
|
||||
# """
|
||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||
# if not current_user.has_edit_permission:
|
||||
# raise Forbidden()
|
||||
#
|
||||
# if not isinstance(current_user, Account):
|
||||
# raise Forbidden()
|
||||
#
|
||||
# parser = (reqparse.RequestParser()
|
||||
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
# job_id = args.get("job_id")
|
||||
# if job_id == None:
|
||||
# raise ValueError("missing job_id")
|
||||
# datasource_type = args.get("datasource_type")
|
||||
# if datasource_type == None:
|
||||
# raise ValueError("missing datasource_type")
|
||||
#
|
||||
# rag_pipeline_service = RagPipelineService()
|
||||
# result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
# pipeline=pipeline,
|
||||
# node_id=node_id,
|
||||
# job_id=job_id,
|
||||
# account=current_user,
|
||||
# datasource_type=datasource_type,
|
||||
# is_published=True
|
||||
# )
|
||||
#
|
||||
# return result
|
||||
|
||||
|
||||
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_rag_pipeline
|
||||
# def post(self, pipeline: Pipeline, node_id: str):
|
||||
# """
|
||||
# Run rag pipeline datasource
|
||||
# """
|
||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||
# if not current_user.has_edit_permission:
|
||||
# raise Forbidden()
|
||||
#
|
||||
# if not isinstance(current_user, Account):
|
||||
# raise Forbidden()
|
||||
#
|
||||
# parser = (reqparse.RequestParser()
|
||||
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
# job_id = args.get("job_id")
|
||||
# if job_id == None:
|
||||
# raise ValueError("missing job_id")
|
||||
# datasource_type = args.get("datasource_type")
|
||||
# if datasource_type == None:
|
||||
# raise ValueError("missing datasource_type")
|
||||
#
|
||||
# rag_pipeline_service = RagPipelineService()
|
||||
# result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
# pipeline=pipeline,
|
||||
# node_id=node_id,
|
||||
# job_id=job_id,
|
||||
# account=current_user,
|
||||
# datasource_type=datasource_type,
|
||||
# is_published=False
|
||||
# )
|
||||
#
|
||||
# return result
|
||||
#
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
|
||||
@ -7,7 +7,8 @@ import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_user_id=current_user.id,
|
||||
)
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -562,8 +561,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
|
||||
@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
|
||||
def plugin_data[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
payload_type: type[BaseModel],
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@ -116,7 +115,4 @@ def plugin_data[**P, R](
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
InsertAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
)
|
||||
|
||||
|
||||
class AnnotationCreatePayload(BaseModel):
|
||||
@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": payload.score_threshold,
|
||||
"embedding_provider_name": payload.embedding_provider_name,
|
||||
"embedding_model_name": payload.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
|
||||
@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
|
||||
@ -527,7 +527,7 @@ class DocumentListApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id)
|
||||
|
||||
if query_params.status:
|
||||
query = DocumentService.apply_display_status_filter(query, query_params.status)
|
||||
|
||||
@ -3,7 +3,6 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
@ -112,10 +119,7 @@ class HumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
@ -135,8 +139,8 @@ class HumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_end_user_id=None,
|
||||
# submission_end_user_id=_end_user.id,
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -29,6 +29,13 @@ from models.model import Message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionDict(TypedDict):
|
||||
"""Shape produced by AgentScratchpadUnit.Action.to_dict()."""
|
||||
|
||||
action: str
|
||||
action_input: dict[str, Any] | str
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
@ -331,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphon.file import File, FileUploadConfig
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
||||
trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel):
|
||||
output_schema: dict | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
|
||||
|
||||
|
||||
class DatasourceProviderApiEntityDict(TypedDict):
|
||||
id: str
|
||||
author: str
|
||||
name: str
|
||||
plugin_id: str | None
|
||||
plugin_unique_identifier: str | None
|
||||
description: I18nObjectDict
|
||||
icon: str | dict
|
||||
label: I18nObjectDict
|
||||
type: str
|
||||
team_credentials: dict | None
|
||||
is_team_authorization: bool
|
||||
allow_delete: bool
|
||||
datasources: list[Any]
|
||||
labels: list[str]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntity(BaseModel):
|
||||
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> DatasourceProviderApiEntityDict:
|
||||
# -------------
|
||||
# overwrite datasource parameter types for temp fix
|
||||
datasources = jsonable_encoder(self.datasources)
|
||||
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
result: DatasourceProviderApiEntityDict = {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
"datasources": datasources,
|
||||
"labels": self.labels,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
from yarl import URL
|
||||
@ -179,6 +179,12 @@ class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasourceInvokeMetaDict(TypedDict):
|
||||
time_cost: float
|
||||
error: str | None
|
||||
tool_config: dict[str, Any] | None
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
Datasource invoke meta
|
||||
@ -202,12 +208,13 @@ class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
def to_dict(self) -> DatasourceInvokeMetaDict:
|
||||
result: DatasourceInvokeMetaDict = {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class DatasourceLabel(BaseModel):
|
||||
|
||||
@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
if not isinstance(message.message.blob, bytes):
|
||||
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||
tool_file_manager = ToolFileManager()
|
||||
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
|
||||
@ -32,9 +32,9 @@ class Extensible:
|
||||
|
||||
name: str
|
||||
tenant_id: str
|
||||
config: dict | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
def __init__(self, tenant_id: str, config: dict | None = None):
|
||||
def __init__(self, tenant_id: str, config: dict[str, Any] | None = None):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
@ -7,6 +10,16 @@ from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class ApiToolConfig(TypedDict, total=False):
|
||||
"""Expected config shape for ApiExternalDataTool.
|
||||
|
||||
Not used directly in method signatures (base class accepts dict[str, Any]);
|
||||
kept here to document the keys this tool reads from config.
|
||||
"""
|
||||
|
||||
api_based_extension_id: str
|
||||
|
||||
|
||||
class ApiExternalDataTool(ExternalDataTool):
|
||||
"""
|
||||
The api external data tool.
|
||||
@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC):
|
||||
variable: str
|
||||
"""the tool variable name of app tool"""
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None):
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
self.variable = variable
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
import orjson
|
||||
|
||||
@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False):
|
||||
user_type: str
|
||||
|
||||
|
||||
class LogDict(TypedDict):
|
||||
ts: str
|
||||
severity: str
|
||||
service: str
|
||||
caller: str
|
||||
message: str
|
||||
trace_id: NotRequired[str]
|
||||
span_id: NotRequired[str]
|
||||
identity: NotRequired[IdentityDict]
|
||||
attributes: NotRequired[dict[str, Any]]
|
||||
stack_trace: NotRequired[str]
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
|
||||
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> LogDict:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
log_dict: LogDict = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
|
||||
@ -146,7 +146,7 @@ def discover_protected_resource_metadata(
|
||||
return ProtectedResourceMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
except (RequestError, ValidationError, json.JSONDecodeError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata(
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
except (RequestError, ValidationError, json.JSONDecodeError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except RequestError:
|
||||
except (RequestError, json.JSONDecodeError, IndexError):
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
|
||||
@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
|
||||
class RequestContext[SessionT: BaseSession, LifespanContextT]:
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
|
||||
@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
|
||||
@ -31,7 +31,6 @@ ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
||||
type AnyFunction = Callable[..., Any]
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
|
||||
@ -61,27 +61,28 @@ class TokenBufferMemory:
|
||||
:param is_user_message: whether this is a user message
|
||||
:return: PromptMessage
|
||||
"""
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
match self.conversation.mode:
|
||||
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
case _:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
if file_extra_config and app_record:
|
||||
|
||||
@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@ -172,10 +172,10 @@ class ModelInstance:
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
@ -193,15 +193,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
@ -216,15 +213,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
@ -241,15 +235,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
@ -261,14 +252,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
@ -289,23 +277,20 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@ -320,17 +305,14 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str) -> bool:
|
||||
@ -342,14 +324,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||
@ -361,14 +340,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||
@ -381,18 +357,15 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
@ -430,9 +403,8 @@ class ModelInstance:
|
||||
continue
|
||||
|
||||
try:
|
||||
if "credentials" in kwargs:
|
||||
del kwargs["credentials"]
|
||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
||||
kwargs["credentials"] = lb_config.credentials
|
||||
return function(*args, **kwargs)
|
||||
except InvokeRateLimitError as e:
|
||||
# expire in 60 seconds
|
||||
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
|
||||
def _get_user(cls, user_id: str) -> EndUser | Account:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -34,6 +34,13 @@ class ModelMode(StrEnum):
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
|
||||
class PromptTemplateConfigDict(TypedDict):
|
||||
prompt_template: PromptTemplateParser
|
||||
custom_variable_keys: list[str]
|
||||
special_variable_keys: list[str]
|
||||
prompt_rules: dict[str, Any]
|
||||
|
||||
|
||||
class SimplePromptTransform(PromptTransform):
|
||||
"""
|
||||
Simple Prompt Transform for Chatbot App Basic Mode.
|
||||
@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform):
|
||||
with_memory_prompt=histories is not None,
|
||||
)
|
||||
|
||||
custom_variable_keys_obj = prompt_template_config["custom_variable_keys"]
|
||||
special_variable_keys_obj = prompt_template_config["special_variable_keys"]
|
||||
custom_variable_keys = prompt_template_config["custom_variable_keys"]
|
||||
if not isinstance(custom_variable_keys, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}")
|
||||
|
||||
# Type check for custom_variable_keys
|
||||
if not isinstance(custom_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}")
|
||||
custom_variable_keys = cast(list[str], custom_variable_keys_obj)
|
||||
|
||||
# Type check for special_variable_keys
|
||||
if not isinstance(special_variable_keys_obj, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}")
|
||||
special_variable_keys = cast(list[str], special_variable_keys_obj)
|
||||
special_variable_keys = prompt_template_config["special_variable_keys"]
|
||||
if not isinstance(special_variable_keys, list):
|
||||
raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}")
|
||||
|
||||
variables = {k: inputs[k] for k in custom_variable_keys if k in inputs}
|
||||
|
||||
@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
has_context: bool,
|
||||
query_in_prompt: bool,
|
||||
with_memory_prompt: bool = False,
|
||||
) -> dict[str, object]:
|
||||
) -> PromptTemplateConfigDict:
|
||||
prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
|
||||
|
||||
custom_variable_keys: list[str] = []
|
||||
@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform):
|
||||
prompt += prompt_rules.get("query_prompt", "{{#query#}}")
|
||||
special_variable_keys.append("#query#")
|
||||
|
||||
return {
|
||||
result: PromptTemplateConfigDict = {
|
||||
"prompt_template": PromptTemplateParser(template=prompt),
|
||||
"custom_variable_keys": custom_variable_keys,
|
||||
"special_variable_keys": special_variable_keys,
|
||||
"prompt_rules": prompt_rules,
|
||||
}
|
||||
return result
|
||||
|
||||
def _get_chat_model_prompt_messages(
|
||||
self,
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -19,6 +20,16 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[str]
|
||||
verify_certs: bool
|
||||
ssl_show_warn: bool
|
||||
request_timeout: int
|
||||
retry_on_timeout: bool
|
||||
max_retries: int
|
||||
basic_auth: tuple[str, str]
|
||||
|
||||
|
||||
def create_ssl_context() -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
|
||||
raise ValueError("config HOSTS is required")
|
||||
return values
|
||||
|
||||
def to_elasticsearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts.split(","),
|
||||
"verify_certs": False,
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 30000,
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 10,
|
||||
}
|
||||
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
|
||||
params = HuaweiElasticsearchParamsDict(
|
||||
hosts=self.hosts.split(","),
|
||||
verify_certs=False,
|
||||
ssl_show_warn=False,
|
||||
request_timeout=30000,
|
||||
retry_on_timeout=True,
|
||||
max_retries=10,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["basic_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
|
||||
UGC_INDEX_PREFIX = "ugc_index"
|
||||
|
||||
|
||||
class LindormOpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: str | None
|
||||
use_ssl: bool
|
||||
pool_maxsize: int
|
||||
timeout: int
|
||||
http_auth: tuple[str, str]
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str | None
|
||||
username: str | None = None
|
||||
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {
|
||||
"hosts": self.hosts,
|
||||
"use_ssl": False,
|
||||
"pool_maxsize": 128,
|
||||
"timeout": 30,
|
||||
}
|
||||
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
|
||||
params = LindormOpenSearchParamsDict(
|
||||
hosts=self.hosts,
|
||||
use_ssl=False,
|
||||
pool_maxsize=128,
|
||||
timeout=30,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.opensearch_config import AuthMethod
|
||||
@ -21,6 +22,20 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _OpenSearchHostDict(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
class OpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[_OpenSearchHostDict]
|
||||
use_ssl: bool
|
||||
verify_certs: bool
|
||||
connection_class: type
|
||||
pool_maxsize: int
|
||||
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
|
||||
service=self.aws_service, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.verify_certs,
|
||||
"connection_class": Urllib3HttpConnection,
|
||||
"pool_maxsize": 20,
|
||||
}
|
||||
def to_opensearch_params(self) -> OpenSearchParamsDict:
|
||||
params = OpenSearchParamsDict(
|
||||
hosts=[{"host": self.host, "port": self.port}],
|
||||
use_ssl=self.secure,
|
||||
verify_certs=self.verify_certs,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
||||
session.execute(drop_statement)
|
||||
create_statement = sql_text(f"""
|
||||
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete(self):
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self.client) as session:
|
||||
|
||||
@ -6,7 +6,7 @@ import sqlalchemy
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id CHAR(36) PRIMARY KEY,
|
||||
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
with Session(self._engine) as session:
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
||||
|
||||
def _get_distance_func(self) -> str:
|
||||
match self._distance_func:
|
||||
|
||||
@ -41,7 +41,23 @@ class AbstractVectorFactory(ABC):
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
# `is_summary` and `original_chunk_id` are stored on summary vectors
|
||||
# by `SummaryIndexService` and read back by `RetrievalService` to
|
||||
# route summary hits through their original parent chunks. They
|
||||
# must be listed here so vector backends that use this list as an
|
||||
# explicit return-properties projection (notably Weaviate) actually
|
||||
# return those fields; without them, summary hits silently
|
||||
# collapse into `is_summary = False` branches and the summary
|
||||
# retrieval path is a no-op. See #34884.
|
||||
attributes = [
|
||||
"doc_id",
|
||||
"dataset_id",
|
||||
"document_id",
|
||||
"doc_hash",
|
||||
"doc_type",
|
||||
"is_summary",
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@ -63,11 +63,11 @@ class IndexProcessor:
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
) -> IndexingResultDict:
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
@ -104,12 +104,12 @@ class IndexProcessor:
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
session.scalar(
|
||||
select(func.sum(DocumentSegment.word_count)).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable") is True:
|
||||
@ -118,15 +118,17 @@ class IndexProcessor:
|
||||
document.need_summary = False
|
||||
session.add(document)
|
||||
# update document segment status
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
.values(
|
||||
status="completed",
|
||||
enabled=True,
|
||||
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
|
||||
result: IndexingResultDict = {
|
||||
@ -151,11 +153,11 @@ class IndexProcessor:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
if document_id:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
else:
|
||||
document = None
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class ParagraphFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
|
||||
if isinstance(chunks, list):
|
||||
preview = []
|
||||
for content in chunks:
|
||||
preview.append({"content": content})
|
||||
return {
|
||||
result: ParagraphFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
|
||||
"preview": preview,
|
||||
"total_segments": len(chunks),
|
||||
}
|
||||
return result
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParentChildFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
parent_mode: str
|
||||
preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -153,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
@ -351,17 +355,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
return {
|
||||
result: ParentChildFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
|
||||
"parent_mode": parent_childs.parent_mode,
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
}
|
||||
return result
|
||||
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
|
||||
@ -4,11 +4,11 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -36,6 +36,12 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QAFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
qa_preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -158,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
@ -230,16 +234,17 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||
return {
|
||||
result: QAFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.QA_INDEX,
|
||||
"qa_preview": preview,
|
||||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
return result
|
||||
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param query_type: query type
|
||||
:return: rerank result
|
||||
"""
|
||||
docs = []
|
||||
docs: list[MultimodalRerankInput] = []
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
document_file_base64 = base64.b64encode(blob).decode()
|
||||
document_file_dict = {
|
||||
"content": document_file_base64,
|
||||
"content_type": document.metadata["doc_type"],
|
||||
}
|
||||
docs.append(document_file_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document_file_base64,
|
||||
content_type=document.metadata["doc_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_text_dict = {
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
docs.append(document_text_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(
|
||||
{
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
unique_documents.append(document)
|
||||
|
||||
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_query = base64.b64encode(blob).decode()
|
||||
file_query_dict = {
|
||||
"content": file_query,
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
file_query_input = MultimodalRerankInput(
|
||||
content=file_query,
|
||||
content_type=DocType.IMAGE,
|
||||
)
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
|
||||
@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, func, literal, or_, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@ -276,8 +276,8 @@ class DatasetRetrieval:
|
||||
document_ids = [i.segment.document_id for i in records]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
|
||||
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
|
||||
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
@ -971,9 +971,11 @@ class DatasetRetrieval:
|
||||
|
||||
# Batch update hit_count for all segments
|
||||
if segment_ids_to_update:
|
||||
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.id.in_(segment_ids_to_update))
|
||||
.values(hit_count=DocumentSegment.hit_count + 1)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
@ -1822,7 +1824,7 @@ class DatasetRetrieval:
|
||||
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
|
||||
with session_factory.create_session() as session:
|
||||
subquery = (
|
||||
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
.where(
|
||||
DocumentModel.indexing_status == "completed",
|
||||
DocumentModel.enabled == True,
|
||||
@ -1834,13 +1836,12 @@ class DatasetRetrieval:
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
session.query(Dataset)
|
||||
results = session.scalars(
|
||||
select(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
available_datasets = []
|
||||
for dataset in results:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
@ -21,7 +23,7 @@ class SummaryIndex:
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
@ -34,32 +36,31 @@ class SummaryIndex:
|
||||
if not document_id:
|
||||
return
|
||||
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
# Skip qa_model documents
|
||||
if document is None or document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return
|
||||
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
existing_summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
|
||||
# Preview mode should process segments that are MISSING completed summaries
|
||||
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
|
||||
@ -73,7 +74,7 @@ class SummaryIndex:
|
||||
def process_segment(segment_id: str) -> None:
|
||||
"""Process a single segment in a thread with a fresh DB session."""
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
|
||||
if segment is None:
|
||||
return
|
||||
try:
|
||||
|
||||
@ -450,6 +450,12 @@ class WorkflowToolParameterConfiguration(BaseModel):
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class ToolInvokeMetaDict(TypedDict):
|
||||
time_cost: float
|
||||
error: str | None
|
||||
tool_config: dict[str, Any] | None
|
||||
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
@ -473,12 +479,13 @@ class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
def to_dict(self) -> ToolInvokeMetaDict:
|
||||
result: ToolInvokeMetaDict = {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class ToolLabel(BaseModel):
|
||||
|
||||
@ -262,6 +262,8 @@ class ToolEngine:
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
continue
|
||||
else:
|
||||
parts.append(str(response.message))
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
@ -158,7 +157,7 @@ class ToolFileManager:
|
||||
|
||||
return tool_file
|
||||
|
||||
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -176,7 +175,7 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.runtime import VariablePool
|
||||
@ -100,7 +100,7 @@ class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
_builtin_tools_labels: dict[str, I18nObject | None] = {}
|
||||
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
@ -190,7 +190,7 @@ class ToolManager:
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: str | None = None,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@ -398,7 +398,7 @@ class ToolManager:
|
||||
agent_tool: AgentToolEntity,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
variable_pool: "VariablePool | None" = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@ -442,7 +442,7 @@ class ToolManager:
|
||||
workflow_tool: WorkflowToolRuntimeSpec,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
variable_pool: "VariablePool | None" = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
@ -634,7 +634,7 @@ class ToolManager:
|
||||
cls._builtin_providers_loaded = False
|
||||
|
||||
@classmethod
|
||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
||||
def get_tool_label(cls, tool_name: str) -> I18nObject | None:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
@ -682,7 +682,7 @@ class ToolManager:
|
||||
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids))))
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
@ -993,7 +993,7 @@ class ToolManager:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
@ -1001,7 +1001,7 @@ class ToolManager:
|
||||
mcp_provider = mcp_service.get_provider_entity(
|
||||
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||
)
|
||||
return mcp_provider.provider_icon
|
||||
return cast(EmojiIconDict | str, mcp_provider.provider_icon)
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
except Exception:
|
||||
@ -1013,7 +1013,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> str | EmojiIconDict | dict[str, str]:
|
||||
) -> str | EmojiIconDict:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
@ -1052,7 +1052,7 @@ class ToolManager:
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: Optional["VariablePool"],
|
||||
variable_pool: "VariablePool | None",
|
||||
tool_configurations: Mapping[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
if not isinstance(message.message.blob, bytes):
|
||||
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
|
||||
@ -17,10 +17,8 @@ class WorkflowToolConfigurationUtils:
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -5,12 +5,30 @@ from typing import Any
|
||||
import pytz # type: ignore[import-untyped]
|
||||
from celery import Celery, Task
|
||||
from celery.schedules import crontab
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
class _CelerySentinelKwargsDict(TypedDict):
|
||||
socket_timeout: float | None
|
||||
password: str | None
|
||||
|
||||
|
||||
class CelerySentinelTransportDict(TypedDict):
|
||||
master_name: str | None
|
||||
sentinel_kwargs: _CelerySentinelKwargsDict
|
||||
|
||||
|
||||
class CelerySSLOptionsDict(TypedDict):
|
||||
ssl_cert_reqs: int
|
||||
ssl_ca_certs: str | None
|
||||
ssl_certfile: str | None
|
||||
ssl_keyfile: str | None
|
||||
|
||||
|
||||
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.BROKER_USE_SSL:
|
||||
@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
|
||||
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||
|
||||
ssl_options = {
|
||||
"ssl_cert_reqs": ssl_cert_reqs,
|
||||
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||
}
|
||||
|
||||
return ssl_options
|
||||
return CelerySSLOptionsDict(
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS,
|
||||
ssl_certfile=dify_config.REDIS_SSL_CERTFILE,
|
||||
ssl_keyfile=dify_config.REDIS_SSL_KEYFILE,
|
||||
)
|
||||
|
||||
|
||||
def get_celery_broker_transport_options() -> dict[str, Any]:
|
||||
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
|
||||
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
|
||||
if dify_config.CELERY_USE_SENTINEL:
|
||||
return {
|
||||
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
"sentinel_kwargs": {
|
||||
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
"password": dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
},
|
||||
}
|
||||
return CelerySentinelTransportDict(
|
||||
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
sentinel_kwargs=_CelerySentinelKwargsDict(
|
||||
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
password=dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
),
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.retry import Retry
|
||||
from redis.sentinel import Sentinel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
@ -126,6 +127,41 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
||||
|
||||
|
||||
class RedisSSLParamsDict(TypedDict):
|
||||
ssl_cert_reqs: int
|
||||
ssl_ca_certs: str | None
|
||||
ssl_certfile: str | None
|
||||
ssl_keyfile: str | None
|
||||
|
||||
|
||||
class RedisHealthParamsDict(TypedDict):
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
class RedisClusterHealthParamsDict(TypedDict):
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
|
||||
|
||||
class RedisBaseParamsDict(TypedDict):
|
||||
username: str | None
|
||||
password: str | None
|
||||
db: int
|
||||
encoding: str
|
||||
encoding_errors: str
|
||||
decode_responses: bool
|
||||
protocol: int
|
||||
cache_config: CacheConfig | None
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
"""Get SSL configuration for Redis connection."""
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
@ -171,17 +207,17 @@ def _get_retry_policy() -> Retry:
|
||||
)
|
||||
|
||||
|
||||
def _get_connection_health_params() -> dict[str, Any]:
|
||||
def _get_connection_health_params() -> RedisHealthParamsDict:
|
||||
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
|
||||
return {
|
||||
"retry": _get_retry_policy(),
|
||||
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
}
|
||||
return RedisHealthParamsDict(
|
||||
retry=_get_retry_policy(),
|
||||
socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
|
||||
|
||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict:
|
||||
"""Get retry and timeout parameters for Redis Cluster clients.
|
||||
|
||||
RedisCluster does not support ``health_check_interval`` as a constructor
|
||||
@ -189,26 +225,31 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||
are passed through.
|
||||
"""
|
||||
params = _get_connection_health_params()
|
||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||
|
||||
|
||||
def _get_base_redis_params() -> dict[str, Any]:
|
||||
"""Get base Redis connection parameters including retry and health policy."""
|
||||
return {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
"password": dify_config.REDIS_PASSWORD or None,
|
||||
"db": dify_config.REDIS_DB,
|
||||
"encoding": "utf-8",
|
||||
"encoding_errors": "strict",
|
||||
"decode_responses": False,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
**_get_connection_health_params(),
|
||||
health_params = _get_connection_health_params()
|
||||
result: RedisClusterHealthParamsDict = {
|
||||
"retry": health_params["retry"],
|
||||
"socket_timeout": health_params["socket_timeout"],
|
||||
"socket_connect_timeout": health_params["socket_connect_timeout"],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
def _get_base_redis_params() -> RedisBaseParamsDict:
|
||||
"""Get base Redis connection parameters including retry and health policy."""
|
||||
return RedisBaseParamsDict(
|
||||
username=dify_config.REDIS_USERNAME,
|
||||
password=dify_config.REDIS_PASSWORD or None,
|
||||
db=dify_config.REDIS_DB,
|
||||
encoding="utf-8",
|
||||
encoding_errors="strict",
|
||||
decode_responses=False,
|
||||
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
cache_config=_get_cache_configuration(),
|
||||
**_get_connection_health_params(),
|
||||
)
|
||||
|
||||
|
||||
def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create Redis client using Sentinel configuration."""
|
||||
if not dify_config.REDIS_SENTINELS:
|
||||
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
||||
@ -232,7 +273,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
|
||||
sentinel_kwargs=sentinel_kwargs,
|
||||
)
|
||||
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
params: dict[str, Any] = {**redis_params}
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
|
||||
return master
|
||||
|
||||
|
||||
@ -259,18 +301,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
return cluster
|
||||
|
||||
|
||||
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create standalone Redis client."""
|
||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||
|
||||
params = {**redis_params}
|
||||
params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
**redis_params,
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
@ -293,8 +333,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
|
||||
kwargs["max_connections"] = max_conns
|
||||
return RedisCluster.from_url(pubsub_url, **kwargs)
|
||||
|
||||
health_params = _get_connection_health_params()
|
||||
kwargs = {**health_params}
|
||||
standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
|
||||
kwargs = {**standalone_health_params}
|
||||
if max_conns:
|
||||
kwargs["max_connections"] = max_conns
|
||||
return redis.Redis.from_url(pubsub_url, **kwargs)
|
||||
|
||||
@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
|
||||
handler = _get_handler_instance(handler_class or SpanHandler)
|
||||
tracer = get_tracer(__name__)
|
||||
|
||||
return handler.wrapper(
|
||||
tracer=tracer,
|
||||
wrapped=func,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return handler.wrapper(tracer, func, *args, **kwargs)
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import inspect
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
|
||||
|
||||
class SpanHandler:
|
||||
@ -16,9 +16,9 @@ class SpanHandler:
|
||||
exceptions. Handlers can override the wrapper method to customize behavior.
|
||||
"""
|
||||
|
||||
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
|
||||
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
|
||||
|
||||
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
|
||||
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
|
||||
"""
|
||||
Build the span name from the wrapped function.
|
||||
|
||||
@ -29,11 +29,11 @@ class SpanHandler:
|
||||
"""
|
||||
return f"{wrapped.__module__}.{wrapped.__qualname__}"
|
||||
|
||||
def _extract_arguments[T](
|
||||
def _extract_arguments[**P, R](
|
||||
self,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract function arguments using inspect.signature.
|
||||
@ -59,13 +59,13 @@ class SpanHandler:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def wrapper[T](
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> T:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
"""
|
||||
Fully control the wrapper behavior.
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from extensions.otel.decorators.handler import SpanHandler
|
||||
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
|
||||
class AppGenerateHandler(SpanHandler):
|
||||
"""Span handler for ``AppGenerateService.generate``."""
|
||||
|
||||
def wrapper[T](
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> T:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||
if not arguments:
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from extensions.otel.decorators.handler import SpanHandler
|
||||
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
|
||||
class WorkflowAppRunnerHandler(SpanHandler):
|
||||
"""Span handler for ``WorkflowAppRunner.run``."""
|
||||
|
||||
def wrapper(
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Any:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||
if not arguments:
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
|
||||
@ -14,9 +14,15 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import redis
|
||||
from redis.cluster import RedisCluster
|
||||
from redis.exceptions import LockNotOwnedError, RedisError
|
||||
from redis.lock import Lock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
|
||||
primary error/exit code.
|
||||
"""
|
||||
|
||||
_redis_client: Any
|
||||
_redis_client: redis.Redis | RedisCluster | RedisClientWrapper
|
||||
_name: str
|
||||
_ttl_seconds: float
|
||||
_renew_interval_seconds: float
|
||||
_log_context: str | None
|
||||
_logger: logging.Logger
|
||||
|
||||
_lock: Any
|
||||
_lock: Lock | None
|
||||
_stop_event: threading.Event | None
|
||||
_thread: threading.Thread | None
|
||||
_acquired: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any,
|
||||
redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
|
||||
name: str,
|
||||
ttl_seconds: float = 60,
|
||||
renew_interval_seconds: float | None = None,
|
||||
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
||||
def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
|
||||
while not stop_event.wait(self._renew_interval_seconds):
|
||||
try:
|
||||
lock.reacquire()
|
||||
|
||||
@ -17,7 +17,6 @@ def http_status_message(code):
|
||||
|
||||
|
||||
def register_external_error_handlers(api: Api):
|
||||
@api.errorhandler(HTTPException)
|
||||
def handle_http_exception(e: HTTPException):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api):
|
||||
headers["Set-Cookie"] = build_force_logout_cookie_headers()
|
||||
return data, status_code, headers
|
||||
|
||||
_ = handle_http_exception
|
||||
|
||||
@api.errorhandler(ValueError)
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 400
|
||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_value_error
|
||||
|
||||
@api.errorhandler(AppInvokeQuotaExceededError)
|
||||
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 429
|
||||
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_quota_exceeded
|
||||
|
||||
@api.errorhandler(Exception)
|
||||
def handle_general_exception(e: Exception):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api):
|
||||
|
||||
return data, status_code
|
||||
|
||||
_ = handle_general_exception
|
||||
api.errorhandler(HTTPException)(handle_http_exception)
|
||||
api.errorhandler(ValueError)(handle_value_error)
|
||||
api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded)
|
||||
api.errorhandler(Exception)(handle_general_exception)
|
||||
|
||||
|
||||
class ExternalApi(Api):
|
||||
|
||||
@ -10,7 +10,7 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
def extract_tenant_id(user: "Account | EndUser") -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
@ -164,7 +164,10 @@ def email(email):
|
||||
EmailStr = Annotated[str, AfterValidator(email)]
|
||||
|
||||
|
||||
def uuid_value(value: Any) -> str:
|
||||
def uuid_value(value: str | UUID) -> str:
|
||||
if isinstance(value, UUID):
|
||||
return str(value)
|
||||
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
||||
@ -405,7 +408,7 @@ class TokenManager:
|
||||
def generate_token(
|
||||
cls,
|
||||
token_type: str,
|
||||
account: Optional["Account"] = None,
|
||||
account: "Account | None" = None,
|
||||
email: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
) -> str:
|
||||
@ -465,9 +468,7 @@ class TokenManager:
|
||||
return current_token
|
||||
|
||||
@classmethod
|
||||
def _set_current_token_for_account(
|
||||
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
|
||||
):
|
||||
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
expiry_seconds = int(expiry_minutes * 60)
|
||||
redis_client.setex(key, expiry_seconds, token)
|
||||
|
||||
145
api/libs/pyrefly_type_coverage.py
Normal file
145
api/libs/pyrefly_type_coverage.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Helpers for generating type-coverage summaries from pyrefly report output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class CoverageSummary(TypedDict):
|
||||
n_modules: int
|
||||
n_typable: int
|
||||
n_typed: int
|
||||
n_any: int
|
||||
n_untyped: int
|
||||
coverage: float
|
||||
strict_coverage: float
|
||||
|
||||
|
||||
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
|
||||
|
||||
_EMPTY_SUMMARY: CoverageSummary = {
|
||||
"n_modules": 0,
|
||||
"n_typable": 0,
|
||||
"n_typed": 0,
|
||||
"n_any": 0,
|
||||
"n_untyped": 0,
|
||||
"coverage": 0.0,
|
||||
"strict_coverage": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def parse_summary(report_json: str) -> CoverageSummary:
|
||||
"""Extract the summary section from ``pyrefly report`` JSON output.
|
||||
|
||||
Returns an empty summary when *report_json* is empty or malformed so that
|
||||
the CI workflow can degrade gracefully instead of crashing.
|
||||
"""
|
||||
if not report_json or not report_json.strip():
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
try:
|
||||
data = json.loads(report_json)
|
||||
except json.JSONDecodeError:
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
summary = data.get("summary")
|
||||
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
return {
|
||||
"n_modules": summary["n_modules"],
|
||||
"n_typable": summary["n_typable"],
|
||||
"n_typed": summary["n_typed"],
|
||||
"n_any": summary["n_any"],
|
||||
"n_untyped": summary["n_untyped"],
|
||||
"coverage": summary["coverage"],
|
||||
"strict_coverage": summary["strict_coverage"],
|
||||
}
|
||||
|
||||
|
||||
def format_summary_markdown(summary: CoverageSummary) -> str:
|
||||
"""Format a single coverage summary as a Markdown table."""
|
||||
|
||||
return (
|
||||
"| Metric | Value |\n"
|
||||
"| --- | ---: |\n"
|
||||
f"| Modules | {summary['n_modules']} |\n"
|
||||
f"| Typable symbols | {summary['n_typable']:,} |\n"
|
||||
f"| Typed symbols | {summary['n_typed']:,} |\n"
|
||||
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
|
||||
f"| Any symbols | {summary['n_any']:,} |\n"
|
||||
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
|
||||
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
|
||||
)
|
||||
|
||||
|
||||
def format_comparison_markdown(
|
||||
base: CoverageSummary,
|
||||
pr: CoverageSummary,
|
||||
) -> str:
|
||||
"""Format a comparison between base and PR coverage as Markdown."""
|
||||
|
||||
coverage_delta = pr["coverage"] - base["coverage"]
|
||||
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
|
||||
typed_delta = pr["n_typed"] - base["n_typed"]
|
||||
untyped_delta = pr["n_untyped"] - base["n_untyped"]
|
||||
|
||||
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
|
||||
sign = "+" if value > 0 else ""
|
||||
return f"{sign}{value:{fmt}}"
|
||||
|
||||
lines = [
|
||||
"| Metric | Base | PR | Delta |",
|
||||
"| --- | ---: | ---: | ---: |",
|
||||
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
|
||||
(
|
||||
f"| Strict coverage | {base['strict_coverage']:.2f}% "
|
||||
f"| {pr['strict_coverage']:.2f}% "
|
||||
f"| {_fmt_delta(strict_delta)}% |"
|
||||
),
|
||||
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
|
||||
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
|
||||
(
|
||||
f"| Modules | {base['n_modules']} "
|
||||
f"| {pr['n_modules']} "
|
||||
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
|
||||
),
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Read pyrefly report JSON from stdin and print a Markdown summary.
|
||||
|
||||
Accepts an optional ``--base <file>`` argument. When provided, the output
|
||||
includes a base-vs-PR comparison table.
|
||||
"""
|
||||
|
||||
args = sys.argv[1:]
|
||||
|
||||
base_file: str | None = None
|
||||
if "--base" in args:
|
||||
idx = args.index("--base")
|
||||
if idx + 1 >= len(args):
|
||||
sys.stderr.write("error: --base requires a file path\n")
|
||||
return 1
|
||||
base_file = args[idx + 1]
|
||||
|
||||
pr_report = sys.stdin.read()
|
||||
pr_summary = parse_summary(pr_report)
|
||||
|
||||
if base_file is not None:
|
||||
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
|
||||
base_summary = parse_summary(base_text)
|
||||
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
|
||||
else:
|
||||
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
|
||||
|
||||
|
||||
class DefaultFieldsMixin:
|
||||
"""Mixin for models that inherit from Base (non-dataclass)."""
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
class DefaultFieldsDCMixin(MappedAsDataclass):
|
||||
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuidv7()),
|
||||
default_factory=lambda: str(uuidv7()),
|
||||
init=False,
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
insert_default=naive_utc_now,
|
||||
default_factory=naive_utc_now,
|
||||
init=False,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
insert_default=naive_utc_now,
|
||||
default_factory=naive_utc_now,
|
||||
init=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
def gen_uuidv4_string() -> str:
|
||||
"""gen_uuidv4_string generate a UUIDv4 string.
|
||||
|
||||
|
||||
@ -108,6 +108,56 @@ class ExternalKnowledgeApiDict(TypedDict):
|
||||
created_at: str
|
||||
|
||||
|
||||
class DocumentDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: str | None
|
||||
dataset_process_rule_id: str | None
|
||||
batch: str
|
||||
name: str
|
||||
created_from: str
|
||||
created_by: str
|
||||
created_api_request_id: str | None
|
||||
created_at: datetime
|
||||
processing_started_at: datetime | None
|
||||
file_id: str | None
|
||||
word_count: int | None
|
||||
parsing_completed_at: datetime | None
|
||||
cleaning_completed_at: datetime | None
|
||||
splitting_completed_at: datetime | None
|
||||
tokens: int | None
|
||||
indexing_latency: float | None
|
||||
completed_at: datetime | None
|
||||
is_paused: bool | None
|
||||
paused_by: str | None
|
||||
paused_at: datetime | None
|
||||
error: str | None
|
||||
stopped_at: datetime | None
|
||||
indexing_status: str
|
||||
enabled: bool
|
||||
disabled_at: datetime | None
|
||||
disabled_by: str | None
|
||||
archived: bool
|
||||
archived_reason: str | None
|
||||
archived_by: str | None
|
||||
archived_at: datetime | None
|
||||
updated_at: datetime
|
||||
doc_type: str | None
|
||||
doc_metadata: Any
|
||||
doc_form: IndexStructureType
|
||||
doc_language: str | None
|
||||
display_status: str | None
|
||||
data_source_info_dict: dict[str, Any]
|
||||
average_segment_length: int
|
||||
dataset_process_rule: ProcessRuleDict | None
|
||||
dataset: None
|
||||
segment_count: int | None
|
||||
hit_count: int | None
|
||||
|
||||
|
||||
class DatasetPermissionEnum(enum.StrEnum):
|
||||
ONLY_ME = "only_me"
|
||||
ALL_TEAM = "all_team_members"
|
||||
@ -303,13 +353,17 @@ class Dataset(Base):
|
||||
if self.provider != "external":
|
||||
return None
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
|
||||
select(ExternalKnowledgeBindings).where(
|
||||
ExternalKnowledgeBindings.dataset_id == self.id,
|
||||
ExternalKnowledgeBindings.tenant_id == self.tenant_id,
|
||||
)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
return None
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis).where(
|
||||
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id
|
||||
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
|
||||
ExternalKnowledgeApis.tenant_id == self.tenant_id,
|
||||
)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
@ -675,8 +729,8 @@ class Document(Base):
|
||||
)
|
||||
return built_in_fields
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
def to_dict(self) -> DocumentDict:
|
||||
result: DocumentDict = {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"dataset_id": self.dataset_id,
|
||||
@ -721,10 +775,11 @@ class Document(Base):
|
||||
"data_source_info_dict": self.data_source_info_dict,
|
||||
"average_segment_length": self.average_segment_length,
|
||||
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
|
||||
"dataset": None, # Dataset class doesn't have a to_dict method
|
||||
"dataset": None,
|
||||
"segment_count": self.segment_count,
|
||||
"hit_count": self.hit_count,
|
||||
}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]):
|
||||
|
||||
@ -674,28 +674,24 @@ class AppModelConfig(TypeBase):
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled})
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig,
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
else {"enabled": False},
|
||||
)
|
||||
return self._get_enabled_config(self.suggested_questions_after_answer)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
|
||||
return self._get_enabled_config(self.speech_to_text)
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
|
||||
return self._get_enabled_config(self.text_to_speech)
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
)
|
||||
return self._get_enabled_config(self.retriever_resource, default_enabled=True)
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
@ -722,7 +718,7 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
|
||||
return self._get_enabled_config(self.more_like_this)
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
|
||||
@ -902,7 +898,7 @@ class InstalledApp(TypeBase):
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
class TrialApp(TypeBase):
|
||||
__tablename__ = "trial_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
|
||||
@ -911,18 +907,22 @@ class TrialApp(Base):
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
class AccountTrialAppRecord(TypeBase):
|
||||
__tablename__ = "account_trial_app_records"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
|
||||
@ -930,11 +930,15 @@ class AccountTrialAppRecord(Base):
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase):
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBindingDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
category: str
|
||||
provider: str
|
||||
credentials: Any
|
||||
created_at: float
|
||||
updated_at: float
|
||||
disabled: bool
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
def to_dict(self) -> DataSourceApiKeyAuthBindingDict:
|
||||
result: DataSourceApiKeyAuthBindingDict = {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"category": self.category,
|
||||
@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
"updated_at": self.updated_at.timestamp(),
|
||||
"disabled": self.disabled,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -66,12 +66,15 @@ def build_file_from_stored_mapping(
|
||||
record_id = resolve_file_record_id(mapping)
|
||||
transfer_method = FileTransferMethod.value_of(mapping["transfer_method"])
|
||||
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE and record_id:
|
||||
mapping["tool_file_id"] = record_id
|
||||
elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id:
|
||||
mapping["upload_file_id"] = record_id
|
||||
elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id:
|
||||
mapping["datasource_file_id"] = record_id
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE if record_id:
|
||||
mapping["tool_file_id"] = record_id
|
||||
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id:
|
||||
mapping["upload_file_id"] = record_id
|
||||
case FileTransferMethod.DATASOURCE_FILE if record_id:
|
||||
mapping["datasource_file_id"] = record_id
|
||||
case _:
|
||||
pass
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
|
||||
remote_url = mapping.get("remote_url")
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
|
||||
raise ValueError(f"invalid workflow type value {value}")
|
||||
|
||||
@classmethod
|
||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
||||
def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
|
||||
"""
|
||||
Get workflow type from app mode.
|
||||
|
||||
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
|
||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||
|
||||
@property
|
||||
|
||||
@ -4,76 +4,76 @@ version = "1.13.3"
|
||||
requires-python = "~=3.12.0"
|
||||
|
||||
dependencies = [
|
||||
"aliyun-log-python-sdk~=0.9.37",
|
||||
"aliyun-log-python-sdk~=0.9.44",
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.83",
|
||||
"boto3==1.42.88",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.25",
|
||||
"flask-cors~=6.0.0",
|
||||
"cachetools~=7.0.5",
|
||||
"celery~=5.6.3",
|
||||
"charset-normalizer>=3.4.7",
|
||||
"flask~=3.1.3",
|
||||
"flask-compress>=1.24,<1.25",
|
||||
"flask-cors~=6.0.2",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.1.0",
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gevent~=26.4.0",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.193.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-api-core>=2.30.3",
|
||||
"google-api-python-client==2.194.0",
|
||||
"google-auth>=2.49.2",
|
||||
"google-auth-httplib2==0.3.1",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
"googleapis-common-protos>=1.65.0",
|
||||
"google-cloud-aiplatform>=1.147.0",
|
||||
"googleapis-common-protos>=1.74.0",
|
||||
"graphon>=0.1.2",
|
||||
"gunicorn~=25.3.0",
|
||||
"httpx[socks]~=0.28.0",
|
||||
"httpx[socks]~=0.28.1",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"langfuse>=3.0.0,<5.0.0",
|
||||
"langsmith~=0.7.16",
|
||||
"json-repair>=0.59.2",
|
||||
"langfuse>=4.2.0,<5.0.0",
|
||||
"langsmith~=0.7.30",
|
||||
"markdown~=3.10.2",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"mlflow-skinny>=3.11.1",
|
||||
"numpy~=2.4.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"opik~=1.11.2",
|
||||
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.40.0",
|
||||
"opentelemetry-instrumentation==0.61b0",
|
||||
"opentelemetry-instrumentation-celery==0.61b0",
|
||||
"opentelemetry-instrumentation-flask==0.61b0",
|
||||
"opentelemetry-instrumentation-httpx==0.61b0",
|
||||
"opentelemetry-instrumentation-redis==0.61b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
|
||||
"opentelemetry-propagator-b3==1.40.0",
|
||||
"opentelemetry-proto==1.40.0",
|
||||
"opentelemetry-sdk==1.40.0",
|
||||
"opentelemetry-semantic-conventions==0.61b0",
|
||||
"opentelemetry-util-http==0.61b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"opentelemetry-api==1.41.0",
|
||||
"opentelemetry-distro==0.62b0",
|
||||
"opentelemetry-exporter-otlp==1.41.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.41.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.41.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.41.0",
|
||||
"opentelemetry-instrumentation==0.62b0",
|
||||
"opentelemetry-instrumentation-celery==0.62b0",
|
||||
"opentelemetry-instrumentation-flask==0.62b0",
|
||||
"opentelemetry-instrumentation-httpx==0.62b0",
|
||||
"opentelemetry-instrumentation-redis==0.62b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.62b0",
|
||||
"opentelemetry-propagator-b3==1.41.0",
|
||||
"opentelemetry-proto==1.41.0",
|
||||
"opentelemetry-sdk==1.41.0",
|
||||
"opentelemetry-semantic-conventions==0.62b0",
|
||||
"opentelemetry-util-http==0.62b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.2",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"psycopg2-binary~=2.9.11",
|
||||
"pycryptodome==3.23.0",
|
||||
"pydantic~=2.12.5",
|
||||
"pydantic-settings~=2.13.1",
|
||||
"pyjwt~=2.12.0",
|
||||
"pyjwt~=2.12.1",
|
||||
"pypdfium2==5.6.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.2.2",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.4.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"resend~=2.27.0",
|
||||
"sentry-sdk[flask]~=2.57.0",
|
||||
"sqlalchemy~=2.0.49",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
@ -82,12 +82,12 @@ dependencies = [
|
||||
"yarl~=1.23.0",
|
||||
"sseclient-py~=1.9.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"sendgrid~=6.12.5",
|
||||
"flask-restx~=1.3.2",
|
||||
"packaging~=23.2",
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"apscheduler>=3.11.0",
|
||||
"packaging~=26.0",
|
||||
"croniter>=6.2.2",
|
||||
"weaviate-client==4.20.5",
|
||||
"apscheduler>=3.11.2",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.3.0",
|
||||
@ -111,16 +111,16 @@ package = false
|
||||
dev = [
|
||||
"coverage~=7.13.4",
|
||||
"dotenv-linter~=0.7.0",
|
||||
"faker~=40.12.0",
|
||||
"faker~=40.13.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.39.0",
|
||||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"ruff~=0.15.10",
|
||||
"pytest~=9.0.3",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.1.0",
|
||||
"pytest-env~=1.6.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
"testcontainers~=4.14.2",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=6.2.0",
|
||||
@ -130,8 +130,8 @@ dev = [
|
||||
"types-docutils~=0.22.3",
|
||||
"types-flask-cors~=6.0.0",
|
||||
"types-flask-migrate~=4.1.0",
|
||||
"types-gevent~=25.9.0",
|
||||
"types-greenlet~=3.3.0",
|
||||
"types-gevent~=26.4.0",
|
||||
"types-greenlet~=3.4.0",
|
||||
"types-html5lib~=1.1.11",
|
||||
"types-markdown~=3.10.2",
|
||||
"types-oauthlib~=3.3.0",
|
||||
@ -149,20 +149,20 @@ dev = [
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2026.4.4",
|
||||
"types-shapely~=2.1.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
"types-tensorflow>=2.18.0",
|
||||
"types-tqdm>=4.67.0",
|
||||
"types-simplejson>=3.20.0.20260408",
|
||||
"types-six>=1.17.0.20260408",
|
||||
"types-tensorflow>=2.18.0.20260408",
|
||||
"types-tqdm>=4.67.3.20260408",
|
||||
"types-ujson>=5.10.0",
|
||||
"boto3-stubs>=1.38.20",
|
||||
"types-jmespath>=1.0.2.20240106",
|
||||
"hypothesis>=6.131.15",
|
||||
"boto3-stubs>=1.42.88",
|
||||
"types-jmespath>=1.1.0.20260408",
|
||||
"hypothesis>=6.151.12",
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=1.17.0",
|
||||
"types_setuptools>=80.9.0",
|
||||
"types_cffi>=2.0.0.20260408",
|
||||
"types_setuptools>=82.0.0.20260408",
|
||||
"pandas-stubs~=3.0.0",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20240910",
|
||||
"types-python-http-client>=3.3.7.20260408",
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
@ -180,10 +180,10 @@ dev = [
|
||||
############################################################
|
||||
storage = [
|
||||
"azure-storage-blob==12.28.0",
|
||||
"bce-python-sdk~=0.9.23",
|
||||
"bce-python-sdk~=0.9.69",
|
||||
"cos-python-sdk-v5==1.9.41",
|
||||
"esdk-obs-python==3.26.2",
|
||||
"google-cloud-storage>=3.0.0",
|
||||
"google-cloud-storage>=3.10.1",
|
||||
"opendal~=0.46.0",
|
||||
"oss2==2.19.1",
|
||||
"supabase~=2.18.1",
|
||||
@ -209,23 +209,23 @@ vdb = [
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==3.1.0",
|
||||
"oracledb==3.4.2",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.2",
|
||||
"pgvector==0.4.2",
|
||||
"pymilvus~=2.6.10",
|
||||
"pymilvus~=2.6.12",
|
||||
"pymochow==2.4.0",
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.9.0",
|
||||
"intersystems-irispython>=5.1.0",
|
||||
"tablestore==6.4.3",
|
||||
"tablestore==6.4.4",
|
||||
"tcvectordb~=2.1.0",
|
||||
"tidb-vector==0.0.15",
|
||||
"upstash-vector==0.8.0",
|
||||
"volcengine-compat~=1.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"weaviate-client==4.20.5",
|
||||
"xinference-client~=2.4.0",
|
||||
"mo-vector~=0.1.13",
|
||||
"mysql-connector-python>=9.3.0",
|
||||
"holo-search-sdk>=0.4.1",
|
||||
"holo-search-sdk>=0.4.2",
|
||||
]
|
||||
|
||||
[tool.pyrefly]
|
||||
|
||||
@ -47,7 +47,6 @@
|
||||
"reportMissingTypeArgument": "hint",
|
||||
"reportUnnecessaryComparison": "hint",
|
||||
"reportUnnecessaryIsInstance": "hint",
|
||||
"reportUntypedFunctionDecorator": "hint",
|
||||
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.12",
|
||||
|
||||
@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
|
||||
|
||||
class InvitationData(TypedDict):
|
||||
@ -800,19 +801,19 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
|
||||
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
query_session = session or db.session
|
||||
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
with session_factory.create_session() as session:
|
||||
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
@ -1516,8 +1517,7 @@ class RegisterService:
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import delete, or_, select, update
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re
|
||||
from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationJobStatusDict(TypedDict):
|
||||
job_id: str
|
||||
@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EnableAnnotationArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to enable_app_annotation."""
|
||||
|
||||
score_threshold: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class UpsertAnnotationArgs(TypedDict, total=False):
|
||||
"""Expected shape of the args dict passed to up_insert_app_annotation_from_message."""
|
||||
|
||||
answer: str
|
||||
content: str
|
||||
message_id: str
|
||||
question: str
|
||||
|
||||
|
||||
class InsertAnnotationArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to insert_app_annotation_directly."""
|
||||
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class UpdateAnnotationArgs(TypedDict, total=False):
|
||||
"""Expected shape of the args dict passed to update_app_annotation_directly.
|
||||
|
||||
Both fields are optional at the type level; the service validates at runtime
|
||||
and raises ValueError if either is missing.
|
||||
"""
|
||||
|
||||
answer: str
|
||||
question: str
|
||||
|
||||
|
||||
class UpdateAnnotationSettingArgs(TypedDict):
|
||||
"""Expected shape of the args dict passed to update_app_annotation_setting."""
|
||||
|
||||
score_threshold: float
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@ -62,8 +102,9 @@ class AppAnnotationService:
|
||||
if answer is None:
|
||||
raise ValueError("Either 'answer' or 'content' must be provided")
|
||||
|
||||
if args.get("message_id"):
|
||||
message_id = str(args["message_id"])
|
||||
raw_message_id = args.get("message_id")
|
||||
if raw_message_id:
|
||||
message_id = str(raw_message_id)
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
|
||||
)
|
||||
@ -87,9 +128,10 @@ class AppAnnotationService:
|
||||
account_id=current_user.id,
|
||||
)
|
||||
else:
|
||||
question = args.get("question")
|
||||
if not question:
|
||||
maybe_question = args.get("question")
|
||||
if not maybe_question:
|
||||
raise ValueError("'question' is required when 'message_id' is not provided")
|
||||
question = maybe_question
|
||||
|
||||
annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id)
|
||||
db.session.add(annotation)
|
||||
@ -110,7 +152,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
|
||||
def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict:
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@ -217,7 +259,7 @@ class AppAnnotationService:
|
||||
return annotations
|
||||
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@ -251,7 +293,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
@ -270,7 +312,11 @@ class AppAnnotationService:
|
||||
if question is None:
|
||||
raise ValueError("'question' is required")
|
||||
|
||||
annotation.content = args["answer"]
|
||||
answer = args.get("answer")
|
||||
if answer is None:
|
||||
raise ValueError("'answer' is required")
|
||||
|
||||
annotation.content = answer
|
||||
annotation.question = question
|
||||
|
||||
db.session.commit()
|
||||
@ -613,7 +659,7 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(
|
||||
cls, app_id: str, annotation_setting_id: str, args: dict
|
||||
cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs
|
||||
) -> AnnotationSettingDict:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
|
||||
@ -467,61 +467,67 @@ class AppDslService:
|
||||
)
|
||||
|
||||
# Initialize app based on mode
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
match app_mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj)
|
||||
for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
|
||||
]
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
|
||||
# Initialize model config
|
||||
model_config = data.get("model_config")
|
||||
if not model_config or not isinstance(model_config, dict):
|
||||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
workflow_service = WorkflowService()
|
||||
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (
|
||||
decrypted_id := self.decrypt_dataset_id(
|
||||
encrypted_data=dataset_id, tenant_id=app.tenant_id
|
||||
)
|
||||
)
|
||||
]
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION:
|
||||
# Initialize model config
|
||||
model_config = data.get("model_config")
|
||||
if not model_config or not isinstance(model_config, dict):
|
||||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
self._session.add(app_model_config)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
self._session.add(app_model_config)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
case _:
|
||||
raise ValueError("Invalid app mode")
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -89,7 +89,7 @@ class AppGenerateService:
|
||||
def generate(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
@ -358,11 +358,11 @@ class AppGenerateService:
|
||||
def generate_more_like_this(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
message_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping, Generator]:
|
||||
) -> Mapping | Generator:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
|
||||
@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import select
|
||||
@ -51,7 +51,7 @@ class AsyncWorkflowService:
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_async(
|
||||
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
|
||||
cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
||||
@ -184,7 +184,7 @@ class AsyncWorkflowService:
|
||||
|
||||
@classmethod
|
||||
def reinvoke_trigger(
|
||||
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
|
||||
cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
||||
|
||||
@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import click
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
for model, table_name in related_tables:
|
||||
# Query records related to expired messages
|
||||
records = (
|
||||
session.query(model)
|
||||
.where(
|
||||
records = session.scalars(
|
||||
select(model).where(
|
||||
model.message_id.in_(batch_message_ids), # type: ignore
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(records) == 0:
|
||||
continue
|
||||
@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
except Exception:
|
||||
logger.exception("Failed to save %s records", table_name)
|
||||
|
||||
session.query(model).where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(model)
|
||||
.where(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
app_ids = [app.id for app in apps]
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
messages = (
|
||||
session.query(Message)
|
||||
messages = session.scalars(
|
||||
select(Message)
|
||||
.where(
|
||||
Message.app_id.in_(app_ids),
|
||||
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if len(messages) == 0:
|
||||
break
|
||||
|
||||
@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
message_ids = [message.id for message in messages]
|
||||
|
||||
# delete messages
|
||||
session.query(Message).where(
|
||||
Message.id.in_(message_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
cls._clear_message_related_tables(session, tenant_id, message_ids)
|
||||
|
||||
@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
conversations = (
|
||||
session.query(Conversation)
|
||||
conversations = session.scalars(
|
||||
select(Conversation)
|
||||
.where(
|
||||
Conversation.app_id.in_(app_ids),
|
||||
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(conversations) == 0:
|
||||
break
|
||||
@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
)
|
||||
|
||||
conversation_ids = [conversation.id for conversation in conversations]
|
||||
session.query(Conversation).where(
|
||||
Conversation.id.in_(conversation_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(Conversation)
|
||||
.where(Conversation.id.in_(conversation_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
while True:
|
||||
with sessionmaker(bind=db.engine, autoflush=False).begin() as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
workflow_app_logs = session.scalars(
|
||||
select(WorkflowAppLog)
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if len(workflow_app_logs) == 0:
|
||||
break
|
||||
@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
|
||||
|
||||
# delete workflow app logs
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowAppLog)
|
||||
.where(WorkflowAppLog.id.in_(workflow_app_log_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
current_time = started_at
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
|
||||
@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
tenant_count = 0
|
||||
for test_interval in test_intervals:
|
||||
tenant_count = (
|
||||
session.query(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, current_time + test_interval))
|
||||
.count()
|
||||
session.scalar(
|
||||
select(func.count(Tenant.id)).where(
|
||||
Tenant.created_at.between(current_time, current_time + test_interval)
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if tenant_count <= 100:
|
||||
interval = test_interval
|
||||
@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
batch_end = min(current_time + interval, ended_at)
|
||||
|
||||
rs = (
|
||||
session.query(Tenant.id)
|
||||
rs = session.execute(
|
||||
select(Tenant.id)
|
||||
.where(Tenant.created_at.between(current_time, batch_end))
|
||||
.order_by(Tenant.created_at)
|
||||
)
|
||||
|
||||
@ -528,6 +528,8 @@ class DatasetService:
|
||||
raise ValueError("External knowledge id is required.")
|
||||
if not external_knowledge_api_id:
|
||||
raise ValueError("External knowledge api id is required.")
|
||||
# Ensure the referenced external API template exists and belongs to the dataset tenant.
|
||||
ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id, dataset.tenant_id)
|
||||
# Update metadata fields
|
||||
dataset.updated_by = user.id if user else None
|
||||
dataset.updated_at = naive_utc_now()
|
||||
@ -552,8 +554,8 @@ class DatasetService:
|
||||
external_knowledge_api_id: External knowledge API identifier
|
||||
"""
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
external_knowledge_binding = session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not external_knowledge_binding:
|
||||
@ -1454,15 +1456,17 @@ class DocumentService:
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
updated_count = (
|
||||
session.query(Document)
|
||||
.filter(
|
||||
result = session.execute(
|
||||
update(Document)
|
||||
.where(
|
||||
Document.id.in_(document_id_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
|
||||
)
|
||||
.update({Document.need_summary: need_summary}, synchronize_session=False)
|
||||
.values(need_summary=need_summary)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
updated_count = result.rowcount # type: ignore[union-attr,attr-defined]
|
||||
session.commit()
|
||||
logger.info(
|
||||
"Updated need_summary to %s for %d documents in dataset %s",
|
||||
@ -2822,6 +2826,10 @@ class DocumentService:
|
||||
|
||||
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
|
||||
if not knowledge_config.process_rule.rules.parent_mode:
|
||||
knowledge_config.process_rule.rules.parent_mode = "paragraph"
|
||||
|
||||
if not knowledge_config.process_rule.rules.segmentation:
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -54,11 +54,13 @@ class DatasourceProviderService:
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(DatasourceOauthTenantParamConfig).where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
)
|
||||
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
@ -110,15 +112,21 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id)
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
@ -173,12 +181,15 @@ class DatasourceProviderService:
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
datasource_providers = session.scalars(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not datasource_providers:
|
||||
return []
|
||||
current_user = get_current_user()
|
||||
@ -232,15 +243,15 @@ class DatasourceProviderService:
|
||||
update datasource provider name
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == credential_id,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
@ -250,16 +261,16 @@ class DatasourceProviderService:
|
||||
|
||||
# check name is exist
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == name,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
@ -273,26 +284,31 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == credential_id,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=target_provider.provider,
|
||||
plugin_id=target_provider.plugin_id,
|
||||
is_default=True,
|
||||
).update({"is_default": False})
|
||||
session.execute(
|
||||
update(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == target_provider.provider,
|
||||
DatasourceProvider.plugin_id == target_provider.plugin_id,
|
||||
DatasourceProvider.is_default.is_(True),
|
||||
)
|
||||
.values(is_default=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
@ -311,14 +327,14 @@ class DatasourceProviderService:
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
tenant_oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not tenant_oauth_client_params:
|
||||
@ -351,9 +367,14 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
|
||||
.first()
|
||||
session.scalar(
|
||||
select(DatasourceOauthParamConfig)
|
||||
.where(
|
||||
DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
@ -423,15 +444,15 @@ class DatasourceProviderService:
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
# get tenant oauth client params
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
enabled=True,
|
||||
tenant_oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == provider,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == plugin_id,
|
||||
DatasourceOauthTenantParamConfig.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
@ -443,8 +464,13 @@ class DatasourceProviderService:
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if is_verified:
|
||||
# fallback to system oauth client params
|
||||
oauth_client_params = (
|
||||
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthParamConfig)
|
||||
.where(
|
||||
DatasourceOauthParamConfig.provider == provider,
|
||||
DatasourceOauthParamConfig.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if oauth_client_params:
|
||||
return oauth_client_params.system_credentials
|
||||
@ -455,15 +481,13 @@ class DatasourceProviderService:
|
||||
def generate_next_datasource_provider_name(
|
||||
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
|
||||
) -> str:
|
||||
db_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
db_providers = session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in db_providers],
|
||||
f"{credential_type.get_name()}",
|
||||
@ -485,8 +509,10 @@ class DatasourceProviderService:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
@ -496,25 +522,28 @@ class DatasourceProviderService:
|
||||
db_provider_name = target_provider.name
|
||||
else:
|
||||
name_conflict = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=CredentialType.OAUTH2.value,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
DatasourceProvider.auth_type == CredentialType.OAUTH2.value,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
if name_conflict > 0:
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
for provider in session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
).all()
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
@ -556,25 +585,27 @@ class DatasourceProviderService:
|
||||
)
|
||||
else:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
DatasourceProvider.auth_type == credential_type.value,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
or 0
|
||||
) > 0:
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
for provider in session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
).all()
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
@ -627,11 +658,16 @@ class DatasourceProviderService:
|
||||
|
||||
# check name is exist
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
DatasourceProvider.provider == provider_name,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
try:
|
||||
@ -918,21 +954,31 @@ class DatasourceProviderService:
|
||||
"""
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == auth_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not datasource_provider:
|
||||
raise ValueError("Datasource provider not found")
|
||||
# update name
|
||||
if name and name != datasource_provider.name:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == name,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
datasource_provider.name = name
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
@ -148,18 +148,23 @@ class ExternalDatasetService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
|
||||
"""
|
||||
Return usage for an external knowledge API within a single tenant.
|
||||
|
||||
The caller already scopes access by tenant, so this query must do the
|
||||
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
|
||||
"""
|
||||
count = (
|
||||
db.session.scalar(
|
||||
select(func.count(ExternalKnowledgeBindings.id)).where(
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
|
||||
ExternalKnowledgeBindings.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
return False, 0
|
||||
return count > 0, count
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
@ -190,9 +195,7 @@ class ExternalDatasetService:
|
||||
raise ValueError(f"{parameter.get('name')} is required")
|
||||
|
||||
@staticmethod
|
||||
def process_external_api(
|
||||
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
|
||||
) -> httpx.Response:
|
||||
def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
@ -314,7 +317,10 @@ class ExternalDatasetService:
|
||||
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
|
||||
.where(
|
||||
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
|
||||
ExternalKnowledgeApis.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
|
||||
@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager, suppress
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
from zipfile import ZIP_DEFLATED, ZipFile
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -52,7 +52,7 @@ class FileService:
|
||||
filename: str,
|
||||
content: bytes,
|
||||
mimetype: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
source: Literal["datasets"] | None = None,
|
||||
source_url: str = "",
|
||||
) -> UploadFile:
|
||||
@ -132,8 +132,8 @@ class FileService:
|
||||
return file_size <= file_size_limit
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = self._session_maker(expire_on_commit=False).scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id).limit(1)
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -178,7 +178,7 @@ class FileService:
|
||||
Return a short text preview extracted from a document file.
|
||||
"""
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -200,7 +200,7 @@ class FileService:
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -220,7 +220,7 @@ class FileService:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -231,7 +231,7 @@ class FileService:
|
||||
|
||||
def get_public_image_preview(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -247,7 +247,7 @@ class FileService:
|
||||
|
||||
def get_file_content(self, file_id: str) -> str:
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user