mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 18:27:15 +08:00
Merge remote-tracking branch 'myori/main' into feat/collaboration2
This commit is contained in:
commit
ee2b021395
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@ -7,6 +7,7 @@
|
|||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
||||||
|
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
|
||||||
|
|
||||||
## Screenshots
|
## Screenshots
|
||||||
|
|
||||||
|
|||||||
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
name: Comment with Pyrefly Type Coverage
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_run:
|
||||||
|
workflows:
|
||||||
|
- Pyrefly Type Coverage
|
||||||
|
types:
|
||||||
|
- completed
|
||||||
|
|
||||||
|
permissions: {}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
comment:
|
||||||
|
name: Comment PR with type coverage
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
actions: read
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout default branch (trusted code)
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
|
||||||
|
- name: Setup Python & UV
|
||||||
|
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
|
- name: Download type coverage artifact
|
||||||
|
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||||
|
with:
|
||||||
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
script: |
|
||||||
|
const fs = require('fs');
|
||||||
|
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
run_id: ${{ github.event.workflow_run.id }},
|
||||||
|
});
|
||||||
|
const match = artifacts.data.artifacts.find((artifact) =>
|
||||||
|
artifact.name === 'pyrefly_type_coverage'
|
||||||
|
);
|
||||||
|
if (!match) {
|
||||||
|
throw new Error('pyrefly_type_coverage artifact not found');
|
||||||
|
}
|
||||||
|
const download = await github.rest.actions.downloadArtifact({
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
artifact_id: match.id,
|
||||||
|
archive_format: 'zip',
|
||||||
|
});
|
||||||
|
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
|
||||||
|
|
||||||
|
- name: Unzip artifact
|
||||||
|
run: unzip -o pyrefly_type_coverage.zip
|
||||||
|
|
||||||
|
- name: Render coverage markdown from structured data
|
||||||
|
id: render
|
||||||
|
run: |
|
||||||
|
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
|
||||||
|
--base base_report.json \
|
||||||
|
< pr_report.json)"
|
||||||
|
|
||||||
|
{
|
||||||
|
echo "### Pyrefly Type Coverage"
|
||||||
|
echo ""
|
||||||
|
echo "$comment_body"
|
||||||
|
} > /tmp/type_coverage_comment.md
|
||||||
|
|
||||||
|
- name: Post comment
|
||||||
|
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||||
|
with:
|
||||||
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
script: |
|
||||||
|
const fs = require('fs');
|
||||||
|
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||||
|
let prNumber = null;
|
||||||
|
try {
|
||||||
|
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
|
||||||
|
} catch (err) {
|
||||||
|
const prs = context.payload.workflow_run.pull_requests || [];
|
||||||
|
if (prs.length > 0 && prs[0].number) {
|
||||||
|
prNumber = prs[0].number;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!prNumber) {
|
||||||
|
throw new Error('PR number not found in artifact or workflow_run payload');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update existing comment if one exists, otherwise create new
|
||||||
|
const { data: comments } = await github.rest.issues.listComments({
|
||||||
|
issue_number: prNumber,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
});
|
||||||
|
const marker = '### Pyrefly Type Coverage';
|
||||||
|
const existing = comments.find(c => c.body.startsWith(marker));
|
||||||
|
|
||||||
|
if (existing) {
|
||||||
|
await github.rest.issues.updateComment({
|
||||||
|
comment_id: existing.id,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
issue_number: prNumber,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
}
|
||||||
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
name: Pyrefly Type Coverage
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'api/**/*.py'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pyrefly-type-coverage:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout PR branch
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python & UV
|
||||||
|
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
|
- name: Run pyrefly report on PR branch
|
||||||
|
run: |
|
||||||
|
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
|
||||||
|
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
|
||||||
|
echo '{}' > /tmp/pyrefly_report_pr.json
|
||||||
|
|
||||||
|
- name: Save helper script from base branch
|
||||||
|
run: |
|
||||||
|
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|
||||||
|
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
|
||||||
|
|
||||||
|
- name: Checkout base branch
|
||||||
|
run: git checkout ${{ github.base_ref }}
|
||||||
|
|
||||||
|
- name: Run pyrefly report on base branch
|
||||||
|
run: |
|
||||||
|
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
|
||||||
|
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
|
||||||
|
echo '{}' > /tmp/pyrefly_report_base.json
|
||||||
|
|
||||||
|
- name: Generate coverage comparison
|
||||||
|
id: coverage
|
||||||
|
run: |
|
||||||
|
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
|
||||||
|
--base /tmp/pyrefly_report_base.json \
|
||||||
|
< /tmp/pyrefly_report_pr.json)"
|
||||||
|
|
||||||
|
{
|
||||||
|
echo "### Pyrefly Type Coverage"
|
||||||
|
echo ""
|
||||||
|
echo "$comment_body"
|
||||||
|
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
|
||||||
|
|
||||||
|
# Save structured data for the fork-PR comment workflow
|
||||||
|
cp /tmp/pyrefly_report_pr.json pr_report.json
|
||||||
|
cp /tmp/pyrefly_report_base.json base_report.json
|
||||||
|
|
||||||
|
- name: Save PR number
|
||||||
|
run: |
|
||||||
|
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||||
|
|
||||||
|
- name: Upload type coverage artifact
|
||||||
|
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||||
|
with:
|
||||||
|
name: pyrefly_type_coverage
|
||||||
|
path: |
|
||||||
|
pr_report.json
|
||||||
|
base_report.json
|
||||||
|
pr_number.txt
|
||||||
|
|
||||||
|
- name: Comment PR with type coverage
|
||||||
|
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||||
|
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||||
|
with:
|
||||||
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
script: |
|
||||||
|
const fs = require('fs');
|
||||||
|
const marker = '### Pyrefly Type Coverage';
|
||||||
|
let body;
|
||||||
|
try {
|
||||||
|
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||||
|
} catch {
|
||||||
|
body = `${marker}\n\n_Coverage report unavailable._`;
|
||||||
|
}
|
||||||
|
const prNumber = context.payload.pull_request.number;
|
||||||
|
|
||||||
|
// Update existing comment if one exists, otherwise create new
|
||||||
|
const { data: comments } = await github.rest.issues.listComments({
|
||||||
|
issue_number: prNumber,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
});
|
||||||
|
const existing = comments.find(c => c.body.startsWith(marker));
|
||||||
|
|
||||||
|
if (existing) {
|
||||||
|
await github.rest.issues.updateComment({
|
||||||
|
comment_id: existing.id,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
await github.rest.issues.createComment({
|
||||||
|
issue_number: prNumber,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body,
|
||||||
|
});
|
||||||
|
}
|
||||||
@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
|
|||||||
## Getting Help
|
## Getting Help
|
||||||
|
|
||||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||||
|
|
||||||
## Automated Agent Contributions
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import base64
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
|
|||||||
return
|
return
|
||||||
normalized_email = email.strip().lower()
|
normalized_email = email.strip().lower()
|
||||||
|
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
valid_password(new_password)
|
valid_password(new_password)
|
||||||
except:
|
except:
|
||||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
# generate password salt
|
# generate password salt
|
||||||
salt = secrets.token_bytes(16)
|
salt = secrets.token_bytes(16)
|
||||||
base64_salt = base64.b64encode(salt).decode()
|
base64_salt = base64.b64encode(salt).decode()
|
||||||
|
|
||||||
# encrypt password with salt
|
# encrypt password with salt
|
||||||
password_hashed = hash_password(new_password, salt)
|
password_hashed = hash_password(new_password, salt)
|
||||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||||
account.password = base64_password_hashed
|
account = db.session.merge(account)
|
||||||
account.password_salt = base64_salt
|
account.password = base64_password_hashed
|
||||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
account.password_salt = base64_salt
|
||||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
db.session.commit()
|
||||||
|
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||||
|
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
@click.command("reset-email", help="Reset the account email.")
|
@click.command("reset-email", help="Reset the account email.")
|
||||||
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
|
|||||||
return
|
return
|
||||||
normalized_new_email = new_email.strip().lower()
|
normalized_new_email = new_email.strip().lower()
|
||||||
|
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
email_validate(normalized_new_email)
|
email_validate(normalized_new_email)
|
||||||
except:
|
except:
|
||||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
account.email = normalized_new_email
|
account = db.session.merge(account)
|
||||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
account.email = normalized_new_email
|
||||||
|
db.session.commit()
|
||||||
|
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
@click.command("create-tenant", help="Create account and tenant.")
|
@click.command("create-tenant", help="Create account and tenant.")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
|
|||||||
InvalidTokenError,
|
InvalidTokenError,
|
||||||
PasswordMismatchError,
|
PasswordMismatchError,
|
||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs.helper import EmailStr, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models import Account
|
from models import Account
|
||||||
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
|
|||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
|
||||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
|
|||||||
email = register_data.get("email", "")
|
email = register_data.get("email", "")
|
||||||
normalized_email = email.lower()
|
normalized_email = email.lower()
|
||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
|
||||||
|
|
||||||
if account:
|
if account:
|
||||||
raise EmailAlreadyInUseError()
|
raise EmailAlreadyInUseError()
|
||||||
else:
|
else:
|
||||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||||
if not account:
|
if not account:
|
||||||
raise AccountNotFoundError()
|
raise AccountNotFoundError()
|
||||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||||
|
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
return {"result": "success", "data": token_pair.model_dump()}
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import secrets
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
|
||||||
|
|
||||||
token = AccountService.send_reset_password_email(
|
token = AccountService.send_reset_password_email(
|
||||||
account=account,
|
account=account,
|
||||||
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
password_hashed = hash_password(args.new_password, salt)
|
password_hashed = hash_password(args.new_password, salt)
|
||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
|
||||||
|
|
||||||
if account:
|
if account:
|
||||||
self._update_existing_account(account, password_hashed, salt, session)
|
account = db.session.merge(account)
|
||||||
else:
|
self._update_existing_account(account, password_hashed, salt)
|
||||||
raise AccountNotFound()
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
raise AccountNotFound()
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
def _update_existing_account(self, account, password_hashed, salt):
|
||||||
# Update existing account credentials
|
# Update existing account credentials
|
||||||
account.password = base64.b64encode(password_hashed).decode()
|
account.password = base64.b64encode(password_hashed).decode()
|
||||||
account.password_salt = base64.b64encode(salt).decode()
|
account.password_salt = base64.b64encode(salt).decode()
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import urllib.parse
|
|||||||
import httpx
|
import httpx
|
||||||
from flask import current_app, redirect, request
|
from flask import current_app, redirect, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
|||||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|||||||
@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, external_knowledge_api_id):
|
def get(self, external_knowledge_api_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||||
|
|
||||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||||
external_knowledge_api_id
|
external_knowledge_api_id, current_tenant_id
|
||||||
)
|
)
|
||||||
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from flask_restx import Resource, fields, marshal_with
|
|||||||
from graphon.file import helpers as file_helpers
|
from graphon.file import helpers as file_helpers
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
@ -580,8 +579,7 @@ class ChangeEmailSendEmailApi(Resource):
|
|||||||
|
|
||||||
user_email = current_user.email
|
user_email = current_user.email
|
||||||
else:
|
else:
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
|
||||||
if account is None:
|
if account is None:
|
||||||
raise AccountNotFound()
|
raise AccountNotFound()
|
||||||
email_for_sending = account.email
|
email_for_sending = account.email
|
||||||
|
|||||||
@ -3,7 +3,6 @@ import secrets
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from controllers.common.schema import register_schema_models
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(request_email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
|
||||||
token = None
|
|
||||||
if account is None:
|
if account is None:
|
||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
else:
|
else:
|
||||||
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
|
|
||||||
with sessionmaker(db.engine).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
|
||||||
|
|
||||||
if account:
|
if account:
|
||||||
self._update_existing_account(account, password_hashed, salt)
|
account = db.session.merge(account)
|
||||||
else:
|
self._update_existing_account(account, password_hashed, salt)
|
||||||
raise AuthenticationFailedError()
|
db.session.commit()
|
||||||
|
else:
|
||||||
|
raise AuthenticationFailedError()
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from graphon.file import File, FileUploadConfig
|
from graphon.file import File, FileUploadConfig
|
||||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||||
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
|
|||||||
extras: dict[str, Any] = Field(default_factory=dict)
|
extras: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
# tracing instance
|
# tracing instance
|
||||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
|
||||||
|
|
||||||
|
|
||||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Any, Literal, TypedDict
|
||||||
|
|
||||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||||
|
|
||||||
|
|
||||||
class DatasourceApiEntity(BaseModel):
|
class DatasourceApiEntity(BaseModel):
|
||||||
@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel):
|
|||||||
output_schema: dict | None = None
|
output_schema: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderApiEntityDict(TypedDict):
|
||||||
|
id: str
|
||||||
|
author: str
|
||||||
|
name: str
|
||||||
|
plugin_id: str | None
|
||||||
|
plugin_unique_identifier: str | None
|
||||||
|
description: I18nObjectDict
|
||||||
|
icon: str | dict
|
||||||
|
label: I18nObjectDict
|
||||||
|
type: str
|
||||||
|
team_credentials: dict | None
|
||||||
|
is_team_authorization: bool
|
||||||
|
allow_delete: bool
|
||||||
|
datasources: list[Any]
|
||||||
|
labels: list[str]
|
||||||
|
|
||||||
|
|
||||||
class DatasourceProviderApiEntity(BaseModel):
|
class DatasourceProviderApiEntity(BaseModel):
|
||||||
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||||||
def convert_none_to_empty_list(cls, v):
|
def convert_none_to_empty_list(cls, v):
|
||||||
return v if v is not None else []
|
return v if v is not None else []
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> DatasourceProviderApiEntityDict:
|
||||||
# -------------
|
# -------------
|
||||||
# overwrite datasource parameter types for temp fix
|
# overwrite datasource parameter types for temp fix
|
||||||
datasources = jsonable_encoder(self.datasources)
|
datasources = jsonable_encoder(self.datasources)
|
||||||
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||||||
parameter["type"] = "files"
|
parameter["type"] = "files"
|
||||||
# -------------
|
# -------------
|
||||||
|
|
||||||
return {
|
result: DatasourceProviderApiEntityDict = {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"author": self.author,
|
"author": self.author,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||||||
"datasources": datasources,
|
"datasources": datasources,
|
||||||
"labels": self.labels,
|
"labels": self.labels,
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|||||||
@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
|
|||||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||||
raise ValueError("unexpected message type")
|
raise ValueError("unexpected message type")
|
||||||
|
|
||||||
# FIXME: should do a type check here.
|
if not isinstance(message.message.blob, bytes):
|
||||||
assert isinstance(message.message.blob, bytes)
|
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||||
tool_file_manager = ToolFileManager()
|
tool_file_manager = ToolFileManager()
|
||||||
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
|
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||||||
logger.exception("Authentication retry failed")
|
logger.exception("Authentication retry failed")
|
||||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||||
|
|
||||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
"""
|
"""
|
||||||
Execute a function with authentication retry logic.
|
Execute a function with authentication retry logic.
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
|
|||||||
|
|
||||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||||
|
|
||||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
|
||||||
LifespanContextT = TypeVar("LifespanContextT")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
|
class RequestContext[SessionT: BaseSession, LifespanContextT]:
|
||||||
request_id: RequestId
|
request_id: RequestId
|
||||||
meta: RequestParams.Meta | None
|
meta: RequestParams.Meta | None
|
||||||
session: SessionT
|
session: SessionT
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
|||||||
|
|
||||||
request: ReceiveRequestT
|
request: ReceiveRequestT
|
||||||
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
|
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
|
||||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
|||||||
request_meta: RequestParams.Meta | None,
|
request_meta: RequestParams.Meta | None,
|
||||||
request: ReceiveRequestT,
|
request: ReceiveRequestT,
|
||||||
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
|
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
|
||||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
|
||||||
):
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.request_meta = request_meta
|
self.request_meta = request_meta
|
||||||
|
|||||||
@ -31,7 +31,6 @@ ProgressToken = str | int
|
|||||||
Cursor = str
|
Cursor = str
|
||||||
Role = Literal["user", "assistant"]
|
Role = Literal["user", "assistant"]
|
||||||
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
||||||
type AnyFunction = Callable[..., Any]
|
|
||||||
|
|
||||||
|
|
||||||
class RequestParams(BaseModel):
|
class RequestParams(BaseModel):
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
|
|||||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
@ -172,10 +172,10 @@ class ModelInstance:
|
|||||||
function=self.model_type_instance.invoke,
|
function=self.model_type_instance.invoke,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
credentials=self.credentials,
|
credentials=self.credentials,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=list(prompt_messages),
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
tools=tools,
|
tools=list(tools) if tools else None,
|
||||||
stop=stop,
|
stop=list(stop) if stop else None,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
),
|
),
|
||||||
@ -193,15 +193,12 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||||
raise Exception("Model type instance is not LargeLanguageModel")
|
raise Exception("Model type instance is not LargeLanguageModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
int,
|
function=self.model_type_instance.get_num_tokens,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.get_num_tokens,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
prompt_messages=list(prompt_messages),
|
||||||
credentials=self.credentials,
|
tools=list(tools) if tools else None,
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
tools=tools,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_text_embedding(
|
def invoke_text_embedding(
|
||||||
@ -216,15 +213,12 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
EmbeddingResult,
|
function=self.model_type_instance.invoke,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.invoke,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
texts=texts,
|
||||||
credentials=self.credentials,
|
input_type=input_type,
|
||||||
texts=texts,
|
|
||||||
input_type=input_type,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_multimodal_embedding(
|
def invoke_multimodal_embedding(
|
||||||
@ -241,15 +235,12 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
EmbeddingResult,
|
function=self.model_type_instance.invoke,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.invoke,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
multimodel_documents=multimodel_documents,
|
||||||
credentials=self.credentials,
|
input_type=input_type,
|
||||||
multimodel_documents=multimodel_documents,
|
|
||||||
input_type=input_type,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||||
@ -261,14 +252,11 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
list[int],
|
function=self.model_type_instance.get_num_tokens,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.get_num_tokens,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
texts=texts,
|
||||||
credentials=self.credentials,
|
|
||||||
texts=texts,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_rerank(
|
def invoke_rerank(
|
||||||
@ -289,23 +277,20 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, RerankModel):
|
if not isinstance(self.model_type_instance, RerankModel):
|
||||||
raise Exception("Model type instance is not RerankModel")
|
raise Exception("Model type instance is not RerankModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
RerankResult,
|
function=self.model_type_instance.invoke,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.invoke,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
query=query,
|
||||||
credentials=self.credentials,
|
docs=docs,
|
||||||
query=query,
|
score_threshold=score_threshold,
|
||||||
docs=docs,
|
top_n=top_n,
|
||||||
score_threshold=score_threshold,
|
|
||||||
top_n=top_n,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_multimodal_rerank(
|
def invoke_multimodal_rerank(
|
||||||
self,
|
self,
|
||||||
query: dict,
|
query: MultimodalRerankInput,
|
||||||
docs: list[dict],
|
docs: list[MultimodalRerankInput],
|
||||||
score_threshold: float | None = None,
|
score_threshold: float | None = None,
|
||||||
top_n: int | None = None,
|
top_n: int | None = None,
|
||||||
) -> RerankResult:
|
) -> RerankResult:
|
||||||
@ -320,17 +305,14 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, RerankModel):
|
if not isinstance(self.model_type_instance, RerankModel):
|
||||||
raise Exception("Model type instance is not RerankModel")
|
raise Exception("Model type instance is not RerankModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
RerankResult,
|
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
query=query,
|
||||||
credentials=self.credentials,
|
docs=docs,
|
||||||
query=query,
|
score_threshold=score_threshold,
|
||||||
docs=docs,
|
top_n=top_n,
|
||||||
score_threshold=score_threshold,
|
|
||||||
top_n=top_n,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_moderation(self, text: str) -> bool:
|
def invoke_moderation(self, text: str) -> bool:
|
||||||
@ -342,14 +324,11 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, ModerationModel):
|
if not isinstance(self.model_type_instance, ModerationModel):
|
||||||
raise Exception("Model type instance is not ModerationModel")
|
raise Exception("Model type instance is not ModerationModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
bool,
|
function=self.model_type_instance.invoke,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.invoke,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
text=text,
|
||||||
credentials=self.credentials,
|
|
||||||
text=text,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||||
@ -361,14 +340,11 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||||
raise Exception("Model type instance is not Speech2TextModel")
|
raise Exception("Model type instance is not Speech2TextModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
str,
|
function=self.model_type_instance.invoke,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.invoke,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
file=file,
|
||||||
credentials=self.credentials,
|
|
||||||
file=file,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||||
@ -381,18 +357,15 @@ class ModelInstance:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(self.model_type_instance, TTSModel):
|
if not isinstance(self.model_type_instance, TTSModel):
|
||||||
raise Exception("Model type instance is not TTSModel")
|
raise Exception("Model type instance is not TTSModel")
|
||||||
return cast(
|
return self._round_robin_invoke(
|
||||||
Iterable[bytes],
|
function=self.model_type_instance.invoke,
|
||||||
self._round_robin_invoke(
|
model=self.model_name,
|
||||||
function=self.model_type_instance.invoke,
|
credentials=self.credentials,
|
||||||
model=self.model_name,
|
content_text=content_text,
|
||||||
credentials=self.credentials,
|
voice=voice,
|
||||||
content_text=content_text,
|
|
||||||
voice=voice,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
"""
|
"""
|
||||||
Round-robin invoke
|
Round-robin invoke
|
||||||
:param function: function to invoke
|
:param function: function to invoke
|
||||||
@ -430,9 +403,8 @@ class ModelInstance:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "credentials" in kwargs:
|
kwargs["credentials"] = lb_config.credentials
|
||||||
del kwargs["credentials"]
|
return function(*args, **kwargs)
|
||||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
|
||||||
except InvokeRateLimitError as e:
|
except InvokeRateLimitError as e:
|
||||||
# expire in 60 seconds
|
# expire in 60 seconds
|
||||||
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Union, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
|
def _get_user(cls, user_id: str) -> EndUser | Account:
|
||||||
"""
|
"""
|
||||||
get the user by user id
|
get the user by user id
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
|
|||||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||||
from sqlalchemy import text as sql_text
|
from sqlalchemy import text as sql_text
|
||||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
|
|||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
index_name = f"{self._collection_name}_embedding_index"
|
index_name = f"{self._collection_name}_embedding_index"
|
||||||
with Session(self.client) as session:
|
with sessionmaker(bind=self.client).begin() as session:
|
||||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
||||||
session.execute(drop_statement)
|
session.execute(drop_statement)
|
||||||
create_statement = sql_text(f"""
|
create_statement = sql_text(f"""
|
||||||
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
|
|||||||
$$);
|
$$);
|
||||||
""")
|
""")
|
||||||
session.execute(index_statement)
|
session.execute(index_statement)
|
||||||
session.commit()
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
|
|||||||
self.delete_by_uuids(ids)
|
self.delete_by_uuids(ids)
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
with Session(self.client) as session:
|
with sessionmaker(bind=self.client).begin() as session:
|
||||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
with Session(self.client) as session:
|
with Session(self.client) as session:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import sqlalchemy
|
|||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||||
from sqlalchemy import text as sql_text
|
from sqlalchemy import text as sql_text
|
||||||
from sqlalchemy.orm import Session, declarative_base
|
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||||
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
|
|||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
tidb_dist_func = self._get_distance_func()
|
tidb_dist_func = self._get_distance_func()
|
||||||
with Session(self._engine) as session:
|
with sessionmaker(bind=self._engine).begin() as session:
|
||||||
session.begin()
|
|
||||||
create_statement = sql_text(f"""
|
create_statement = sql_text(f"""
|
||||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||||
id CHAR(36) PRIMARY KEY,
|
id CHAR(36) PRIMARY KEY,
|
||||||
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
|
|||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
session.execute(create_statement)
|
session.execute(create_statement)
|
||||||
session.commit()
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
with Session(self._engine) as session:
|
with sessionmaker(bind=self._engine).begin() as session:
|
||||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def _get_distance_func(self) -> str:
|
def _get_distance_func(self) -> str:
|
||||||
match self._distance_func:
|
match self._distance_func:
|
||||||
|
|||||||
@ -3,8 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from typing import Any, TypedDict, cast
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
|
|||||||
_file_access_controller = DatabaseFileAccessController()
|
_file_access_controller = DatabaseFileAccessController()
|
||||||
|
|
||||||
|
|
||||||
|
class ParagraphFormatPreviewDict(TypedDict):
|
||||||
|
chunk_structure: str
|
||||||
|
preview: list[dict[str, Any]]
|
||||||
|
total_segments: int
|
||||||
|
|
||||||
|
|
||||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||||
text_docs = ExtractProcessor.extract(
|
text_docs = ExtractProcessor.extract(
|
||||||
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||||||
keyword = Keyword(dataset)
|
keyword = Keyword(dataset)
|
||||||
keyword.add_texts(documents)
|
keyword.add_texts(documents)
|
||||||
|
|
||||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
|
||||||
if isinstance(chunks, list):
|
if isinstance(chunks, list):
|
||||||
preview = []
|
preview = []
|
||||||
for content in chunks:
|
for content in chunks:
|
||||||
preview.append({"content": content})
|
preview.append({"content": content})
|
||||||
return {
|
result: ParagraphFormatPreviewDict = {
|
||||||
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
|
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
|
||||||
"preview": preview,
|
"preview": preview,
|
||||||
"total_segments": len(chunks),
|
"total_segments": len(chunks),
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
else:
|
else:
|
||||||
raise ValueError("Chunks is not a list")
|
raise ValueError("Chunks is not a list")
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from typing import Any, TypedDict
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
|
|
||||||
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildFormatPreviewDict(TypedDict):
|
||||||
|
chunk_structure: str
|
||||||
|
parent_mode: str
|
||||||
|
preview: list[dict[str, Any]]
|
||||||
|
total_segments: int
|
||||||
|
|
||||||
|
|
||||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||||
text_docs = ExtractProcessor.extract(
|
text_docs = ExtractProcessor.extract(
|
||||||
@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||||||
if all_multimodal_documents and dataset.is_multimodal:
|
if all_multimodal_documents and dataset.is_multimodal:
|
||||||
vector.create_multimodal(all_multimodal_documents)
|
vector.create_multimodal(all_multimodal_documents)
|
||||||
|
|
||||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
|
||||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||||
preview = []
|
preview = []
|
||||||
for parent_child in parent_childs.parent_child_chunks:
|
for parent_child in parent_childs.parent_child_chunks:
|
||||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||||
return {
|
result: ParentChildFormatPreviewDict = {
|
||||||
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
|
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
|
||||||
"parent_mode": parent_childs.parent_mode,
|
"parent_mode": parent_childs.parent_mode,
|
||||||
"preview": preview,
|
"preview": preview,
|
||||||
"total_segments": len(parent_childs.parent_child_chunks),
|
"total_segments": len(parent_childs.parent_child_chunks),
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def generate_summary_preview(
|
def generate_summary_preview(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -4,8 +4,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from typing import Any, TypedDict
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QAFormatPreviewDict(TypedDict):
|
||||||
|
chunk_structure: str
|
||||||
|
qa_preview: list[dict[str, Any]]
|
||||||
|
total_segments: int
|
||||||
|
|
||||||
|
|
||||||
class QAIndexProcessor(BaseIndexProcessor):
|
class QAIndexProcessor(BaseIndexProcessor):
|
||||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||||
text_docs = ExtractProcessor.extract(
|
text_docs = ExtractProcessor.extract(
|
||||||
@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Indexing technique must be high quality.")
|
raise ValueError("Indexing technique must be high quality.")
|
||||||
|
|
||||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
|
||||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||||
preview = []
|
preview = []
|
||||||
for qa_chunk in qa_chunks.qa_chunks:
|
for qa_chunk in qa_chunks.qa_chunks:
|
||||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||||
return {
|
result: QAFormatPreviewDict = {
|
||||||
"chunk_structure": IndexStructureType.QA_INDEX,
|
"chunk_structure": IndexStructureType.QA_INDEX,
|
||||||
"qa_preview": preview,
|
"qa_preview": preview,
|
||||||
"total_segments": len(qa_chunks.qa_chunks),
|
"total_segments": len(qa_chunks.qa_chunks),
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
def generate_summary_preview(
|
def generate_summary_preview(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||||
|
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.rag.index_processor.constant.doc_type import DocType
|
from core.rag.index_processor.constant.doc_type import DocType
|
||||||
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
|
|||||||
:param query_type: query type
|
:param query_type: query type
|
||||||
:return: rerank result
|
:return: rerank result
|
||||||
"""
|
"""
|
||||||
docs = []
|
docs: list[MultimodalRerankInput] = []
|
||||||
doc_ids = set()
|
doc_ids = set()
|
||||||
unique_documents = []
|
unique_documents = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
|
|||||||
if upload_file:
|
if upload_file:
|
||||||
blob = storage.load_once(upload_file.key)
|
blob = storage.load_once(upload_file.key)
|
||||||
document_file_base64 = base64.b64encode(blob).decode()
|
document_file_base64 = base64.b64encode(blob).decode()
|
||||||
document_file_dict = {
|
docs.append(
|
||||||
"content": document_file_base64,
|
MultimodalRerankInput(
|
||||||
"content_type": document.metadata["doc_type"],
|
content=document_file_base64,
|
||||||
}
|
content_type=document.metadata["doc_type"],
|
||||||
docs.append(document_file_dict)
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
document_text_dict = {
|
docs.append(
|
||||||
"content": document.page_content,
|
MultimodalRerankInput(
|
||||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
content=document.page_content,
|
||||||
}
|
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||||
docs.append(document_text_dict)
|
)
|
||||||
|
)
|
||||||
doc_ids.add(document.metadata["doc_id"])
|
doc_ids.add(document.metadata["doc_id"])
|
||||||
unique_documents.append(document)
|
unique_documents.append(document)
|
||||||
elif document.provider == "external":
|
elif document.provider == "external":
|
||||||
if document not in unique_documents:
|
if document not in unique_documents:
|
||||||
docs.append(
|
docs.append(
|
||||||
{
|
MultimodalRerankInput(
|
||||||
"content": document.page_content,
|
content=document.page_content,
|
||||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
unique_documents.append(document)
|
unique_documents.append(document)
|
||||||
|
|
||||||
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
|
|||||||
if upload_file:
|
if upload_file:
|
||||||
blob = storage.load_once(upload_file.key)
|
blob = storage.load_once(upload_file.key)
|
||||||
file_query = base64.b64encode(blob).decode()
|
file_query = base64.b64encode(blob).decode()
|
||||||
file_query_dict = {
|
file_query_input = MultimodalRerankInput(
|
||||||
"content": file_query,
|
content=file_query,
|
||||||
"content_type": DocType.IMAGE,
|
content_type=DocType.IMAGE,
|
||||||
}
|
)
|
||||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||||
)
|
)
|
||||||
return rerank_result, unique_documents
|
return rerank_result, unique_documents
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from mimetypes import guess_extension, guess_type
|
from mimetypes import guess_extension, guess_type
|
||||||
from typing import Union
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -158,7 +157,7 @@ class ToolFileManager:
|
|||||||
|
|
||||||
return tool_file
|
return tool_file
|
||||||
|
|
||||||
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
|
def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
|
||||||
"""
|
"""
|
||||||
get file binary
|
get file binary
|
||||||
|
|
||||||
@ -176,7 +175,7 @@ class ToolFileManager:
|
|||||||
|
|
||||||
return blob, tool_file.mimetype
|
return blob, tool_file.mimetype
|
||||||
|
|
||||||
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
|
def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
|
||||||
"""
|
"""
|
||||||
get file binary
|
get file binary
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import time
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from os import listdir, path
|
from os import listdir, path
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
|
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from graphon.runtime import VariablePool
|
from graphon.runtime import VariablePool
|
||||||
@ -100,7 +100,7 @@ class ToolManager:
|
|||||||
_builtin_provider_lock = Lock()
|
_builtin_provider_lock = Lock()
|
||||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||||
_builtin_providers_loaded = False
|
_builtin_providers_loaded = False
|
||||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
_builtin_tools_labels: dict[str, I18nObject | None] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||||
@ -190,7 +190,7 @@ class ToolManager:
|
|||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||||
credential_id: str | None = None,
|
credential_id: str | None = None,
|
||||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
|
||||||
"""
|
"""
|
||||||
get the tool runtime
|
get the tool runtime
|
||||||
|
|
||||||
@ -398,7 +398,7 @@ class ToolManager:
|
|||||||
agent_tool: AgentToolEntity,
|
agent_tool: AgentToolEntity,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
variable_pool: Optional["VariablePool"] = None,
|
variable_pool: "VariablePool | None" = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get the agent tool runtime
|
get the agent tool runtime
|
||||||
@ -442,7 +442,7 @@ class ToolManager:
|
|||||||
workflow_tool: WorkflowToolRuntimeSpec,
|
workflow_tool: WorkflowToolRuntimeSpec,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
variable_pool: Optional["VariablePool"] = None,
|
variable_pool: "VariablePool | None" = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get the workflow tool runtime
|
get the workflow tool runtime
|
||||||
@ -634,7 +634,7 @@ class ToolManager:
|
|||||||
cls._builtin_providers_loaded = False
|
cls._builtin_providers_loaded = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
def get_tool_label(cls, tool_name: str) -> I18nObject | None:
|
||||||
"""
|
"""
|
||||||
get the tool label
|
get the tool label
|
||||||
|
|
||||||
@ -993,7 +993,7 @@ class ToolManager:
|
|||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
|
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
mcp_service = MCPToolManageService(session=session)
|
mcp_service = MCPToolManageService(session=session)
|
||||||
@ -1001,7 +1001,7 @@ class ToolManager:
|
|||||||
mcp_provider = mcp_service.get_provider_entity(
|
mcp_provider = mcp_service.get_provider_entity(
|
||||||
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||||
)
|
)
|
||||||
return mcp_provider.provider_icon
|
return cast(EmojiIconDict | str, mcp_provider.provider_icon)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1013,7 +1013,7 @@ class ToolManager:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
provider_type: ToolProviderType,
|
provider_type: ToolProviderType,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
) -> str | EmojiIconDict | dict[str, str]:
|
) -> str | EmojiIconDict:
|
||||||
"""
|
"""
|
||||||
get the tool icon
|
get the tool icon
|
||||||
|
|
||||||
@ -1052,7 +1052,7 @@ class ToolManager:
|
|||||||
def _convert_tool_parameters_type(
|
def _convert_tool_parameters_type(
|
||||||
cls,
|
cls,
|
||||||
parameters: list[ToolParameter],
|
parameters: list[ToolParameter],
|
||||||
variable_pool: Optional["VariablePool"],
|
variable_pool: "VariablePool | None",
|
||||||
tool_configurations: Mapping[str, Any],
|
tool_configurations: Mapping[str, Any],
|
||||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
|||||||
@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
|
|||||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||||
raise ValueError("unexpected message type")
|
raise ValueError("unexpected message type")
|
||||||
|
|
||||||
assert isinstance(message.message.blob, bytes)
|
if not isinstance(message.message.blob, bytes):
|
||||||
|
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||||
tool_file_manager = ToolFileManager()
|
tool_file_manager = ToolFileManager()
|
||||||
tool_file = tool_file_manager.create_file_by_raw(
|
tool_file = tool_file_manager.create_file_by_raw(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
|
|||||||
from redis.connection import Connection, SSLConnection
|
from redis.connection import Connection, SSLConnection
|
||||||
from redis.retry import Retry
|
from redis.retry import Retry
|
||||||
from redis.sentinel import Sentinel
|
from redis.sentinel import Sentinel
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
@ -126,6 +127,35 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
|
|||||||
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RedisSSLParamsDict(TypedDict):
|
||||||
|
ssl_cert_reqs: int
|
||||||
|
ssl_ca_certs: str | None
|
||||||
|
ssl_certfile: str | None
|
||||||
|
ssl_keyfile: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class RedisHealthParamsDict(TypedDict):
|
||||||
|
retry: Retry
|
||||||
|
socket_timeout: float | None
|
||||||
|
socket_connect_timeout: float | None
|
||||||
|
health_check_interval: int | None
|
||||||
|
|
||||||
|
|
||||||
|
class RedisBaseParamsDict(TypedDict):
|
||||||
|
username: str | None
|
||||||
|
password: str | None
|
||||||
|
db: int
|
||||||
|
encoding: str
|
||||||
|
encoding_errors: str
|
||||||
|
decode_responses: bool
|
||||||
|
protocol: int
|
||||||
|
cache_config: CacheConfig | None
|
||||||
|
retry: Retry
|
||||||
|
socket_timeout: float | None
|
||||||
|
socket_connect_timeout: float | None
|
||||||
|
health_check_interval: int | None
|
||||||
|
|
||||||
|
|
||||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||||
"""Get SSL configuration for Redis connection."""
|
"""Get SSL configuration for Redis connection."""
|
||||||
if not dify_config.REDIS_USE_SSL:
|
if not dify_config.REDIS_USE_SSL:
|
||||||
@ -171,14 +201,14 @@ def _get_retry_policy() -> Retry:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_connection_health_params() -> dict[str, Any]:
|
def _get_connection_health_params() -> RedisHealthParamsDict:
|
||||||
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
|
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
|
||||||
return {
|
return RedisHealthParamsDict(
|
||||||
"retry": _get_retry_policy(),
|
retry=_get_retry_policy(),
|
||||||
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
|
socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
|
||||||
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||||
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||||
@ -189,26 +219,26 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
|
|||||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||||
are passed through.
|
are passed through.
|
||||||
"""
|
"""
|
||||||
params = _get_connection_health_params()
|
params: dict[str, Any] = dict(_get_connection_health_params())
|
||||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||||
|
|
||||||
|
|
||||||
def _get_base_redis_params() -> dict[str, Any]:
|
def _get_base_redis_params() -> RedisBaseParamsDict:
|
||||||
"""Get base Redis connection parameters including retry and health policy."""
|
"""Get base Redis connection parameters including retry and health policy."""
|
||||||
return {
|
return RedisBaseParamsDict(
|
||||||
"username": dify_config.REDIS_USERNAME,
|
username=dify_config.REDIS_USERNAME,
|
||||||
"password": dify_config.REDIS_PASSWORD or None,
|
password=dify_config.REDIS_PASSWORD or None,
|
||||||
"db": dify_config.REDIS_DB,
|
db=dify_config.REDIS_DB,
|
||||||
"encoding": "utf-8",
|
encoding="utf-8",
|
||||||
"encoding_errors": "strict",
|
encoding_errors="strict",
|
||||||
"decode_responses": False,
|
decode_responses=False,
|
||||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||||
"cache_config": _get_cache_configuration(),
|
cache_config=_get_cache_configuration(),
|
||||||
**_get_connection_health_params(),
|
**_get_connection_health_params(),
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||||
"""Create Redis client using Sentinel configuration."""
|
"""Create Redis client using Sentinel configuration."""
|
||||||
if not dify_config.REDIS_SENTINELS:
|
if not dify_config.REDIS_SENTINELS:
|
||||||
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
||||||
@ -232,7 +262,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
|
|||||||
sentinel_kwargs=sentinel_kwargs,
|
sentinel_kwargs=sentinel_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
params: dict[str, Any] = {**redis_params}
|
||||||
|
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
|
||||||
return master
|
return master
|
||||||
|
|
||||||
|
|
||||||
@ -259,18 +290,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
|||||||
return cluster
|
return cluster
|
||||||
|
|
||||||
|
|
||||||
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||||
"""Create standalone Redis client."""
|
"""Create standalone Redis client."""
|
||||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||||
|
|
||||||
params = {**redis_params}
|
params: dict[str, Any] = {
|
||||||
params.update(
|
**redis_params,
|
||||||
{
|
"host": dify_config.REDIS_HOST,
|
||||||
"host": dify_config.REDIS_HOST,
|
"port": dify_config.REDIS_PORT,
|
||||||
"port": dify_config.REDIS_PORT,
|
"connection_class": connection_class,
|
||||||
"connection_class": connection_class,
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||||
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||||
@ -293,8 +322,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
|
|||||||
kwargs["max_connections"] = max_conns
|
kwargs["max_connections"] = max_conns
|
||||||
return RedisCluster.from_url(pubsub_url, **kwargs)
|
return RedisCluster.from_url(pubsub_url, **kwargs)
|
||||||
|
|
||||||
health_params = _get_connection_health_params()
|
standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
|
||||||
kwargs = {**health_params}
|
kwargs = {**standalone_health_params}
|
||||||
if max_conns:
|
if max_conns:
|
||||||
kwargs["max_connections"] = max_conns
|
kwargs["max_connections"] = max_conns
|
||||||
return redis.Redis.from_url(pubsub_url, **kwargs)
|
return redis.Redis.from_url(pubsub_url, **kwargs)
|
||||||
|
|||||||
@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
|
|||||||
handler = _get_handler_instance(handler_class or SpanHandler)
|
handler = _get_handler_instance(handler_class or SpanHandler)
|
||||||
tracer = get_tracer(__name__)
|
tracer = get_tracer(__name__)
|
||||||
|
|
||||||
return handler.wrapper(
|
return handler.wrapper(tracer, func, *args, **kwargs)
|
||||||
tracer=tracer,
|
|
||||||
wrapped=func,
|
|
||||||
args=args,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(Callable[P, R], wrapper)
|
return cast(Callable[P, R], wrapper)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||||
|
|
||||||
|
|
||||||
class SpanHandler:
|
class SpanHandler:
|
||||||
@ -16,9 +16,9 @@ class SpanHandler:
|
|||||||
exceptions. Handlers can override the wrapper method to customize behavior.
|
exceptions. Handlers can override the wrapper method to customize behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
|
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
|
||||||
|
|
||||||
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
|
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
|
||||||
"""
|
"""
|
||||||
Build the span name from the wrapped function.
|
Build the span name from the wrapped function.
|
||||||
|
|
||||||
@ -29,11 +29,11 @@ class SpanHandler:
|
|||||||
"""
|
"""
|
||||||
return f"{wrapped.__module__}.{wrapped.__qualname__}"
|
return f"{wrapped.__module__}.{wrapped.__qualname__}"
|
||||||
|
|
||||||
def _extract_arguments[T](
|
def _extract_arguments[**P, R](
|
||||||
self,
|
self,
|
||||||
wrapped: Callable[..., T],
|
wrapped: Callable[P, R],
|
||||||
args: tuple[object, ...],
|
*args: P.args,
|
||||||
kwargs: Mapping[str, object],
|
**kwargs: P.kwargs,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Extract function arguments using inspect.signature.
|
Extract function arguments using inspect.signature.
|
||||||
@ -59,13 +59,13 @@ class SpanHandler:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def wrapper[T](
|
def wrapper[**P, R](
|
||||||
self,
|
self,
|
||||||
tracer: Any,
|
tracer: Tracer,
|
||||||
wrapped: Callable[..., T],
|
wrapped: Callable[P, R],
|
||||||
args: tuple[object, ...],
|
*args: P.args,
|
||||||
kwargs: Mapping[str, object],
|
**kwargs: P.kwargs,
|
||||||
) -> T:
|
) -> R:
|
||||||
"""
|
"""
|
||||||
Fully control the wrapper behavior.
|
Fully control the wrapper behavior.
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||||
from opentelemetry.util.types import AttributeValue
|
from opentelemetry.util.types import AttributeValue
|
||||||
|
|
||||||
from extensions.otel.decorators.handler import SpanHandler
|
from extensions.otel.decorators.handler import SpanHandler
|
||||||
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
|
|||||||
class AppGenerateHandler(SpanHandler):
|
class AppGenerateHandler(SpanHandler):
|
||||||
"""Span handler for ``AppGenerateService.generate``."""
|
"""Span handler for ``AppGenerateService.generate``."""
|
||||||
|
|
||||||
def wrapper[T](
|
def wrapper[**P, R](
|
||||||
self,
|
self,
|
||||||
tracer: Any,
|
tracer: Tracer,
|
||||||
wrapped: Callable[..., T],
|
wrapped: Callable[P, R],
|
||||||
args: tuple[object, ...],
|
*args: P.args,
|
||||||
kwargs: Mapping[str, object],
|
**kwargs: P.kwargs,
|
||||||
) -> T:
|
) -> R:
|
||||||
try:
|
try:
|
||||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||||
if not arguments:
|
if not arguments:
|
||||||
return wrapped(*args, **kwargs)
|
return wrapped(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||||
from opentelemetry.util.types import AttributeValue
|
from opentelemetry.util.types import AttributeValue
|
||||||
|
|
||||||
from extensions.otel.decorators.handler import SpanHandler
|
from extensions.otel.decorators.handler import SpanHandler
|
||||||
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
|
|||||||
class WorkflowAppRunnerHandler(SpanHandler):
|
class WorkflowAppRunnerHandler(SpanHandler):
|
||||||
"""Span handler for ``WorkflowAppRunner.run``."""
|
"""Span handler for ``WorkflowAppRunner.run``."""
|
||||||
|
|
||||||
def wrapper(
|
def wrapper[**P, R](
|
||||||
self,
|
self,
|
||||||
tracer: Any,
|
tracer: Tracer,
|
||||||
wrapped: Callable[..., Any],
|
wrapped: Callable[P, R],
|
||||||
args: tuple[Any, ...],
|
*args: P.args,
|
||||||
kwargs: Mapping[str, Any],
|
**kwargs: P.kwargs,
|
||||||
) -> Any:
|
) -> R:
|
||||||
try:
|
try:
|
||||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||||
if not arguments:
|
if not arguments:
|
||||||
return wrapped(*args, **kwargs)
|
return wrapped(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -14,9 +14,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import redis
|
||||||
|
from redis.cluster import RedisCluster
|
||||||
from redis.exceptions import LockNotOwnedError, RedisError
|
from redis.exceptions import LockNotOwnedError, RedisError
|
||||||
|
from redis.lock import Lock
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from extensions.ext_redis import RedisClientWrapper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
|
|||||||
primary error/exit code.
|
primary error/exit code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_redis_client: Any
|
_redis_client: redis.Redis | RedisCluster | RedisClientWrapper
|
||||||
_name: str
|
_name: str
|
||||||
_ttl_seconds: float
|
_ttl_seconds: float
|
||||||
_renew_interval_seconds: float
|
_renew_interval_seconds: float
|
||||||
_log_context: str | None
|
_log_context: str | None
|
||||||
_logger: logging.Logger
|
_logger: logging.Logger
|
||||||
|
|
||||||
_lock: Any
|
_lock: Lock | None
|
||||||
_stop_event: threading.Event | None
|
_stop_event: threading.Event | None
|
||||||
_thread: threading.Thread | None
|
_thread: threading.Thread | None
|
||||||
_acquired: bool
|
_acquired: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_client: Any,
|
redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
|
||||||
name: str,
|
name: str,
|
||||||
ttl_seconds: float = 60,
|
ttl_seconds: float = 60,
|
||||||
renew_interval_seconds: float | None = None,
|
renew_interval_seconds: float | None = None,
|
||||||
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
|
|||||||
)
|
)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
|
||||||
while not stop_event.wait(self._renew_interval_seconds):
|
while not stop_event.wait(self._renew_interval_seconds):
|
||||||
try:
|
try:
|
||||||
lock.reacquire()
|
lock.reacquire()
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import uuid
|
|||||||
from collections.abc import Callable, Generator, Mapping
|
from collections.abc import Callable, Generator, Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from zoneinfo import available_timezones
|
from zoneinfo import available_timezones
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
|
|||||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
|
|
||||||
|
|
||||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
def extract_tenant_id(user: "Account | EndUser") -> str | None:
|
||||||
"""
|
"""
|
||||||
Extract tenant_id from Account or EndUser object.
|
Extract tenant_id from Account or EndUser object.
|
||||||
|
|
||||||
@ -164,7 +164,10 @@ def email(email):
|
|||||||
EmailStr = Annotated[str, AfterValidator(email)]
|
EmailStr = Annotated[str, AfterValidator(email)]
|
||||||
|
|
||||||
|
|
||||||
def uuid_value(value: Any) -> str:
|
def uuid_value(value: str | UUID) -> str:
|
||||||
|
if isinstance(value, UUID):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
if value == "":
|
if value == "":
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
@ -405,7 +408,7 @@ class TokenManager:
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
cls,
|
cls,
|
||||||
token_type: str,
|
token_type: str,
|
||||||
account: Optional["Account"] = None,
|
account: "Account | None" = None,
|
||||||
email: str | None = None,
|
email: str | None = None,
|
||||||
additional_data: dict | None = None,
|
additional_data: dict | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -465,9 +468,7 @@ class TokenManager:
|
|||||||
return current_token
|
return current_token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _set_current_token_for_account(
|
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
|
||||||
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
|
|
||||||
):
|
|
||||||
key = cls._get_account_token_key(account_id, token_type)
|
key = cls._get_account_token_key(account_id, token_type)
|
||||||
expiry_seconds = int(expiry_minutes * 60)
|
expiry_seconds = int(expiry_minutes * 60)
|
||||||
redis_client.setex(key, expiry_seconds, token)
|
redis_client.setex(key, expiry_seconds, token)
|
||||||
|
|||||||
145
api/libs/pyrefly_type_coverage.py
Normal file
145
api/libs/pyrefly_type_coverage.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
"""Helpers for generating type-coverage summaries from pyrefly report output."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class CoverageSummary(TypedDict):
|
||||||
|
n_modules: int
|
||||||
|
n_typable: int
|
||||||
|
n_typed: int
|
||||||
|
n_any: int
|
||||||
|
n_untyped: int
|
||||||
|
coverage: float
|
||||||
|
strict_coverage: float
|
||||||
|
|
||||||
|
|
||||||
|
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
|
||||||
|
|
||||||
|
_EMPTY_SUMMARY: CoverageSummary = {
|
||||||
|
"n_modules": 0,
|
||||||
|
"n_typable": 0,
|
||||||
|
"n_typed": 0,
|
||||||
|
"n_any": 0,
|
||||||
|
"n_untyped": 0,
|
||||||
|
"coverage": 0.0,
|
||||||
|
"strict_coverage": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_summary(report_json: str) -> CoverageSummary:
|
||||||
|
"""Extract the summary section from ``pyrefly report`` JSON output.
|
||||||
|
|
||||||
|
Returns an empty summary when *report_json* is empty or malformed so that
|
||||||
|
the CI workflow can degrade gracefully instead of crashing.
|
||||||
|
"""
|
||||||
|
if not report_json or not report_json.strip():
|
||||||
|
return _EMPTY_SUMMARY.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(report_json)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return _EMPTY_SUMMARY.copy()
|
||||||
|
|
||||||
|
summary = data.get("summary")
|
||||||
|
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
|
||||||
|
return _EMPTY_SUMMARY.copy()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"n_modules": summary["n_modules"],
|
||||||
|
"n_typable": summary["n_typable"],
|
||||||
|
"n_typed": summary["n_typed"],
|
||||||
|
"n_any": summary["n_any"],
|
||||||
|
"n_untyped": summary["n_untyped"],
|
||||||
|
"coverage": summary["coverage"],
|
||||||
|
"strict_coverage": summary["strict_coverage"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format_summary_markdown(summary: CoverageSummary) -> str:
|
||||||
|
"""Format a single coverage summary as a Markdown table."""
|
||||||
|
|
||||||
|
return (
|
||||||
|
"| Metric | Value |\n"
|
||||||
|
"| --- | ---: |\n"
|
||||||
|
f"| Modules | {summary['n_modules']} |\n"
|
||||||
|
f"| Typable symbols | {summary['n_typable']:,} |\n"
|
||||||
|
f"| Typed symbols | {summary['n_typed']:,} |\n"
|
||||||
|
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
|
||||||
|
f"| Any symbols | {summary['n_any']:,} |\n"
|
||||||
|
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
|
||||||
|
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_comparison_markdown(
|
||||||
|
base: CoverageSummary,
|
||||||
|
pr: CoverageSummary,
|
||||||
|
) -> str:
|
||||||
|
"""Format a comparison between base and PR coverage as Markdown."""
|
||||||
|
|
||||||
|
coverage_delta = pr["coverage"] - base["coverage"]
|
||||||
|
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
|
||||||
|
typed_delta = pr["n_typed"] - base["n_typed"]
|
||||||
|
untyped_delta = pr["n_untyped"] - base["n_untyped"]
|
||||||
|
|
||||||
|
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
|
||||||
|
sign = "+" if value > 0 else ""
|
||||||
|
return f"{sign}{value:{fmt}}"
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"| Metric | Base | PR | Delta |",
|
||||||
|
"| --- | ---: | ---: | ---: |",
|
||||||
|
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
|
||||||
|
(
|
||||||
|
f"| Strict coverage | {base['strict_coverage']:.2f}% "
|
||||||
|
f"| {pr['strict_coverage']:.2f}% "
|
||||||
|
f"| {_fmt_delta(strict_delta)}% |"
|
||||||
|
),
|
||||||
|
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
|
||||||
|
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
|
||||||
|
(
|
||||||
|
f"| Modules | {base['n_modules']} "
|
||||||
|
f"| {pr['n_modules']} "
|
||||||
|
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""Read pyrefly report JSON from stdin and print a Markdown summary.
|
||||||
|
|
||||||
|
Accepts an optional ``--base <file>`` argument. When provided, the output
|
||||||
|
includes a base-vs-PR comparison table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
base_file: str | None = None
|
||||||
|
if "--base" in args:
|
||||||
|
idx = args.index("--base")
|
||||||
|
if idx + 1 >= len(args):
|
||||||
|
sys.stderr.write("error: --base requires a file path\n")
|
||||||
|
return 1
|
||||||
|
base_file = args[idx + 1]
|
||||||
|
|
||||||
|
pr_report = sys.stdin.read()
|
||||||
|
pr_summary = parse_summary(pr_report)
|
||||||
|
|
||||||
|
if base_file is not None:
|
||||||
|
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
|
||||||
|
base_summary = parse_summary(base_text)
|
||||||
|
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
|
||||||
|
else:
|
||||||
|
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultFieldsMixin:
|
class DefaultFieldsMixin:
|
||||||
|
"""Mixin for models that inherit from Base (non-dataclass)."""
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
StringUUID,
|
StringUUID,
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
|
|||||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultFieldsDCMixin(MappedAsDataclass):
|
||||||
|
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
StringUUID,
|
||||||
|
primary_key=True,
|
||||||
|
insert_default=lambda: str(uuidv7()),
|
||||||
|
default_factory=lambda: str(uuidv7()),
|
||||||
|
init=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
insert_default=naive_utc_now,
|
||||||
|
default_factory=naive_utc_now,
|
||||||
|
init=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
insert_default=naive_utc_now,
|
||||||
|
default_factory=naive_utc_now,
|
||||||
|
init=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||||
|
|
||||||
|
|
||||||
def gen_uuidv4_string() -> str:
|
def gen_uuidv4_string() -> str:
|
||||||
"""gen_uuidv4_string generate a UUIDv4 string.
|
"""gen_uuidv4_string generate a UUIDv4 string.
|
||||||
|
|
||||||
|
|||||||
@ -913,11 +913,7 @@ class TrialApp(TypeBase):
|
|||||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime,
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
nullable=False,
|
|
||||||
insert_default=func.current_timestamp(),
|
|
||||||
server_default=func.current_timestamp(),
|
|
||||||
init=False,
|
|
||||||
)
|
)
|
||||||
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
|
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
|
||||||
|
|
||||||
@ -941,11 +937,7 @@ class AccountTrialAppRecord(TypeBase):
|
|||||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime,
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
nullable=False,
|
|
||||||
insert_default=func.current_timestamp(),
|
|
||||||
server_default=func.current_timestamp(),
|
|
||||||
init=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import logging
|
|||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
|
|||||||
raise ValueError(f"invalid workflow type value {value}")
|
raise ValueError(f"invalid workflow type value {value}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
|
||||||
"""
|
"""
|
||||||
Get workflow type from app mode.
|
Get workflow type from app mode.
|
||||||
|
|
||||||
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||||||
)
|
)
|
||||||
return extras
|
return extras
|
||||||
|
|
||||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
|
||||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
|
|||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from sqlalchemy import delete, func, select, update
|
from sqlalchemy import delete, func, select, update
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
|
|
||||||
|
|
||||||
class InvitationData(TypedDict):
|
class InvitationData(TypedDict):
|
||||||
@ -800,19 +801,19 @@ class AccountService:
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
|
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
|
||||||
"""
|
"""
|
||||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||||
|
|
||||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||||
rest of the system gradually normalizes new inputs.
|
rest of the system gradually normalizes new inputs.
|
||||||
"""
|
"""
|
||||||
query_session = session or db.session
|
with session_factory.create_session() as session:
|
||||||
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||||
if account or email == email.lower():
|
if account or email == email.lower():
|
||||||
return account
|
return account
|
||||||
|
|
||||||
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||||
@ -1516,8 +1517,7 @@ class RegisterService:
|
|||||||
|
|
||||||
check_workspace_member_invite_permission(tenant.id)
|
check_workspace_member_invite_permission(tenant.id)
|
||||||
|
|
||||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable, Generator, Mapping
|
from collections.abc import Callable, Generator, Mapping
|
||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||||
@ -88,7 +88,7 @@ class AppGenerateService:
|
|||||||
def generate(
|
def generate(
|
||||||
cls,
|
cls,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
@ -356,11 +356,11 @@ class AppGenerateService:
|
|||||||
def generate_more_like_this(
|
def generate_more_like_this(
|
||||||
cls,
|
cls,
|
||||||
app_model: App,
|
app_model: App,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
) -> Union[Mapping, Generator]:
|
) -> Mapping | Generator:
|
||||||
"""
|
"""
|
||||||
Generate more like this
|
Generate more like this
|
||||||
:param app_model: app model
|
:param app_model: app model
|
||||||
|
|||||||
@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from celery.result import AsyncResult
|
from celery.result import AsyncResult
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -50,7 +50,7 @@ class AsyncWorkflowService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def trigger_workflow_async(
|
def trigger_workflow_async(
|
||||||
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
|
cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
|
||||||
) -> AsyncTriggerResponse:
|
) -> AsyncTriggerResponse:
|
||||||
"""
|
"""
|
||||||
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
||||||
@ -177,7 +177,7 @@ class AsyncWorkflowService:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reinvoke_trigger(
|
def reinvoke_trigger(
|
||||||
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
|
cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
|
||||||
) -> AsyncTriggerResponse:
|
) -> AsyncTriggerResponse:
|
||||||
"""
|
"""
|
||||||
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
||||||
|
|||||||
@ -2822,6 +2822,10 @@ class DocumentService:
|
|||||||
|
|
||||||
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
|
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
|
||||||
|
|
||||||
|
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
|
||||||
|
if not knowledge_config.process_rule.rules.parent_mode:
|
||||||
|
knowledge_config.process_rule.rules.parent_mode = "paragraph"
|
||||||
|
|
||||||
if not knowledge_config.process_rule.rules.segmentation:
|
if not knowledge_config.process_rule.rules.segmentation:
|
||||||
raise ValueError("Process rule segmentation is required")
|
raise ValueError("Process rule segmentation is required")
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Union, cast
|
from typing import Any, cast
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -148,18 +148,23 @@ class ExternalDatasetService:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
Return usage for an external knowledge API within a single tenant.
|
||||||
|
|
||||||
|
The caller already scopes access by tenant, so this query must do the
|
||||||
|
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
|
||||||
|
"""
|
||||||
count = (
|
count = (
|
||||||
db.session.scalar(
|
db.session.scalar(
|
||||||
select(func.count(ExternalKnowledgeBindings.id)).where(
|
select(func.count(ExternalKnowledgeBindings.id)).where(
|
||||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
|
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
|
||||||
|
ExternalKnowledgeBindings.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
if count > 0:
|
return count > 0, count
|
||||||
return True, count
|
|
||||||
return False, 0
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||||
@ -190,9 +195,7 @@ class ExternalDatasetService:
|
|||||||
raise ValueError(f"{parameter.get('name')} is required")
|
raise ValueError(f"{parameter.get('name')} is required")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_external_api(
|
def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
|
||||||
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
|
|
||||||
) -> httpx.Response:
|
|
||||||
"""
|
"""
|
||||||
do http request depending on api bundle
|
do http request depending on api bundle
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import uuid
|
|||||||
from collections.abc import Iterator, Sequence
|
from collections.abc import Iterator, Sequence
|
||||||
from contextlib import contextmanager, suppress
|
from contextlib import contextmanager, suppress
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Literal, Union
|
from typing import Literal
|
||||||
from zipfile import ZIP_DEFLATED, ZipFile
|
from zipfile import ZIP_DEFLATED, ZipFile
|
||||||
|
|
||||||
from graphon.file import helpers as file_helpers
|
from graphon.file import helpers as file_helpers
|
||||||
@ -52,7 +52,7 @@ class FileService:
|
|||||||
filename: str,
|
filename: str,
|
||||||
content: bytes,
|
content: bytes,
|
||||||
mimetype: str,
|
mimetype: str,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
source: Literal["datasets"] | None = None,
|
source: Literal["datasets"] | None = None,
|
||||||
source_url: str = "",
|
source_url: str = "",
|
||||||
) -> UploadFile:
|
) -> UploadFile:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict, Union
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from graphon.model_runtime.entities.provider_entities import (
|
from graphon.model_runtime.entities.provider_entities import (
|
||||||
@ -626,7 +626,7 @@ class ModelLoadBalancingService:
|
|||||||
|
|
||||||
def _get_credential_schema(
|
def _get_credential_schema(
|
||||||
self, provider_configuration: ProviderConfiguration
|
self, provider_configuration: ProviderConfiguration
|
||||||
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
|
) -> ModelCredentialSchema | ProviderCredentialSchema:
|
||||||
"""Get form schemas."""
|
"""Get form schemas."""
|
||||||
if provider_configuration.provider.model_credential_schema:
|
if provider_configuration.provider.model_credential_schema:
|
||||||
return provider_configuration.provider.model_credential_schema
|
return provider_configuration.provider.model_credential_schema
|
||||||
|
|||||||
@ -1,14 +1,13 @@
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from core.db.session_factory import session_factory
|
||||||
from models.account import TenantPluginPermission
|
from models.account import TenantPluginPermission
|
||||||
|
|
||||||
|
|
||||||
class PluginPermissionService:
|
class PluginPermissionService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with session_factory.create_session() as session:
|
||||||
return session.scalar(
|
return session.scalar(
|
||||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||||
)
|
)
|
||||||
@ -19,7 +18,7 @@ class PluginPermissionService:
|
|||||||
install_permission: TenantPluginPermission.InstallPermission,
|
install_permission: TenantPluginPermission.InstallPermission,
|
||||||
debug_permission: TenantPluginPermission.DebugPermission,
|
debug_permission: TenantPluginPermission.DebugPermission,
|
||||||
):
|
):
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with session_factory.create_session() as session, session.begin():
|
||||||
permission = session.scalar(
|
permission = session.scalar(
|
||||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||||
@ -17,7 +17,7 @@ class PipelineGenerateService:
|
|||||||
def generate(
|
def generate(
|
||||||
cls,
|
cls,
|
||||||
pipeline: Pipeline,
|
pipeline: Pipeline,
|
||||||
user: Union[Account, EndUser],
|
user: Account | EndUser,
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Union, cast
|
from typing import Any, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
@ -1387,7 +1387,7 @@ class RagPipelineService:
|
|||||||
"uninstalled_recommended_plugins": uninstalled_plugin_list,
|
"uninstalled_recommended_plugins": uninstalled_plugin_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
|
def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser):
|
||||||
"""
|
"""
|
||||||
Retry error document
|
Retry error document
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import TypeAdapter, ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
@ -69,7 +69,7 @@ class ToolTransformService:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
|
||||||
"""
|
"""
|
||||||
repack provider
|
repack provider
|
||||||
|
|
||||||
|
|||||||
@ -7,15 +7,16 @@ with appropriate retry policies and error handling.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any, NotRequired
|
||||||
|
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from graphon.runtime import GraphRuntimeState
|
from graphon.runtime import GraphRuntimeState
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||||
from core.app.layers.timeslice_layer import TimeSliceLayer
|
from core.app.layers.timeslice_layer import TimeSliceLayer
|
||||||
@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowGeneratorArgsDict(TypedDict):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
files: list[Any]
|
||||||
|
_skip_prepare_user_inputs: bool
|
||||||
|
workflow_id: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||||
"""Execute workflow for professional tier with highest priority"""
|
"""Execute workflow for professional tier with highest priority"""
|
||||||
@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]:
|
def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict:
|
||||||
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
|
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
|
||||||
|
return {
|
||||||
args: dict[str, Any] = {
|
|
||||||
"inputs": dict(trigger_data.inputs),
|
"inputs": dict(trigger_data.inputs),
|
||||||
"files": list(trigger_data.files),
|
"files": list(trigger_data.files),
|
||||||
SKIP_PREPARE_USER_INPUTS_KEY: True,
|
"_skip_prepare_user_inputs": True,
|
||||||
}
|
}
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_workflow_common(
|
def _execute_workflow_common(
|
||||||
|
|||||||
@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
|||||||
second_result.scalar_one_or_none.return_value = expected_account
|
second_result.scalar_one_or_none.return_value = expected_account
|
||||||
mock_session.execute.side_effect = [first_result, second_result]
|
mock_session.execute.side_effect = [first_result, second_result]
|
||||||
|
|
||||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
with patch("services.account_service.session_factory") as mock_factory:
|
||||||
|
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||||
|
|
||||||
assert result is expected_account
|
assert result is expected_account
|
||||||
assert mock_session.execute.call_count == 2
|
assert mock_session.execute.call_count == 2
|
||||||
|
|||||||
@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi:
|
|||||||
class TestForgotPasswordResetApi:
|
class TestForgotPasswordResetApi:
|
||||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
|
@patch("controllers.console.auth.forgot_password.db")
|
||||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||||
def test_reset_fetches_account_with_original_email(
|
def test_reset_fetches_account_with_original_email(
|
||||||
self,
|
self,
|
||||||
mock_get_reset_data,
|
mock_get_reset_data,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
|
mock_db,
|
||||||
mock_get_account,
|
mock_get_account,
|
||||||
mock_update_account,
|
mock_update_account,
|
||||||
app,
|
app,
|
||||||
@ -126,6 +128,7 @@ class TestForgotPasswordResetApi:
|
|||||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||||
mock_account = MagicMock()
|
mock_account = MagicMock()
|
||||||
mock_get_account.return_value = mock_account
|
mock_get_account.return_value = mock_account
|
||||||
|
mock_db.session.merge.return_value = mock_account
|
||||||
|
|
||||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||||
with (
|
with (
|
||||||
@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
|||||||
second_result.scalar_one_or_none.return_value = expected_account
|
second_result.scalar_one_or_none.return_value = expected_account
|
||||||
mock_session.execute.side_effect = [first_result, second_result]
|
mock_session.execute.side_effect = [first_result, second_result]
|
||||||
|
|
||||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
with patch("services.account_service.session_factory") as mock_factory:
|
||||||
|
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||||
|
|
||||||
assert result is expected_account
|
assert result is expected_account
|
||||||
assert mock_session.execute.call_count == 2
|
assert mock_session.execute.call_count == 2
|
||||||
|
|||||||
@ -437,7 +437,10 @@ class TestAccountGeneration:
|
|||||||
second_result.scalar_one_or_none.return_value = expected_account
|
second_result.scalar_one_or_none.return_value = expected_account
|
||||||
mock_session.execute.side_effect = [first_result, second_result]
|
mock_session.execute.side_effect = [first_result, second_result]
|
||||||
|
|
||||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
with patch("services.account_service.session_factory") as mock_factory:
|
||||||
|
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||||
|
|
||||||
assert result is expected_account
|
assert result is expected_account
|
||||||
assert mock_session.execute.call_count == 2
|
assert mock_session.execute.call_count == 2
|
||||||
|
|||||||
@ -335,10 +335,12 @@ class TestForgotPasswordResetApi:
|
|||||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
|
@patch("controllers.console.auth.forgot_password.db")
|
||||||
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
||||||
def test_reset_password_success(
|
def test_reset_password_success(
|
||||||
self,
|
self,
|
||||||
mock_get_tenants,
|
mock_get_tenants,
|
||||||
|
mock_db,
|
||||||
mock_get_account,
|
mock_get_account,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_get_data,
|
mock_get_data,
|
||||||
@ -356,6 +358,7 @@ class TestForgotPasswordResetApi:
|
|||||||
# Arrange
|
# Arrange
|
||||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||||
mock_get_account.return_value = mock_account
|
mock_get_account.return_value = mock_account
|
||||||
|
mock_db.session.merge.return_value = mock_account
|
||||||
mock_get_tenants.return_value = [MagicMock()]
|
mock_get_tenants.return_value = [MagicMock()]
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|||||||
@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi:
|
|||||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||||
@patch("controllers.web.forgot_password.sessionmaker")
|
|
||||||
def test_should_normalize_email_before_sending(
|
def test_should_normalize_email_before_sending(
|
||||||
self,
|
self,
|
||||||
mock_session_cls,
|
|
||||||
mock_extract_ip,
|
mock_extract_ip,
|
||||||
mock_rate_limit,
|
mock_rate_limit,
|
||||||
mock_get_account,
|
mock_get_account,
|
||||||
@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi:
|
|||||||
mock_account = MagicMock()
|
mock_account = MagicMock()
|
||||||
mock_get_account.return_value = mock_account
|
mock_get_account.return_value = mock_account
|
||||||
mock_send_mail.return_value = "token-123"
|
mock_send_mail.return_value = "token-123"
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
|
||||||
|
|
||||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
with app.test_request_context(
|
||||||
with app.test_request_context(
|
"/web/forgot-password",
|
||||||
"/web/forgot-password",
|
method="POST",
|
||||||
method="POST",
|
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
):
|
||||||
):
|
response = ForgotPasswordSendEmailApi().post()
|
||||||
response = ForgotPasswordSendEmailApi().post()
|
|
||||||
|
|
||||||
assert response == {"result": "success", "data": "token-123"}
|
assert response == {"result": "success", "data": "token-123"}
|
||||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
mock_get_account.assert_called_once_with("User@Example.com")
|
||||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||||
mock_extract_ip.assert_called_once()
|
mock_extract_ip.assert_called_once()
|
||||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||||
@ -153,14 +148,14 @@ class TestForgotPasswordResetApi:
|
|||||||
|
|
||||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
@patch("controllers.web.forgot_password.sessionmaker")
|
@patch("controllers.web.forgot_password.db")
|
||||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||||
def test_should_fetch_account_with_fallback(
|
def test_should_fetch_account_with_fallback(
|
||||||
self,
|
self,
|
||||||
mock_get_reset_data,
|
mock_get_reset_data,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_session_cls,
|
mock_db,
|
||||||
mock_get_account,
|
mock_get_account,
|
||||||
mock_update_account,
|
mock_update_account,
|
||||||
app,
|
app,
|
||||||
@ -168,29 +163,27 @@ class TestForgotPasswordResetApi:
|
|||||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
||||||
mock_account = MagicMock()
|
mock_account = MagicMock()
|
||||||
mock_get_account.return_value = mock_account
|
mock_get_account.return_value = mock_account
|
||||||
mock_session = MagicMock()
|
mock_db.session.merge.return_value = mock_account
|
||||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
|
||||||
|
|
||||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
with app.test_request_context(
|
||||||
with app.test_request_context(
|
"/web/forgot-password/resets",
|
||||||
"/web/forgot-password/resets",
|
method="POST",
|
||||||
method="POST",
|
json={
|
||||||
json={
|
"token": "token-123",
|
||||||
"token": "token-123",
|
"new_password": "ValidPass123!",
|
||||||
"new_password": "ValidPass123!",
|
"password_confirm": "ValidPass123!",
|
||||||
"password_confirm": "ValidPass123!",
|
},
|
||||||
},
|
):
|
||||||
):
|
response = ForgotPasswordResetApi().post()
|
||||||
response = ForgotPasswordResetApi().post()
|
|
||||||
|
|
||||||
assert response == {"result": "success"}
|
assert response == {"result": "success"}
|
||||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
mock_get_account.assert_called_once_with("User@Example.com")
|
||||||
mock_update_account.assert_called_once()
|
mock_update_account.assert_called_once()
|
||||||
mock_revoke_token.assert_called_once_with("token-123")
|
mock_revoke_token.assert_called_once_with("token-123")
|
||||||
|
|
||||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||||
@patch("controllers.web.forgot_password.sessionmaker")
|
@patch("controllers.web.forgot_password.db")
|
||||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||||
@ -199,7 +192,7 @@ class TestForgotPasswordResetApi:
|
|||||||
mock_get_account,
|
mock_get_account,
|
||||||
mock_get_reset_data,
|
mock_get_reset_data,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_session_cls,
|
mock_db,
|
||||||
mock_token_bytes,
|
mock_token_bytes,
|
||||||
mock_hash_password,
|
mock_hash_password,
|
||||||
app,
|
app,
|
||||||
@ -207,20 +200,18 @@ class TestForgotPasswordResetApi:
|
|||||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
mock_get_account.return_value = account
|
mock_get_account.return_value = account
|
||||||
mock_session = MagicMock()
|
mock_db.session.merge.return_value = account
|
||||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
|
||||||
|
|
||||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
with app.test_request_context(
|
||||||
with app.test_request_context(
|
"/web/forgot-password/resets",
|
||||||
"/web/forgot-password/resets",
|
method="POST",
|
||||||
method="POST",
|
json={
|
||||||
json={
|
"token": "reset-token",
|
||||||
"token": "reset-token",
|
"new_password": "StrongPass123!",
|
||||||
"new_password": "StrongPass123!",
|
"password_confirm": "StrongPass123!",
|
||||||
"password_confirm": "StrongPass123!",
|
},
|
||||||
},
|
):
|
||||||
):
|
response = ForgotPasswordResetApi().post()
|
||||||
response = ForgotPasswordResetApi().post()
|
|
||||||
|
|
||||||
assert response == {"result": "success"}
|
assert response == {"result": "success"}
|
||||||
mock_get_reset_data.assert_called_once_with("reset-token")
|
mock_get_reset_data.assert_called_once_with("reset-token")
|
||||||
|
|||||||
@ -1,239 +1,193 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
from unittest.mock import ANY, MagicMock, patch
|
from unittest.mock import ANY, MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset, DatasetQuery
|
||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
|
|
||||||
|
|
||||||
class TestHitTestingService:
|
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
|
||||||
"""Test suite for HitTestingService"""
|
tenant_id = str(uuid4())
|
||||||
|
created_by = str(uuid4())
|
||||||
|
ds = Dataset(
|
||||||
|
tenant_id=kwargs.get("tenant_id", tenant_id),
|
||||||
|
name=kwargs.get("name", "test-dataset"),
|
||||||
|
created_by=kwargs.get("created_by", created_by),
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
db_session.add(ds)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(ds)
|
||||||
|
return ds
|
||||||
|
|
||||||
# ===== Utility Method Tests =====
|
|
||||||
|
class TestHitTestingService:
|
||||||
|
# ── Utility methods (pure logic, no DB) ────────────────────────────
|
||||||
|
|
||||||
def test_escape_query_for_search_should_escape_double_quotes(self):
|
def test_escape_query_for_search_should_escape_double_quotes(self):
|
||||||
"""Test that escape_query_for_search escapes double quotes correctly"""
|
|
||||||
# Arrange
|
|
||||||
query = 'test "query" with quotes'
|
query = 'test "query" with quotes'
|
||||||
expected = 'test \\"query\\" with quotes'
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = HitTestingService.escape_query_for_search(query)
|
result = HitTestingService.escape_query_for_search(query)
|
||||||
|
assert result == 'test \\"query\\" with quotes'
|
||||||
# Assert
|
|
||||||
assert result == expected
|
|
||||||
|
|
||||||
def test_hit_testing_args_check_should_pass_with_valid_query(self):
|
def test_hit_testing_args_check_should_pass_with_valid_query(self):
|
||||||
"""Test that hit_testing_args_check passes with a valid query"""
|
HitTestingService.hit_testing_args_check({"query": "valid query"})
|
||||||
# Arrange
|
|
||||||
args = {"query": "valid query"}
|
|
||||||
|
|
||||||
# Act & Assert (should not raise)
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
|
|
||||||
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
|
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
|
||||||
"""Test that hit_testing_args_check passes with valid attachment_ids"""
|
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
|
||||||
# Arrange
|
|
||||||
args = {"attachment_ids": ["id1", "id2"]}
|
|
||||||
|
|
||||||
# Act & Assert (should not raise)
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
|
|
||||||
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
|
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
|
||||||
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
|
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
|
||||||
# Arrange
|
HitTestingService.hit_testing_args_check({})
|
||||||
args = {}
|
|
||||||
|
|
||||||
# Act & Assert
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
assert "Query or attachment_ids is required" in str(exc_info.value)
|
|
||||||
|
|
||||||
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
|
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
|
||||||
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
|
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
|
||||||
# Arrange
|
HitTestingService.hit_testing_args_check({"query": "a" * 251})
|
||||||
args = {"query": "a" * 251}
|
|
||||||
|
|
||||||
# Act & Assert
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
assert "Query cannot exceed 250 characters" in str(exc_info.value)
|
|
||||||
|
|
||||||
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
|
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
|
||||||
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
|
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
|
||||||
# Arrange
|
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
|
||||||
args = {"attachment_ids": "not a list"}
|
|
||||||
|
|
||||||
# Act & Assert
|
# ── Response formatting ────────────────────────────────────────────
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
assert "Attachment_ids must be a list" in str(exc_info.value)
|
|
||||||
|
|
||||||
# ===== Response Formatting Tests =====
|
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
||||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
|
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
|
||||||
"""Test that compact_retrieve_response formats the response correctly"""
|
|
||||||
# Arrange
|
|
||||||
query = "test query"
|
query = "test query"
|
||||||
mock_doc = MagicMock(spec=Document)
|
mock_doc = MagicMock(spec=Document)
|
||||||
documents = [mock_doc]
|
|
||||||
|
|
||||||
mock_record = MagicMock()
|
mock_record = MagicMock()
|
||||||
mock_record.model_dump.return_value = {"content": "formatted content"}
|
mock_record.model_dump.return_value = {"content": "formatted content"}
|
||||||
mock_format.return_value = [mock_record]
|
mock_format.return_value = [mock_record]
|
||||||
|
|
||||||
# Act
|
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
|
||||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||||
assert len(result["records"]) == 1
|
assert len(result["records"]) == 1
|
||||||
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
|
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
|
||||||
mock_format.assert_called_once_with(documents)
|
mock_format.assert_called_once_with([mock_doc])
|
||||||
|
|
||||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
|
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
|
||||||
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
|
self, db_session_with_containers: Session
|
||||||
# Arrange
|
):
|
||||||
dataset = MagicMock(spec=Dataset)
|
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||||
dataset.provider = "external"
|
|
||||||
query = "test query"
|
|
||||||
documents = [
|
documents = [
|
||||||
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
|
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
|
||||||
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
|
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Act
|
result = cast(
|
||||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
|
||||||
assert len(result["records"]) == 2
|
assert len(result["records"]) == 2
|
||||||
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
|
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
|
||||||
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
|
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
|
||||||
|
|
||||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
|
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
|
||||||
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
|
self, db_session_with_containers: Session
|
||||||
# Arrange
|
):
|
||||||
dataset = MagicMock(spec=Dataset)
|
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||||
dataset.provider = "not_external"
|
|
||||||
query = "test query"
|
|
||||||
documents = [{"content": "c1"}]
|
|
||||||
|
|
||||||
# Act
|
result = cast(
|
||||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
dict[str, Any],
|
||||||
|
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
|
||||||
|
)
|
||||||
|
|
||||||
# Assert
|
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
|
||||||
assert result["records"] == []
|
assert result["records"] == []
|
||||||
|
|
||||||
# ===== External Retrieve Tests =====
|
# ── External retrieve (real DB) ────────────────────────────────────
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
|
||||||
@patch("extensions.ext_database.db.session.add")
|
def test_external_retrieve_should_succeed_for_external_provider(
|
||||||
@patch("extensions.ext_database.db.session.commit")
|
self, mock_ext_retrieve, db_session_with_containers: Session
|
||||||
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
|
):
|
||||||
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
|
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||||
# Arrange
|
account_id = str(uuid4())
|
||||||
dataset = MagicMock(spec=Dataset)
|
|
||||||
dataset.id = "dataset_id"
|
|
||||||
dataset.provider = "external"
|
|
||||||
query = 'test "query"'
|
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
account.id = "account_id"
|
account.id = account_id
|
||||||
|
|
||||||
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
|
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
|
||||||
|
|
||||||
# Act
|
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||||
|
|
||||||
result = cast(
|
result = cast(
|
||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
HitTestingService.external_retrieve(
|
HitTestingService.external_retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=query,
|
query='test "query"',
|
||||||
account=account,
|
account=account,
|
||||||
external_retrieval_model={"model": "test"},
|
external_retrieval_model={"model": "test"},
|
||||||
metadata_filtering_conditions={"key": "val"},
|
metadata_filtering_conditions={"key": "val"},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
|
||||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
|
||||||
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
|
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
|
||||||
|
|
||||||
# Verify call to RetrievalService.external_retrieve with escaped query
|
|
||||||
mock_ext_retrieve.assert_called_once_with(
|
mock_ext_retrieve.assert_called_once_with(
|
||||||
dataset_id="dataset_id",
|
dataset_id=dataset.id,
|
||||||
query='test \\"query\\"',
|
query='test \\"query\\"',
|
||||||
external_retrieval_model={"model": "test"},
|
external_retrieval_model={"model": "test"},
|
||||||
metadata_filtering_conditions={"key": "val"},
|
metadata_filtering_conditions={"key": "val"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify DatasetQuery record was added and committed
|
db_session_with_containers.expire_all()
|
||||||
mock_add.assert_called_once()
|
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||||
mock_commit.assert_called_once()
|
assert after_count == before_count + 1
|
||||||
|
|
||||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
|
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
|
||||||
"""Test that external_retrieve returns empty results immediately if provider is not external"""
|
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||||
# Arrange
|
|
||||||
dataset = MagicMock(spec=Dataset)
|
|
||||||
dataset.provider = "not_external"
|
|
||||||
query = "test query"
|
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
|
|
||||||
# Act
|
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
|
||||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
|
|
||||||
|
|
||||||
# Assert
|
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
|
||||||
assert result["records"] == []
|
assert result["records"] == []
|
||||||
|
|
||||||
# ===== Retrieve Tests =====
|
# ── Retrieve (real DB) ─────────────────────────────────────────────
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||||
@patch("extensions.ext_database.db.session.add")
|
def test_retrieve_should_use_default_model_when_none_provided(
|
||||||
@patch("extensions.ext_database.db.session.commit")
|
self, mock_retrieve, db_session_with_containers: Session
|
||||||
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
|
):
|
||||||
"""Test that retrieve uses default model when retrieval_model is not provided"""
|
dataset = _create_dataset(db_session_with_containers)
|
||||||
# Arrange
|
|
||||||
dataset = MagicMock(spec=Dataset)
|
|
||||||
dataset.id = "dataset_id"
|
|
||||||
dataset.retrieval_model = None
|
dataset.retrieval_model = None
|
||||||
query = "test query"
|
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
account.id = "account_id"
|
account.id = str(uuid4())
|
||||||
|
|
||||||
mock_retrieve.return_value = []
|
mock_retrieve.return_value = []
|
||||||
|
|
||||||
# Act
|
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||||
|
|
||||||
result = cast(
|
result = cast(
|
||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
HitTestingService.retrieve(
|
HitTestingService.retrieve(
|
||||||
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
|
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
|
||||||
mock_retrieve.assert_called_once()
|
mock_retrieve.assert_called_once()
|
||||||
# Verify top_k from default_retrieval_model (4)
|
|
||||||
assert mock_retrieve.call_args.kwargs["top_k"] == 4
|
assert mock_retrieve.call_args.kwargs["top_k"] == 4
|
||||||
mock_commit.assert_called_once()
|
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||||
|
assert after_count == before_count + 1
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||||
@patch("extensions.ext_database.db.session.add")
|
def test_retrieve_should_handle_metadata_filtering(
|
||||||
@patch("extensions.ext_database.db.session.commit")
|
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||||
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
|
):
|
||||||
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
|
dataset = _create_dataset(db_session_with_containers)
|
||||||
# Arrange
|
|
||||||
dataset = MagicMock(spec=Dataset)
|
|
||||||
dataset.id = "dataset_id"
|
|
||||||
query = "test query"
|
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
account.id = "account_id"
|
account.id = str(uuid4())
|
||||||
|
|
||||||
retrieval_model = {
|
retrieval_model = {
|
||||||
"search_method": "semantic_search",
|
"search_method": "semantic_search",
|
||||||
@ -242,29 +196,27 @@ class TestHitTestingService:
|
|||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
|
||||||
# Mock metadata filtering response
|
|
||||||
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
|
|
||||||
mock_retrieve.return_value = []
|
mock_retrieve.return_value = []
|
||||||
|
|
||||||
# Act
|
|
||||||
HitTestingService.retrieve(
|
HitTestingService.retrieve(
|
||||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
dataset=dataset,
|
||||||
|
query="test query",
|
||||||
|
account=account,
|
||||||
|
retrieval_model=retrieval_model,
|
||||||
|
external_retrieval_model={},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
|
||||||
mock_get_meta.assert_called_once()
|
mock_get_meta.assert_called_once()
|
||||||
mock_retrieve.assert_called_once()
|
mock_retrieve.assert_called_once()
|
||||||
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
|
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
|
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
|
||||||
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
|
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||||
# Arrange
|
):
|
||||||
dataset = MagicMock(spec=Dataset)
|
dataset = _create_dataset(db_session_with_containers)
|
||||||
dataset.id = "dataset_id"
|
|
||||||
query = "test query"
|
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
|
|
||||||
retrieval_model = {
|
retrieval_model = {
|
||||||
@ -274,37 +226,27 @@ class TestHitTestingService:
|
|||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"score_threshold_enabled": False,
|
"score_threshold_enabled": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock metadata filtering response: condition returned but no IDs
|
|
||||||
mock_get_meta.return_value = ({}, "condition_string")
|
mock_get_meta.return_value = ({}, "condition_string")
|
||||||
|
|
||||||
# Act
|
|
||||||
result = cast(
|
result = cast(
|
||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
HitTestingService.retrieve(
|
HitTestingService.retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=query,
|
query="test query",
|
||||||
account=account,
|
account=account,
|
||||||
retrieval_model=retrieval_model,
|
retrieval_model=retrieval_model,
|
||||||
external_retrieval_model={},
|
external_retrieval_model={},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result["records"] == []
|
assert result["records"] == []
|
||||||
mock_retrieve.assert_not_called()
|
mock_retrieve.assert_not_called()
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||||
@patch("extensions.ext_database.db.session.add")
|
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
|
||||||
@patch("extensions.ext_database.db.session.commit")
|
dataset = _create_dataset(db_session_with_containers)
|
||||||
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
|
|
||||||
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
|
|
||||||
# Arrange
|
|
||||||
dataset = MagicMock(spec=Dataset)
|
|
||||||
dataset.id = "dataset_id"
|
|
||||||
query = "test query"
|
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
account.id = "account_id"
|
account.id = str(uuid4())
|
||||||
attachment_ids = ["att1", "att2"]
|
attachment_ids = ["att1", "att2"]
|
||||||
|
|
||||||
retrieval_model = {
|
retrieval_model = {
|
||||||
@ -315,21 +257,19 @@ class TestHitTestingService:
|
|||||||
}
|
}
|
||||||
mock_retrieve.return_value = []
|
mock_retrieve.return_value = []
|
||||||
|
|
||||||
# Act
|
|
||||||
HitTestingService.retrieve(
|
HitTestingService.retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=query,
|
query="test query",
|
||||||
account=account,
|
account=account,
|
||||||
retrieval_model=retrieval_model,
|
retrieval_model=retrieval_model,
|
||||||
external_retrieval_model={},
|
external_retrieval_model={},
|
||||||
attachment_ids=attachment_ids,
|
attachment_ids=attachment_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
|
||||||
mock_retrieve.assert_called_once_with(
|
mock_retrieve.assert_called_once_with(
|
||||||
retrieval_method=ANY,
|
retrieval_method=ANY,
|
||||||
dataset_id="dataset_id",
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query="test query",
|
||||||
attachment_ids=attachment_ids,
|
attachment_ids=attachment_ids,
|
||||||
top_k=4,
|
top_k=4,
|
||||||
score_threshold=0.0,
|
score_threshold=0.0,
|
||||||
@ -338,26 +278,27 @@ class TestHitTestingService:
|
|||||||
weights=None,
|
weights=None,
|
||||||
document_ids_filter=None,
|
document_ids_filter=None,
|
||||||
)
|
)
|
||||||
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
|
|
||||||
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
|
# Verify DatasetQuery was persisted with correct content structure
|
||||||
called_query = mock_add.call_args[0][0]
|
db_session_with_containers.expire_all()
|
||||||
query_content = json.loads(called_query.content)
|
latest = db_session_with_containers.scalar(
|
||||||
|
select(DatasetQuery)
|
||||||
|
.where(DatasetQuery.dataset_id == dataset.id)
|
||||||
|
.order_by(DatasetQuery.created_at.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
assert latest is not None
|
||||||
|
query_content = json.loads(latest.content)
|
||||||
assert len(query_content) == 3 # 1 text + 2 images
|
assert len(query_content) == 3 # 1 text + 2 images
|
||||||
assert query_content[0]["content_type"] == "text_query"
|
assert query_content[0]["content_type"] == "text_query"
|
||||||
assert query_content[1]["content_type"] == "image_query"
|
assert query_content[1]["content_type"] == "image_query"
|
||||||
assert query_content[1]["content"] == "att1"
|
assert query_content[1]["content"] == "att1"
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||||
@patch("extensions.ext_database.db.session.add")
|
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
|
||||||
@patch("extensions.ext_database.db.session.commit")
|
dataset = _create_dataset(db_session_with_containers)
|
||||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
|
|
||||||
"""Test that retrieve passes reranking and threshold parameters correctly"""
|
|
||||||
# Arrange
|
|
||||||
dataset = MagicMock(spec=Dataset)
|
|
||||||
dataset.id = "dataset_id"
|
|
||||||
query = "test query"
|
|
||||||
account = MagicMock()
|
account = MagicMock()
|
||||||
account.id = "account_id"
|
account.id = str(uuid4())
|
||||||
|
|
||||||
retrieval_model = {
|
retrieval_model = {
|
||||||
"search_method": "hybrid_search",
|
"search_method": "hybrid_search",
|
||||||
@ -371,12 +312,14 @@ class TestHitTestingService:
|
|||||||
}
|
}
|
||||||
mock_retrieve.return_value = []
|
mock_retrieve.return_value = []
|
||||||
|
|
||||||
# Act
|
|
||||||
HitTestingService.retrieve(
|
HitTestingService.retrieve(
|
||||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
dataset=dataset,
|
||||||
|
query="test query",
|
||||||
|
account=account,
|
||||||
|
retrieval_model=retrieval_model,
|
||||||
|
external_retrieval_model={},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
|
||||||
mock_retrieve.assert_called_once()
|
mock_retrieve.assert_called_once()
|
||||||
kwargs = mock_retrieve.call_args.kwargs
|
kwargs = mock_retrieve.call_args.kwargs
|
||||||
assert kwargs["score_threshold"] == 0.5
|
assert kwargs["score_threshold"] == 0.5
|
||||||
@ -0,0 +1,363 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.ops.entities.config_entity import TracingProviderEnum
|
||||||
|
from models.model import TraceAppConfig
|
||||||
|
from services.account_service import AccountService, TenantService
|
||||||
|
from services.app_service import AppService
|
||||||
|
from services.ops_service import OpsService
|
||||||
|
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpsService:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_external_service_dependencies(self):
|
||||||
|
with (
|
||||||
|
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||||
|
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||||
|
patch("services.app_service.ModelManager.for_tenant") as mock_model_manager,
|
||||||
|
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||||
|
):
|
||||||
|
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||||
|
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||||
|
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||||
|
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||||
|
mock_model_instance = mock_model_manager.return_value
|
||||||
|
mock_model_instance.get_default_model_instance.return_value = None
|
||||||
|
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||||
|
yield {
|
||||||
|
"feature_service": mock_feature_service,
|
||||||
|
"enterprise_service": mock_enterprise_service,
|
||||||
|
"model_manager": mock_model_manager,
|
||||||
|
"account_feature_service": mock_account_feature_service,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ops_trace_manager(self):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||||
|
fake = Faker()
|
||||||
|
account = AccountService.create_account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
password=generate_valid_password(fake),
|
||||||
|
)
|
||||||
|
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||||
|
tenant = account.current_tenant
|
||||||
|
app_service = AppService()
|
||||||
|
app = app_service.create_app(
|
||||||
|
tenant.id,
|
||||||
|
{
|
||||||
|
"name": fake.company(),
|
||||||
|
"description": fake.text(max_nb_chars=100),
|
||||||
|
"mode": "chat",
|
||||||
|
"icon_type": "emoji",
|
||||||
|
"icon": "🤖",
|
||||||
|
"icon_background": "#FF6B6B",
|
||||||
|
},
|
||||||
|
account,
|
||||||
|
)
|
||||||
|
return app, account
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
def _insert_trace_config(
|
||||||
|
self,
|
||||||
|
db_session: Session,
|
||||||
|
app_id: str,
|
||||||
|
provider: str,
|
||||||
|
tracing_config: dict | None | object = _SENTINEL,
|
||||||
|
) -> TraceAppConfig:
|
||||||
|
trace_config = TraceAppConfig(
|
||||||
|
app_id=app_id,
|
||||||
|
tracing_provider=provider,
|
||||||
|
tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"},
|
||||||
|
)
|
||||||
|
db_session.add(trace_config)
|
||||||
|
db_session.commit()
|
||||||
|
return trace_config
|
||||||
|
|
||||||
|
# ── get_tracing_app_config ─────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||||
|
result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||||
|
fake_app_id = str(uuid.uuid4())
|
||||||
|
self._insert_trace_config(db_session_with_containers, fake_app_id, "arize")
|
||||||
|
result = OpsService.get_tracing_app_config(fake_app_id, "arize")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_tracing_app_config_none_config(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager
|
||||||
|
):
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||||
|
OpsService.get_tracing_app_config(app.id, "arize")
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("provider", "default_url"),
|
||||||
|
[
|
||||||
|
("arize", "https://app.arize.com/"),
|
||||||
|
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||||
|
("langsmith", "https://smith.langchain.com/"),
|
||||||
|
("opik", "https://www.comet.com/opik/"),
|
||||||
|
("weave", "https://wandb.ai/"),
|
||||||
|
("aliyun", "https://arms.console.aliyun.com/"),
|
||||||
|
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||||
|
("mlflow", "http://localhost:5000/"),
|
||||||
|
("databricks", "https://www.databricks.com/"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_tracing_app_config_providers_exception(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.decrypt_tracing_config.return_value = {}
|
||||||
|
mock_otm.obfuscated_decrypt_token.return_value = {}
|
||||||
|
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
|
||||||
|
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, provider)
|
||||||
|
|
||||||
|
result = OpsService.get_tracing_app_config(app.id, provider)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["tracing_config"]["project_url"] == default_url
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider",
|
||||||
|
["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"],
|
||||||
|
)
|
||||||
|
def test_get_tracing_app_config_providers_success(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies, provider
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.decrypt_tracing_config.return_value = {}
|
||||||
|
mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"}
|
||||||
|
mock_otm.get_trace_config_project_url.return_value = "success_url"
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, provider)
|
||||||
|
|
||||||
|
result = OpsService.get_tracing_app_config(app.id, provider)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["tracing_config"]["project_url"] == "success_url"
|
||||||
|
|
||||||
|
def test_get_tracing_app_config_langfuse_success(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||||
|
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||||
|
mock_otm.get_trace_config_project_key.return_value = "key"
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
|
||||||
|
|
||||||
|
result = OpsService.get_tracing_app_config(app.id, "langfuse")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
||||||
|
|
||||||
|
def test_get_tracing_app_config_langfuse_exception(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||||
|
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||||
|
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
|
||||||
|
|
||||||
|
result = OpsService.get_tracing_app_config(app.id, "langfuse")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
||||||
|
|
||||||
|
# ── create_tracing_app_config ──────────────────────────────────────
|
||||||
|
|
||||||
|
def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
|
||||||
|
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
|
||||||
|
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
||||||
|
|
||||||
|
def test_create_tracing_app_config_invalid_credentials(
|
||||||
|
self, db_session_with_containers: Session, mock_ops_trace_manager
|
||||||
|
):
|
||||||
|
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||||
|
result = OpsService.create_tracing_app_config(
|
||||||
|
str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}
|
||||||
|
)
|
||||||
|
assert result == {"error": "Invalid Credentials"}
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("provider", "config"),
|
||||||
|
[
|
||||||
|
(TracingProviderEnum.ARIZE, {}),
|
||||||
|
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
||||||
|
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
||||||
|
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_create_tracing_app_config_project_url_exception(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config
|
||||||
|
):
|
||||||
|
# Existing config causes the service to return None before reaching the DB insert
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.check_trace_config_is_effective.return_value = True
|
||||||
|
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
|
||||||
|
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, str(provider))
|
||||||
|
|
||||||
|
result = OpsService.create_tracing_app_config(app.id, provider, config)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_create_tracing_app_config_langfuse_success(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.check_trace_config_is_effective.return_value = True
|
||||||
|
mock_otm.get_trace_config_project_key.return_value = "key"
|
||||||
|
mock_otm.encrypt_tracing_config.return_value = {}
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
result = OpsService.create_tracing_app_config(
|
||||||
|
app.id,
|
||||||
|
TracingProviderEnum.LANGFUSE,
|
||||||
|
{"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
def test_create_tracing_app_config_already_exists(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.check_trace_config_is_effective.return_value = True
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||||
|
|
||||||
|
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||||
|
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||||
|
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_create_tracing_app_config_with_empty_other_keys(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
# "project" is in other_keys for Arize; providing "" triggers default substitution
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.check_trace_config_is_effective.return_value = True
|
||||||
|
mock_otm.get_trace_config_project_url.side_effect = Exception("no url")
|
||||||
|
mock_otm.encrypt_tracing_config.return_value = {}
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""})
|
||||||
|
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
def test_create_tracing_app_config_success(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.check_trace_config_is_effective.return_value = True
|
||||||
|
mock_otm.get_trace_config_project_url.return_value = "http://project_url"
|
||||||
|
mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||||
|
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
# ── update_tracing_app_config ──────────────────────────────────────
|
||||||
|
|
||||||
|
def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
|
||||||
|
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
||||||
|
OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
|
||||||
|
|
||||||
|
def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||||
|
result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||||
|
fake_app_id = str(uuid.uuid4())
|
||||||
|
self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE))
|
||||||
|
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||||
|
result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_update_tracing_app_config_invalid_credentials(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.encrypt_tracing_config.return_value = {}
|
||||||
|
mock_otm.decrypt_tracing_config.return_value = {}
|
||||||
|
mock_otm.check_trace_config_is_effective.return_value = False
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid Credentials"):
|
||||||
|
OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||||
|
|
||||||
|
def test_update_tracing_app_config_success(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||||
|
mock_otm.encrypt_tracing_config.return_value = {"updated": "config"}
|
||||||
|
mock_otm.decrypt_tracing_config.return_value = {}
|
||||||
|
mock_otm.check_trace_config_is_effective.return_value = True
|
||||||
|
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||||
|
|
||||||
|
result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["app_id"] == app.id
|
||||||
|
|
||||||
|
# ── delete_tracing_app_config ──────────────────────────────────────
|
||||||
|
|
||||||
|
def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session):
|
||||||
|
result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_delete_tracing_app_config_success(
|
||||||
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||||
|
self._insert_trace_config(db_session_with_containers, app.id, "arize")
|
||||||
|
|
||||||
|
result = OpsService.delete_tracing_app_config(app.id, "arize")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
remaining = db_session_with_containers.scalar(
|
||||||
|
select(TraceAppConfig)
|
||||||
|
.where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize")
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
assert remaining is None
|
||||||
@ -233,11 +233,10 @@ class TestWebAppAuthService:
|
|||||||
assert result.status == AccountStatus.ACTIVE
|
assert result.status == AccountStatus.ACTIVE
|
||||||
|
|
||||||
# Verify database state
|
# Verify database state
|
||||||
|
refreshed = db_session_with_containers.get(Account, result.id)
|
||||||
db_session_with_containers.refresh(result)
|
assert refreshed is not None
|
||||||
assert result.id is not None
|
assert refreshed.password is not None
|
||||||
assert result.password is not None
|
assert refreshed.password_salt is not None
|
||||||
assert result.password_salt is not None
|
|
||||||
|
|
||||||
def test_authenticate_account_not_found(
|
def test_authenticate_account_not_found(
|
||||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
@ -414,9 +413,8 @@ class TestWebAppAuthService:
|
|||||||
assert result.status == AccountStatus.ACTIVE
|
assert result.status == AccountStatus.ACTIVE
|
||||||
|
|
||||||
# Verify database state
|
# Verify database state
|
||||||
|
refreshed = db_session_with_containers.get(Account, result.id)
|
||||||
db_session_with_containers.refresh(result)
|
assert refreshed is not None
|
||||||
assert result.id is not None
|
|
||||||
|
|
||||||
def test_get_user_through_email_not_found(
|
def test_get_user_through_email_not_found(
|
||||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from importlib import import_module
|
||||||
from unittest.mock import MagicMock, PropertyMock, patch
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -11,6 +12,7 @@ from controllers.console.datasets.external import (
|
|||||||
BedrockRetrievalApi,
|
BedrockRetrievalApi,
|
||||||
ExternalApiTemplateApi,
|
ExternalApiTemplateApi,
|
||||||
ExternalApiTemplateListApi,
|
ExternalApiTemplateListApi,
|
||||||
|
ExternalApiUseCheckApi,
|
||||||
ExternalDatasetCreateApi,
|
ExternalDatasetCreateApi,
|
||||||
ExternalKnowledgeHitTestingApi,
|
ExternalKnowledgeHitTestingApi,
|
||||||
)
|
)
|
||||||
@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService
|
|||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
from services.knowledge_service import ExternalDatasetTestService
|
from services.knowledge_service import ExternalDatasetTestService
|
||||||
|
|
||||||
|
external_controller = import_module("controllers.console.datasets.external")
|
||||||
|
|
||||||
|
|
||||||
def unwrap(func):
|
def unwrap(func):
|
||||||
while hasattr(func, "__wrapped__"):
|
while hasattr(func, "__wrapped__"):
|
||||||
@ -44,10 +48,11 @@ def current_user():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_auth(mocker, current_user):
|
def mock_auth(monkeypatch, current_user):
|
||||||
mocker.patch(
|
monkeypatch.setattr(
|
||||||
"controllers.console.datasets.external.current_account_with_tenant",
|
external_controller,
|
||||||
return_value=(current_user, "tenant-1"),
|
"current_account_with_tenant",
|
||||||
|
lambda: (current_user, "tenant-1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -136,6 +141,26 @@ class TestExternalApiTemplateApi:
|
|||||||
method(api, "api-id")
|
method(api, "api-id")
|
||||||
|
|
||||||
|
|
||||||
|
class TestExternalApiUseCheckApi:
|
||||||
|
def test_get_scopes_usage_check_to_current_tenant(self, app):
|
||||||
|
api = ExternalApiUseCheckApi()
|
||||||
|
method = unwrap(api.get)
|
||||||
|
|
||||||
|
with (
|
||||||
|
app.test_request_context("/"),
|
||||||
|
patch.object(
|
||||||
|
ExternalDatasetService,
|
||||||
|
"external_knowledge_api_use_check",
|
||||||
|
return_value=(True, 2),
|
||||||
|
) as mock_use_check,
|
||||||
|
):
|
||||||
|
response, status = method(api, "api-id")
|
||||||
|
|
||||||
|
assert status == 200
|
||||||
|
assert response == {"is_using": True, "count": 2}
|
||||||
|
mock_use_check.assert_called_once_with("api-id", "tenant-1")
|
||||||
|
|
||||||
|
|
||||||
class TestExternalDatasetCreateApi:
|
class TestExternalDatasetCreateApi:
|
||||||
def test_create_success(self, app):
|
def test_create_success(self, app):
|
||||||
api = ExternalDatasetCreateApi()
|
api = ExternalDatasetCreateApi()
|
||||||
|
|||||||
@ -233,15 +233,20 @@ class TestCheckEmailUnique:
|
|||||||
|
|
||||||
|
|
||||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||||
session = MagicMock()
|
mock_session = MagicMock()
|
||||||
first = MagicMock()
|
first = MagicMock()
|
||||||
first.scalar_one_or_none.return_value = None
|
first.scalar_one_or_none.return_value = None
|
||||||
second = MagicMock()
|
second = MagicMock()
|
||||||
expected_account = MagicMock()
|
expected_account = MagicMock()
|
||||||
second.scalar_one_or_none.return_value = expected_account
|
second.scalar_one_or_none.return_value = expected_account
|
||||||
session.execute.side_effect = [first, second]
|
mock_session.execute.side_effect = [first, second]
|
||||||
|
|
||||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
|
mock_factory = MagicMock()
|
||||||
|
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||||
|
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("services.account_service.session_factory", mock_factory):
|
||||||
|
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||||
|
|
||||||
assert result is expected_account
|
assert result is expected_account
|
||||||
assert session.execute.call_count == 2
|
assert mock_session.execute.call_count == 2
|
||||||
|
|||||||
@ -4,9 +4,7 @@ from unittest.mock import Mock
|
|||||||
|
|
||||||
from core.mcp.entities import (
|
from core.mcp.entities import (
|
||||||
SUPPORTED_PROTOCOL_VERSIONS,
|
SUPPORTED_PROTOCOL_VERSIONS,
|
||||||
LifespanContextT,
|
|
||||||
RequestContext,
|
RequestContext,
|
||||||
SessionT,
|
|
||||||
)
|
)
|
||||||
from core.mcp.session.base_session import BaseSession
|
from core.mcp.session.base_session import BaseSession
|
||||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
|
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
|
||||||
@ -198,42 +196,3 @@ class TestRequestContext:
|
|||||||
assert "RequestContext" in repr_str
|
assert "RequestContext" in repr_str
|
||||||
assert "test-123" in repr_str
|
assert "test-123" in repr_str
|
||||||
assert "MockSession" in repr_str
|
assert "MockSession" in repr_str
|
||||||
|
|
||||||
|
|
||||||
class TestTypeVariables:
|
|
||||||
"""Test type variables defined in the module."""
|
|
||||||
|
|
||||||
def test_session_type_var(self):
|
|
||||||
"""Test SessionT type variable."""
|
|
||||||
|
|
||||||
# Create a custom session class
|
|
||||||
class CustomSession(BaseSession):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Use in generic context
|
|
||||||
def process_session(session: SessionT) -> SessionT:
|
|
||||||
return session
|
|
||||||
|
|
||||||
mock_session = Mock(spec=CustomSession)
|
|
||||||
result = process_session(mock_session)
|
|
||||||
assert result == mock_session
|
|
||||||
|
|
||||||
def test_lifespan_context_type_var(self):
|
|
||||||
"""Test LifespanContextT type variable."""
|
|
||||||
|
|
||||||
# Use in generic context
|
|
||||||
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
|
|
||||||
return context
|
|
||||||
|
|
||||||
# Test with different types
|
|
||||||
str_context = "string-context"
|
|
||||||
assert process_lifespan(str_context) == str_context
|
|
||||||
|
|
||||||
dict_context = {"key": "value"}
|
|
||||||
assert process_lifespan(dict_context) == dict_context
|
|
||||||
|
|
||||||
class CustomContext:
|
|
||||||
pass
|
|
||||||
|
|
||||||
custom_context = CustomContext()
|
|
||||||
assert process_lifespan(custom_context) == custom_context
|
|
||||||
|
|||||||
@ -39,6 +39,25 @@ class _FakeSession:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeBeginContext:
|
||||||
|
def __init__(self, session):
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_both(monkeypatch, module, session):
|
||||||
|
"""Patch both Session and sessionmaker on the module."""
|
||||||
|
monkeypatch.setattr(module, "Session", lambda _client: session)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def relyt_module(monkeypatch):
|
def relyt_module(monkeypatch):
|
||||||
for name, module in _build_fake_relyt_modules().items():
|
for name, module in _build_fake_relyt_modules().items():
|
||||||
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
|||||||
|
|
||||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
|
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
|
||||||
session = _FakeSession()
|
session = _FakeSession()
|
||||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
_patch_both(monkeypatch, relyt_module, session)
|
||||||
vector.create_collection(3)
|
vector.create_collection(3)
|
||||||
session.execute.assert_not_called()
|
session.execute.assert_not_called()
|
||||||
|
|
||||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
|
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
|
||||||
session = _FakeSession()
|
session = _FakeSession()
|
||||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
_patch_both(monkeypatch, relyt_module, session)
|
||||||
vector.create_collection(3)
|
vector.create_collection(3)
|
||||||
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
|
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
|
||||||
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
|
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
|
||||||
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
|
|||||||
|
|
||||||
|
|
||||||
# 8. delete commits session
|
# 8. delete commits session
|
||||||
def test_delete_commits_session(relyt_module, monkeypatch):
|
def test_delete_drops_table(relyt_module, monkeypatch):
|
||||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||||
vector._collection_name = "collection_1"
|
vector._collection_name = "collection_1"
|
||||||
vector.client = MagicMock()
|
vector.client = MagicMock()
|
||||||
vector.embedding_dimension = 3
|
vector.embedding_dimension = 3
|
||||||
session = _FakeSession()
|
session = _FakeSession()
|
||||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
_patch_both(monkeypatch, relyt_module, session)
|
||||||
vector.delete()
|
vector.delete()
|
||||||
session.commit.assert_called_once()
|
session.execute.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
||||||
|
|||||||
@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
|||||||
|
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
|
|
||||||
class _SessionCtx:
|
class _BeginCtx:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||||
|
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||||
|
|
||||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||||
vector._collection_name = "collection_1"
|
vector._collection_name = "collection_1"
|
||||||
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
|||||||
|
|
||||||
vector._create_collection(3)
|
vector._create_collection(3)
|
||||||
|
|
||||||
session.begin.assert_called_once()
|
|
||||||
sql = str(session.execute.call_args.args[0])
|
sql = str(session.execute.call_args.args[0])
|
||||||
assert "VECTOR<FLOAT>(3)" in sql
|
assert "VECTOR<FLOAT>(3)" in sql
|
||||||
assert "VEC_L2_DISTANCE" in sql
|
assert "VEC_L2_DISTANCE" in sql
|
||||||
session.commit.assert_called_once()
|
|
||||||
tidb_module.redis_client.set.assert_called_once()
|
tidb_module.redis_client.set.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
|||||||
def test_delete_drops_table(tidb_module, monkeypatch):
|
def test_delete_drops_table(tidb_module, monkeypatch):
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.execute.return_value = None
|
session.execute.return_value = None
|
||||||
session.commit = MagicMock()
|
|
||||||
|
|
||||||
class _SessionCtx:
|
class _BeginCtx:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||||
|
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||||
vector._collection_name = "collection_1"
|
vector._collection_name = "collection_1"
|
||||||
vector._engine = MagicMock()
|
vector._engine = MagicMock()
|
||||||
vector.delete()
|
vector.delete()
|
||||||
drop_sql = str(session.execute.call_args.args[0])
|
drop_sql = str(session.execute.call_args.args[0])
|
||||||
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
||||||
session.commit.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class TestAppGenerateHandler:
|
|||||||
"root_node_id": None,
|
"root_node_id": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs)
|
arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs)
|
||||||
|
|
||||||
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
|
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
|
||||||
assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
|
assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
|
||||||
@ -70,14 +70,11 @@ class TestAppGenerateHandler:
|
|||||||
handler.wrapper(
|
handler.wrapper(
|
||||||
tracer,
|
tracer,
|
||||||
dummy_func,
|
dummy_func,
|
||||||
(),
|
app_model=mock_app_model,
|
||||||
{
|
user=mock_account_user,
|
||||||
"app_model": mock_app_model,
|
args={"workflow_id": test_workflow_id},
|
||||||
"user": mock_account_user,
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
"args": {"workflow_id": test_workflow_id},
|
streaming=False,
|
||||||
"invoke_from": InvokeFrom.DEBUGGER,
|
|
||||||
"streaming": False,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
|||||||
@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler:
|
|||||||
def runner_run(self):
|
def runner_run(self):
|
||||||
return "result"
|
return "result"
|
||||||
|
|
||||||
handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {})
|
handler.wrapper(tracer, runner_run, mock_workflow_runner)
|
||||||
|
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
assert len(spans) == 1
|
assert len(spans) == 1
|
||||||
|
|||||||
@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments:
|
|||||||
|
|
||||||
args = (1, 2, 3)
|
args = (1, 2, 3)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
result = handler._extract_arguments(func, args, kwargs)
|
result = handler._extract_arguments(func, *args, **kwargs)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["a"] == 1
|
assert result["a"] == 1
|
||||||
@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments:
|
|||||||
|
|
||||||
args = ()
|
args = ()
|
||||||
kwargs = {"a": 1, "b": 2, "c": 3}
|
kwargs = {"a": 1, "b": 2, "c": 3}
|
||||||
result = handler._extract_arguments(func, args, kwargs)
|
result = handler._extract_arguments(func, *args, **kwargs)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["a"] == 1
|
assert result["a"] == 1
|
||||||
@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments:
|
|||||||
|
|
||||||
args = (1,)
|
args = (1,)
|
||||||
kwargs = {"b": 2, "c": 3}
|
kwargs = {"b": 2, "c": 3}
|
||||||
result = handler._extract_arguments(func, args, kwargs)
|
result = handler._extract_arguments(func, *args, **kwargs)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["a"] == 1
|
assert result["a"] == 1
|
||||||
@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments:
|
|||||||
|
|
||||||
args = (1,)
|
args = (1,)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
result = handler._extract_arguments(func, args, kwargs)
|
result = handler._extract_arguments(func, *args, **kwargs)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["a"] == 1
|
assert result["a"] == 1
|
||||||
@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments:
|
|||||||
instance = MyClass()
|
instance = MyClass()
|
||||||
args = (1, 2)
|
args = (1, 2)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
result = handler._extract_arguments(instance.method, args, kwargs)
|
result = handler._extract_arguments(instance.method, *args, **kwargs)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["a"] == 1
|
assert result["a"] == 1
|
||||||
@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments:
|
|||||||
|
|
||||||
args = (1,)
|
args = (1,)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
result = handler._extract_arguments(func, args, kwargs)
|
result = handler._extract_arguments(func, *args, **kwargs)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments:
|
|||||||
|
|
||||||
assert func not in handler._signature_cache
|
assert func not in handler._signature_cache
|
||||||
|
|
||||||
handler._extract_arguments(func, (1, 2), {})
|
handler._extract_arguments(func, 1, 2)
|
||||||
assert func in handler._signature_cache
|
assert func in handler._signature_cache
|
||||||
|
|
||||||
cached_sig = handler._signature_cache[func]
|
cached_sig = handler._signature_cache[func]
|
||||||
handler._extract_arguments(func, (3, 4), {})
|
handler._extract_arguments(func, 3, 4)
|
||||||
assert handler._signature_cache[func] is cached_sig
|
assert handler._signature_cache[func] is cached_sig
|
||||||
|
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ class TestSpanHandlerWrapper:
|
|||||||
def test_func():
|
def test_func():
|
||||||
return "result"
|
return "result"
|
||||||
|
|
||||||
result = handler.wrapper(tracer, test_func, (), {})
|
result = handler.wrapper(tracer, test_func)
|
||||||
|
|
||||||
assert result == "result"
|
assert result == "result"
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
@ -159,7 +159,7 @@ class TestSpanHandlerWrapper:
|
|||||||
def test_func():
|
def test_func():
|
||||||
return "result"
|
return "result"
|
||||||
|
|
||||||
handler.wrapper(tracer, test_func, (), {})
|
handler.wrapper(tracer, test_func)
|
||||||
|
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
assert len(spans) == 1
|
assert len(spans) == 1
|
||||||
@ -174,7 +174,7 @@ class TestSpanHandlerWrapper:
|
|||||||
def test_func():
|
def test_func():
|
||||||
return "result"
|
return "result"
|
||||||
|
|
||||||
handler.wrapper(tracer, test_func, (), {})
|
handler.wrapper(tracer, test_func)
|
||||||
|
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
assert len(spans) == 1
|
assert len(spans) == 1
|
||||||
@ -190,7 +190,7 @@ class TestSpanHandlerWrapper:
|
|||||||
raise ValueError("test error")
|
raise ValueError("test error")
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="test error"):
|
with pytest.raises(ValueError, match="test error"):
|
||||||
handler.wrapper(tracer, test_func, (), {})
|
handler.wrapper(tracer, test_func)
|
||||||
|
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
assert len(spans) == 1
|
assert len(spans) == 1
|
||||||
@ -208,7 +208,7 @@ class TestSpanHandlerWrapper:
|
|||||||
raise ValueError("test error")
|
raise ValueError("test error")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
handler.wrapper(tracer, test_func, (), {})
|
handler.wrapper(tracer, test_func)
|
||||||
|
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
assert len(spans) == 1
|
assert len(spans) == 1
|
||||||
@ -225,7 +225,7 @@ class TestSpanHandlerWrapper:
|
|||||||
raise ValueError("test error")
|
raise ValueError("test error")
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="test error"):
|
with pytest.raises(ValueError, match="test error"):
|
||||||
handler.wrapper(tracer, test_func, (), {})
|
handler.wrapper(tracer, test_func)
|
||||||
|
|
||||||
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
|
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
|
||||||
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
|
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
|
||||||
@ -236,7 +236,7 @@ class TestSpanHandlerWrapper:
|
|||||||
def test_func(a, b, c=10):
|
def test_func(a, b, c=10):
|
||||||
return a + b + c
|
return a + b + c
|
||||||
|
|
||||||
result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3})
|
result = handler.wrapper(tracer, test_func, 1, 2, c=3)
|
||||||
|
|
||||||
assert result == 6
|
assert result == 6
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ class TestSpanHandlerWrapper:
|
|||||||
def my_function(x):
|
def my_function(x):
|
||||||
return x * 2
|
return x * 2
|
||||||
|
|
||||||
result = handler.wrapper(tracer, my_function, (5,), {})
|
result = handler.wrapper(tracer, my_function, 5)
|
||||||
|
|
||||||
assert result == 10
|
assert result == 10
|
||||||
spans = memory_span_exporter.get_finished_spans()
|
spans = memory_span_exporter.get_finished_spans()
|
||||||
|
|||||||
138
api/tests/unit_tests/libs/test_pyrefly_type_coverage.py
Normal file
138
api/tests/unit_tests/libs/test_pyrefly_type_coverage.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from libs.pyrefly_type_coverage import (
|
||||||
|
CoverageSummary,
|
||||||
|
format_comparison_markdown,
|
||||||
|
format_summary_markdown,
|
||||||
|
parse_summary,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_report(summary: dict) -> str:
|
||||||
|
return json.dumps({"module_reports": [], "summary": summary})
|
||||||
|
|
||||||
|
|
||||||
|
_SAMPLE_SUMMARY: dict = {
|
||||||
|
"n_modules": 100,
|
||||||
|
"n_typable": 1000,
|
||||||
|
"n_typed": 400,
|
||||||
|
"n_any": 50,
|
||||||
|
"n_untyped": 550,
|
||||||
|
"coverage": 45.0,
|
||||||
|
"strict_coverage": 40.0,
|
||||||
|
"n_functions": 200,
|
||||||
|
"n_methods": 300,
|
||||||
|
"n_function_params": 150,
|
||||||
|
"n_method_params": 250,
|
||||||
|
"n_classes": 80,
|
||||||
|
"n_attrs": 40,
|
||||||
|
"n_properties": 20,
|
||||||
|
"n_type_ignores": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_summary(
|
||||||
|
*,
|
||||||
|
n_modules: int = 100,
|
||||||
|
n_typable: int = 1000,
|
||||||
|
n_typed: int = 400,
|
||||||
|
n_any: int = 50,
|
||||||
|
n_untyped: int = 550,
|
||||||
|
coverage: float = 45.0,
|
||||||
|
strict_coverage: float = 40.0,
|
||||||
|
) -> CoverageSummary:
|
||||||
|
return {
|
||||||
|
"n_modules": n_modules,
|
||||||
|
"n_typable": n_typable,
|
||||||
|
"n_typed": n_typed,
|
||||||
|
"n_any": n_any,
|
||||||
|
"n_untyped": n_untyped,
|
||||||
|
"coverage": coverage,
|
||||||
|
"strict_coverage": strict_coverage,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_summary_extracts_fields() -> None:
|
||||||
|
report_json = _make_report(_SAMPLE_SUMMARY)
|
||||||
|
|
||||||
|
result = parse_summary(report_json)
|
||||||
|
|
||||||
|
assert result["n_modules"] == 100
|
||||||
|
assert result["n_typable"] == 1000
|
||||||
|
assert result["n_typed"] == 400
|
||||||
|
assert result["n_any"] == 50
|
||||||
|
assert result["n_untyped"] == 550
|
||||||
|
assert result["coverage"] == 45.0
|
||||||
|
assert result["strict_coverage"] == 40.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_summary_handles_empty_input() -> None:
|
||||||
|
assert parse_summary("")["n_modules"] == 0
|
||||||
|
assert parse_summary(" ")["n_modules"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_summary_handles_invalid_json() -> None:
|
||||||
|
assert parse_summary("not json")["n_modules"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_summary_handles_missing_summary_key() -> None:
|
||||||
|
assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_summary_handles_incomplete_summary() -> None:
|
||||||
|
partial = json.dumps({"summary": {"n_modules": 5}})
|
||||||
|
assert parse_summary(partial)["n_modules"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_summary_markdown_contains_key_metrics() -> None:
|
||||||
|
summary = _make_summary()
|
||||||
|
|
||||||
|
result = format_summary_markdown(summary)
|
||||||
|
|
||||||
|
assert "**Type coverage**" in result
|
||||||
|
assert "45.00%" in result
|
||||||
|
assert "40.00%" in result
|
||||||
|
assert "| Modules | 100 |" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_comparison_markdown_shows_positive_delta() -> None:
|
||||||
|
base = _make_summary()
|
||||||
|
pr = _make_summary(
|
||||||
|
n_modules=101,
|
||||||
|
n_typable=1010,
|
||||||
|
n_typed=420,
|
||||||
|
n_untyped=540,
|
||||||
|
coverage=46.53,
|
||||||
|
strict_coverage=41.58,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = format_comparison_markdown(base, pr)
|
||||||
|
|
||||||
|
assert "| Base | PR | Delta |" in result
|
||||||
|
assert "+1.53%" in result
|
||||||
|
assert "+1.58%" in result
|
||||||
|
assert "+20" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_comparison_markdown_shows_negative_delta() -> None:
|
||||||
|
base = _make_summary()
|
||||||
|
pr = _make_summary(
|
||||||
|
n_typed=390,
|
||||||
|
n_any=60,
|
||||||
|
coverage=44.0,
|
||||||
|
strict_coverage=39.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = format_comparison_markdown(base, pr)
|
||||||
|
|
||||||
|
assert "-1.00%" in result
|
||||||
|
assert "-10" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_comparison_markdown_shows_zero_delta() -> None:
|
||||||
|
summary = _make_summary()
|
||||||
|
|
||||||
|
result = format_comparison_markdown(summary, summary)
|
||||||
|
|
||||||
|
assert "0.00%" in result
|
||||||
|
assert "| 0 |" in result
|
||||||
@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings:
|
|||||||
|
|
||||||
mock_db_session.scalar.return_value = 3
|
mock_db_session.scalar.return_value = 3
|
||||||
|
|
||||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||||
|
|
||||||
assert in_use is True
|
assert in_use is True
|
||||||
assert count == 3
|
assert count == 3
|
||||||
|
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
|
||||||
|
|
||||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
|
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
|
||||||
"""
|
"""
|
||||||
@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
|||||||
|
|
||||||
mock_db_session.scalar.return_value = 0
|
mock_db_session.scalar.return_value = 0
|
||||||
|
|
||||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||||
|
|
||||||
assert in_use is False
|
assert in_use is False
|
||||||
assert count == 0
|
assert count == 0
|
||||||
|
|||||||
@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service"
|
|||||||
|
|
||||||
|
|
||||||
def _patched_session():
|
def _patched_session():
|
||||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
"""Patch session_factory.create_session() to return a mock session as context manager."""
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
mock_sessionmaker = MagicMock()
|
session.__enter__ = MagicMock(return_value=session)
|
||||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
session.__exit__ = MagicMock(return_value=False)
|
||||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
session.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
session.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
db_patcher = patch(f"{MODULE}.db")
|
mock_factory = MagicMock()
|
||||||
return patcher, db_patcher, session
|
mock_factory.create_session.return_value = session
|
||||||
|
patcher = patch(f"{MODULE}.session_factory", mock_factory)
|
||||||
|
return patcher, session
|
||||||
|
|
||||||
|
|
||||||
class TestGetPermission:
|
class TestGetPermission:
|
||||||
def test_returns_permission_when_found(self):
|
def test_returns_permission_when_found(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, session = _patched_session()
|
||||||
permission = MagicMock()
|
permission = MagicMock()
|
||||||
session.scalar.return_value = permission
|
session.scalar.return_value = permission
|
||||||
|
|
||||||
with p1, p2:
|
with p1:
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
|
||||||
result = PluginPermissionService.get_permission("t1")
|
result = PluginPermissionService.get_permission("t1")
|
||||||
@ -30,10 +32,10 @@ class TestGetPermission:
|
|||||||
assert result is permission
|
assert result is permission
|
||||||
|
|
||||||
def test_returns_none_when_not_found(self):
|
def test_returns_none_when_not_found(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, session = _patched_session()
|
||||||
session.scalar.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with p1, p2:
|
with p1:
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
|
||||||
result = PluginPermissionService.get_permission("t1")
|
result = PluginPermissionService.get_permission("t1")
|
||||||
@ -43,10 +45,10 @@ class TestGetPermission:
|
|||||||
|
|
||||||
class TestChangePermission:
|
class TestChangePermission:
|
||||||
def test_creates_new_permission_when_not_exists(self):
|
def test_creates_new_permission_when_not_exists(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, session = _patched_session()
|
||||||
session.scalar.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||||
perm_cls.return_value = MagicMock()
|
perm_cls.return_value = MagicMock()
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
|
||||||
@ -54,20 +56,24 @@ class TestChangePermission:
|
|||||||
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
|
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
session.begin.assert_called_once()
|
||||||
session.add.assert_called_once()
|
session.add.assert_called_once()
|
||||||
|
|
||||||
def test_updates_existing_permission(self):
|
def test_updates_existing_permission(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, session = _patched_session()
|
||||||
existing = MagicMock()
|
existing = MagicMock()
|
||||||
session.scalar.return_value = existing
|
session.scalar.return_value = existing
|
||||||
|
|
||||||
with p1, p2:
|
with p1:
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
|
||||||
result = PluginPermissionService.change_permission(
|
result = PluginPermissionService.change_permission(
|
||||||
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
|
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
session.begin.assert_called_once()
|
||||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||||
session.add.assert_not_called()
|
session.add.assert_not_called()
|
||||||
|
|||||||
@ -1427,16 +1427,7 @@ class TestRegisterService:
|
|||||||
mock_tenant.name = "Test Workspace"
|
mock_tenant.name = "Test Workspace"
|
||||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||||
|
|
||||||
# Mock database queries - need to mock the sessionmaker query
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
|
||||||
|
|
||||||
mock_sessionmaker = MagicMock()
|
|
||||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
|
||||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
|
||||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||||
):
|
):
|
||||||
mock_lookup.return_value = None
|
mock_lookup.return_value = None
|
||||||
@ -1475,7 +1466,7 @@ class TestRegisterService:
|
|||||||
status=AccountStatus.PENDING,
|
status=AccountStatus.PENDING,
|
||||||
is_setup=True,
|
is_setup=True,
|
||||||
)
|
)
|
||||||
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
|
mock_lookup.assert_called_once_with("newuser@example.com")
|
||||||
|
|
||||||
def test_invite_new_member_normalizes_new_account_email(
|
def test_invite_new_member_normalizes_new_account_email(
|
||||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||||
@ -1486,13 +1477,7 @@ class TestRegisterService:
|
|||||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||||
mixed_email = "Invitee@Example.com"
|
mixed_email = "Invitee@Example.com"
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_sessionmaker = MagicMock()
|
|
||||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
|
||||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
|
||||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||||
):
|
):
|
||||||
mock_lookup.return_value = None
|
mock_lookup.return_value = None
|
||||||
@ -1525,7 +1510,7 @@ class TestRegisterService:
|
|||||||
status=AccountStatus.PENDING,
|
status=AccountStatus.PENDING,
|
||||||
is_setup=True,
|
is_setup=True,
|
||||||
)
|
)
|
||||||
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
|
mock_lookup.assert_called_once_with(mixed_email)
|
||||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||||
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
|
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
|
||||||
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
||||||
@ -1545,16 +1530,7 @@ class TestRegisterService:
|
|||||||
account_id="existing-user-456", email="existing@example.com", status="pending"
|
account_id="existing-user-456", email="existing@example.com", status="pending"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock database queries - need to mock the sessionmaker query
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
|
||||||
|
|
||||||
mock_sessionmaker = MagicMock()
|
|
||||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
|
||||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
|
||||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||||
):
|
):
|
||||||
mock_lookup.return_value = mock_existing_account
|
mock_lookup.return_value = mock_existing_account
|
||||||
@ -1584,7 +1560,7 @@ class TestRegisterService:
|
|||||||
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
|
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
|
||||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
|
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
|
||||||
mock_task_dependencies.delay.assert_called_once()
|
mock_task_dependencies.delay.assert_called_once()
|
||||||
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
|
mock_lookup.assert_called_once_with("existing@example.com")
|
||||||
|
|
||||||
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
|
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
|
||||||
"""Test inviting a member who is already in the tenant."""
|
"""Test inviting a member who is already in the tenant."""
|
||||||
|
|||||||
@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation:
|
|||||||
assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1
|
assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1
|
||||||
assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False
|
assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False
|
||||||
|
|
||||||
|
def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self):
|
||||||
|
knowledge_config = KnowledgeConfig(
|
||||||
|
indexing_technique="economy",
|
||||||
|
data_source=DataSource(
|
||||||
|
info_list=InfoList(
|
||||||
|
data_source_type="upload_file",
|
||||||
|
file_info_list=FileInfo(file_ids=["file-1"]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
process_rule=ProcessRule(
|
||||||
|
mode="hierarchical",
|
||||||
|
rules=Rule(
|
||||||
|
pre_processing_rules=[
|
||||||
|
PreProcessingRule(id="remove_extra_spaces", enabled=True),
|
||||||
|
],
|
||||||
|
segmentation=Segmentation(separator="\n", max_tokens=1024),
|
||||||
|
subchunk_segmentation=Segmentation(separator="\n", max_tokens=512),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
DocumentService.process_rule_args_validate(knowledge_config)
|
||||||
|
|
||||||
|
assert knowledge_config.process_rule is not None
|
||||||
|
assert knowledge_config.process_rule.rules is not None
|
||||||
|
assert knowledge_config.process_rule.rules.parent_mode == "paragraph"
|
||||||
|
|
||||||
|
|
||||||
class TestDocumentServiceSaveDocumentWithDatasetId:
|
class TestDocumentServiceSaveDocumentWithDatasetId:
|
||||||
"""Unit tests for non-SQL validation branches in save_document_with_dataset_id."""
|
"""Unit tests for non-SQL validation branches in save_document_with_dataset_id."""
|
||||||
|
|||||||
@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck:
|
|||||||
"""Test API use check when API has one binding."""
|
"""Test API use check when API has one binding."""
|
||||||
# Arrange
|
# Arrange
|
||||||
api_id = "api-123"
|
api_id = "api-123"
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
|
||||||
mock_db.session.scalar.return_value = 1
|
mock_db.session.scalar.return_value = 1
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert in_use is True
|
assert in_use is True
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
|
||||||
|
|
||||||
@patch("services.external_knowledge_service.db")
|
@patch("services.external_knowledge_service.db")
|
||||||
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
|
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
|
||||||
"""Test API use check with multiple bindings."""
|
"""Test API use check with multiple bindings."""
|
||||||
# Arrange
|
# Arrange
|
||||||
api_id = "api-123"
|
api_id = "api-123"
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
|
||||||
mock_db.session.scalar.return_value = 10
|
mock_db.session.scalar.return_value = 10
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert in_use is True
|
assert in_use is True
|
||||||
@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck:
|
|||||||
"""Test API use check when API is not in use."""
|
"""Test API use check when API is not in use."""
|
||||||
# Arrange
|
# Arrange
|
||||||
api_id = "api-123"
|
api_id = "api-123"
|
||||||
|
tenant_id = "tenant-123"
|
||||||
|
|
||||||
mock_db.session.scalar.return_value = 0
|
mock_db.session.scalar.return_value = 0
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert in_use is False
|
assert in_use is False
|
||||||
|
|||||||
@ -1,392 +0,0 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from core.ops.entities.config_entity import TracingProviderEnum
|
|
||||||
from models.model import App, TraceAppConfig
|
|
||||||
from services.ops_service import OpsService
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpsService:
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
mock_db.session.scalar.return_value = None
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
trace_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
mock_db.session.scalar.return_value = trace_config
|
|
||||||
mock_db.session.get.return_value = None
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
trace_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
trace_config.tracing_config = None
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = trace_config
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
|
|
||||||
# Act & Assert
|
|
||||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
|
||||||
OpsService.get_tracing_app_config("app_id", "arize")
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("provider", "default_url"),
|
|
||||||
[
|
|
||||||
("arize", "https://app.arize.com/"),
|
|
||||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
|
||||||
("langsmith", "https://smith.langchain.com/"),
|
|
||||||
("opik", "https://www.comet.com/opik/"),
|
|
||||||
("weave", "https://wandb.ai/"),
|
|
||||||
("aliyun", "https://arms.console.aliyun.com/"),
|
|
||||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
|
||||||
("mlflow", "http://localhost:5000/"),
|
|
||||||
("databricks", "https://www.databricks.com/"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
|
|
||||||
# Arrange
|
|
||||||
trace_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
trace_config.tracing_config = {"some": "config"}
|
|
||||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = trace_config
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
|
|
||||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
|
||||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result["tracing_config"]["project_url"] == default_url
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
|
|
||||||
)
|
|
||||||
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
|
|
||||||
# Arrange
|
|
||||||
trace_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
trace_config.tracing_config = {"some": "config"}
|
|
||||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = trace_config
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
|
|
||||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
|
||||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result["tracing_config"]["project_url"] == "success_url"
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
trace_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
trace_config.tracing_config = {"some": "config"}
|
|
||||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = trace_config
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
|
|
||||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
|
||||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
trace_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
trace_config.tracing_config = {"some": "config"}
|
|
||||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = trace_config
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
|
|
||||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
|
||||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Act
|
|
||||||
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.LANGFUSE
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == {"error": "Invalid Credentials"}
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("provider", "config"),
|
|
||||||
[
|
|
||||||
(TracingProviderEnum.ARIZE, {}),
|
|
||||||
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
|
||||||
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
|
||||||
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
|
|
||||||
# Arrange
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
|
||||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.create_tracing_app_config("app_id", provider, config)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.LANGFUSE
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = None
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.create_tracing_app_config(
|
|
||||||
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == {"result": "success"}
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
|
||||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
|
||||||
mock_db.session.scalar.return_value = None
|
|
||||||
mock_db.session.get.return_value = None
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = None
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
|
||||||
|
|
||||||
# Act
|
|
||||||
# 'project' is in other_keys for Arize
|
|
||||||
# provide an empty string for the project in the tracing_config
|
|
||||||
# create_tracing_app_config will replace it with the default from the model
|
|
||||||
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == {"result": "success"}
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
|
||||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = None
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == {"result": "success"}
|
|
||||||
mock_db.session.add.assert_called()
|
|
||||||
mock_db.session.commit.assert_called()
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Act & Assert
|
|
||||||
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
|
||||||
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
mock_db.session.scalar.return_value = None
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
current_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
mock_db.session.scalar.return_value = current_config
|
|
||||||
mock_db.session.get.return_value = None
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
current_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = current_config
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
|
||||||
|
|
||||||
# Act & Assert
|
|
||||||
with pytest.raises(ValueError, match="Invalid Credentials"):
|
|
||||||
OpsService.update_tracing_app_config("app_id", provider, {})
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
@patch("services.ops_service.OpsTraceManager")
|
|
||||||
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
|
||||||
# Arrange
|
|
||||||
provider = TracingProviderEnum.ARIZE
|
|
||||||
current_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
current_config.to_dict.return_value = {"some": "data"}
|
|
||||||
app = MagicMock(spec=App)
|
|
||||||
app.tenant_id = "tenant_id"
|
|
||||||
mock_db.session.scalar.return_value = current_config
|
|
||||||
mock_db.session.get.return_value = app
|
|
||||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
|
||||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == {"some": "data"}
|
|
||||||
mock_db.session.commit.assert_called_once()
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
def test_delete_tracing_app_config_no_config(self, mock_db):
|
|
||||||
# Arrange
|
|
||||||
mock_db.session.scalar.return_value = None
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@patch("services.ops_service.db")
|
|
||||||
def test_delete_tracing_app_config_success(self, mock_db):
|
|
||||||
# Arrange
|
|
||||||
trace_config = MagicMock(spec=TraceAppConfig)
|
|
||||||
mock_db.session.scalar.return_value = trace_config
|
|
||||||
|
|
||||||
# Act
|
|
||||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result is True
|
|
||||||
mock_db.session.delete.assert_called_with(trace_config)
|
|
||||||
mock_db.session.commit.assert_called_once()
|
|
||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import type { Area } from 'react-easy-crop'
|
import type { Area } from 'react-easy-crop'
|
||||||
import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput'
|
import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput'
|
||||||
import type { AvatarProps } from '@/app/components/base/avatar'
|
import type { AvatarProps } from '@/app/components/base/ui/avatar'
|
||||||
import type { ImageFile } from '@/types/app'
|
import type { ImageFile } from '@/types/app'
|
||||||
import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react'
|
import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react'
|
||||||
import * as React from 'react'
|
import * as React from 'react'
|
||||||
@ -10,10 +10,10 @@ import { useCallback, useState } from 'react'
|
|||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
|
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
|
||||||
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
|
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import Divider from '@/app/components/base/divider'
|
import Divider from '@/app/components/base/divider'
|
||||||
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
|
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
||||||
import { toast } from '@/app/components/base/ui/toast'
|
import { toast } from '@/app/components/base/ui/toast'
|
||||||
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'
|
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'
|
||||||
|
|||||||
@ -6,9 +6,9 @@ import {
|
|||||||
import { Fragment } from 'react'
|
import { Fragment } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
|
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { useProviderContext } from '@/context/provider-context'
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
import { useRouter } from '@/next/navigation'
|
import { useRouter } from '@/next/navigation'
|
||||||
import { useLogout, useUserProfile } from '@/service/use-common'
|
import { useLogout, useUserProfile } from '@/service/use-common'
|
||||||
|
|||||||
@ -10,9 +10,9 @@ import {
|
|||||||
import * as React from 'react'
|
import * as React from 'react'
|
||||||
import { useEffect, useRef } from 'react'
|
import { useEffect, useRef } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import Loading from '@/app/components/base/loading'
|
import Loading from '@/app/components/base/loading'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { toast } from '@/app/components/base/ui/toast'
|
import { toast } from '@/app/components/base/ui/toast'
|
||||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||||
|
|||||||
@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi
|
|||||||
import { useDebounce } from 'ahooks'
|
import { useDebounce } from 'ahooks'
|
||||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { useSelector } from '@/context/app-context'
|
import { useSelector } from '@/context/app-context'
|
||||||
import { SubjectType } from '@/models/access-control'
|
import { SubjectType } from '@/models/access-control'
|
||||||
import { useSearchForWhiteListCandidates } from '@/service/access-control'
|
import { useSearchForWhiteListCandidates } from '@/service/access-control'
|
||||||
import { cn } from '@/utils/classnames'
|
import { cn } from '@/utils/classnames'
|
||||||
import useAccessControlStore from '../../../../context/access-control-store'
|
import useAccessControlStore from '../../../../context/access-control-store'
|
||||||
import { Avatar } from '../../base/avatar'
|
|
||||||
import Button from '../../base/button'
|
import Button from '../../base/button'
|
||||||
import Checkbox from '../../base/checkbox'
|
import Checkbox from '../../base/checkbox'
|
||||||
import Input from '../../base/input'
|
import Input from '../../base/input'
|
||||||
|
|||||||
@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c
|
|||||||
import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react'
|
import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react'
|
||||||
import { useCallback, useEffect } from 'react'
|
import { useCallback, useEffect } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { AccessMode } from '@/models/access-control'
|
import { AccessMode } from '@/models/access-control'
|
||||||
import { useAppWhiteListSubjects } from '@/service/access-control'
|
import { useAppWhiteListSubjects } from '@/service/access-control'
|
||||||
import useAccessControlStore from '../../../../context/access-control-store'
|
import useAccessControlStore from '../../../../context/access-control-store'
|
||||||
import { Avatar } from '../../base/avatar'
|
|
||||||
import Loading from '../../base/loading'
|
import Loading from '../../base/loading'
|
||||||
import Tooltip from '../../base/tooltip'
|
import Tooltip from '../../base/tooltip'
|
||||||
import AddMemberOrGroupDialog from './add-member-or-group-pop'
|
import AddMemberOrGroupDialog from './add-member-or-group-pop'
|
||||||
|
|||||||
@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
|||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/app/components/base/avatar', () => ({
|
vi.mock('@/app/components/base/ui/avatar', () => ({
|
||||||
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
|
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,11 @@ import {
|
|||||||
useCallback,
|
useCallback,
|
||||||
useMemo,
|
useMemo,
|
||||||
} from 'react'
|
} from 'react'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import Chat from '@/app/components/base/chat/chat'
|
import Chat from '@/app/components/base/chat/chat'
|
||||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||||
import { getLastAnswer } from '@/app/components/base/chat/utils'
|
import { getLastAnswer } from '@/app/components/base/chat/utils'
|
||||||
import { useFeatures } from '@/app/components/base/features/hooks'
|
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import { useAppContext } from '@/context/app-context'
|
import { useAppContext } from '@/context/app-context'
|
||||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||||
|
|||||||
@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty
|
|||||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||||
import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
|
import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
|
||||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import Chat from '@/app/components/base/chat/chat'
|
import Chat from '@/app/components/base/chat/chat'
|
||||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||||
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
|
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
|
||||||
import { useFeatures } from '@/app/components/base/features/hooks'
|
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import { useAppContext } from '@/context/app-context'
|
import { useAppContext } from '@/context/app-context'
|
||||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon'
|
|||||||
import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form'
|
import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form'
|
||||||
import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions'
|
import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions'
|
||||||
import { Markdown } from '@/app/components/base/markdown'
|
import { Markdown } from '@/app/components/base/markdown'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { InputVarType } from '@/app/components/workflow/types'
|
import { InputVarType } from '@/app/components/workflow/types'
|
||||||
import {
|
import {
|
||||||
AppSourceType,
|
AppSourceType,
|
||||||
@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w
|
|||||||
import { TransferMethod } from '@/types/app'
|
import { TransferMethod } from '@/types/app'
|
||||||
import { cn } from '@/utils/classnames'
|
import { cn } from '@/utils/classnames'
|
||||||
import { formatBooleanInputs } from '@/utils/model-config'
|
import { formatBooleanInputs } from '@/utils/model-config'
|
||||||
import { Avatar } from '../../avatar'
|
|
||||||
import Chat from '../chat'
|
import Chat from '../chat'
|
||||||
import { useChat } from '../chat/hooks'
|
import { useChat } from '../chat/hooks'
|
||||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
|
|||||||
import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form'
|
import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form'
|
||||||
import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar'
|
import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar'
|
||||||
import { Markdown } from '@/app/components/base/markdown'
|
import { Markdown } from '@/app/components/base/markdown'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { InputVarType } from '@/app/components/workflow/types'
|
import { InputVarType } from '@/app/components/workflow/types'
|
||||||
import {
|
import {
|
||||||
AppSourceType,
|
AppSourceType,
|
||||||
@ -23,7 +24,6 @@ import {
|
|||||||
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
|
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
|
||||||
import { TransferMethod } from '@/types/app'
|
import { TransferMethod } from '@/types/app'
|
||||||
import { cn } from '@/utils/classnames'
|
import { cn } from '@/utils/classnames'
|
||||||
import { Avatar } from '../../avatar'
|
|
||||||
import Chat from '../chat'
|
import Chat from '../chat'
|
||||||
import { useChat } from '../chat/hooks'
|
import { useChat } from '../chat/hooks'
|
||||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { render, screen } from '@testing-library/react'
|
import { render, screen } from '@testing-library/react'
|
||||||
import { Avatar } from '../index'
|
import { Avatar } from '..'
|
||||||
|
|
||||||
describe('Avatar', () => {
|
describe('Avatar', () => {
|
||||||
describe('Rendering', () => {
|
describe('Rendering', () => {
|
||||||
@ -53,8 +53,8 @@ function AvatarRoot({
|
|||||||
return (
|
return (
|
||||||
<BaseAvatar.Root
|
<BaseAvatar.Root
|
||||||
className={cn(
|
className={cn(
|
||||||
'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600',
|
'relative inline-flex shrink-0 items-center justify-center overflow-hidden rounded-full bg-primary-600 select-none',
|
||||||
isAvatarPresetSize(size) && avatarSizeClasses[size].root,
|
avatarSizeClasses[size].root,
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
style={resolvedStyle}
|
style={resolvedStyle}
|
||||||
@ -104,7 +104,7 @@ function AvatarImage({
|
|||||||
}: AvatarImageProps) {
|
}: AvatarImageProps) {
|
||||||
return (
|
return (
|
||||||
<BaseAvatar.Image
|
<BaseAvatar.Image
|
||||||
className={cn('absolute inset-0 size-full object-cover', className)}
|
className={cn('inset-0 absolute size-full object-cover', className)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
@ -4,13 +4,13 @@ import { useDebounceFn } from 'ahooks'
|
|||||||
import * as React from 'react'
|
import * as React from 'react'
|
||||||
import { useCallback, useMemo, useState } from 'react'
|
import { useCallback, useMemo, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import Input from '@/app/components/base/input'
|
import Input from '@/app/components/base/input'
|
||||||
import {
|
import {
|
||||||
PortalToFollowElem,
|
PortalToFollowElem,
|
||||||
PortalToFollowElemContent,
|
PortalToFollowElemContent,
|
||||||
PortalToFollowElemTrigger,
|
PortalToFollowElemTrigger,
|
||||||
} from '@/app/components/base/portal-to-follow-elem'
|
} from '@/app/components/base/portal-to-follow-elem'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||||
import { DatasetPermission } from '@/models/datasets'
|
import { DatasetPermission } from '@/models/datasets'
|
||||||
import { cn } from '@/utils/classnames'
|
import { cn } from '@/utils/classnames'
|
||||||
|
|||||||
@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react'
|
|||||||
import { useState } from 'react'
|
import { useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||||
import ThemeSwitcher from '@/app/components/base/theme-switcher'
|
import ThemeSwitcher from '@/app/components/base/theme-switcher'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
|
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||||
import { IS_CLOUD_EDITION } from '@/config'
|
import { IS_CLOUD_EDITION } from '@/config'
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
import type { InvitationResult } from '@/models/common'
|
import type { InvitationResult } from '@/models/common'
|
||||||
import { useState } from 'react'
|
import { useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||||
import { NUM_INFINITE } from '@/app/components/billing/config'
|
import { NUM_INFINITE } from '@/app/components/billing/config'
|
||||||
import { Plan } from '@/app/components/billing/type'
|
import { Plan } from '@/app/components/billing/type'
|
||||||
|
|||||||
@ -3,9 +3,9 @@ import type { FC } from 'react'
|
|||||||
import * as React from 'react'
|
import * as React from 'react'
|
||||||
import { useMemo, useState } from 'react'
|
import { useMemo, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Avatar } from '@/app/components/base/avatar'
|
|
||||||
import Input from '@/app/components/base/input'
|
import Input from '@/app/components/base/input'
|
||||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
|
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
|
||||||
|
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||||
import { useMembers } from '@/service/use-common'
|
import { useMembers } from '@/service/use-common'
|
||||||
import { cn } from '@/utils/classnames'
|
import { cn } from '@/utils/classnames'
|
||||||
|
|
||||||
|
|||||||
46
web/app/components/workflow/__tests__/block-icon.spec.tsx
Normal file
46
web/app/components/workflow/__tests__/block-icon.spec.tsx
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import { render } from '@testing-library/react'
|
||||||
|
import { API_PREFIX } from '@/config'
|
||||||
|
import BlockIcon, { VarBlockIcon } from '../block-icon'
|
||||||
|
import { BlockEnum } from '../types'
|
||||||
|
|
||||||
|
describe('BlockIcon', () => {
|
||||||
|
it('renders the default workflow icon container for regular nodes', () => {
|
||||||
|
const { container } = render(<BlockIcon type={BlockEnum.Start} size="xs" className="extra-class" />)
|
||||||
|
|
||||||
|
const iconContainer = container.firstElementChild
|
||||||
|
expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class')
|
||||||
|
expect(iconContainer?.querySelector('svg')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('normalizes protected plugin icon urls for tool-like nodes', () => {
|
||||||
|
const { container } = render(
|
||||||
|
<BlockIcon
|
||||||
|
type={BlockEnum.Tool}
|
||||||
|
toolIcon="/foo/workspaces/current/plugin/icon/plugin-tool.png"
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
const iconContainer = container.firstElementChild as HTMLElement
|
||||||
|
const backgroundIcon = iconContainer.querySelector('div') as HTMLElement
|
||||||
|
|
||||||
|
expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500')
|
||||||
|
expect(backgroundIcon.style.backgroundImage).toContain(
|
||||||
|
`${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('VarBlockIcon', () => {
|
||||||
|
it('renders the compact icon variant without the default container wrapper', () => {
|
||||||
|
const { container } = render(
|
||||||
|
<VarBlockIcon
|
||||||
|
type={BlockEnum.Answer}
|
||||||
|
className="custom-var-icon"
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(container.querySelector('.custom-var-icon')).toBeInTheDocument()
|
||||||
|
expect(container.querySelector('svg')).toBeInTheDocument()
|
||||||
|
expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
39
web/app/components/workflow/__tests__/context.spec.tsx
Normal file
39
web/app/components/workflow/__tests__/context.spec.tsx
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import { render, screen } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import { WorkflowContextProvider } from '../context'
|
||||||
|
import { useStore, useWorkflowStore } from '../store'
|
||||||
|
|
||||||
|
const StoreConsumer = () => {
|
||||||
|
const showSingleRunPanel = useStore(s => s.showSingleRunPanel)
|
||||||
|
const store = useWorkflowStore()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button onClick={() => store.getState().setShowSingleRunPanel(!showSingleRunPanel)}>
|
||||||
|
{showSingleRunPanel ? 'open' : 'closed'}
|
||||||
|
</button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('WorkflowContextProvider', () => {
|
||||||
|
it('provides the workflow store to descendants and keeps the same store across rerenders', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const { rerender } = render(
|
||||||
|
<WorkflowContextProvider>
|
||||||
|
<StoreConsumer />
|
||||||
|
</WorkflowContextProvider>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument()
|
||||||
|
|
||||||
|
await user.click(screen.getByRole('button', { name: 'closed' }))
|
||||||
|
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
|
||||||
|
|
||||||
|
rerender(
|
||||||
|
<WorkflowContextProvider>
|
||||||
|
<StoreConsumer />
|
||||||
|
</WorkflowContextProvider>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
67
web/app/components/workflow/__tests__/index.spec.tsx
Normal file
67
web/app/components/workflow/__tests__/index.spec.tsx
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import type { Edge, Node } from '../types'
|
||||||
|
import { render, screen } from '@testing-library/react'
|
||||||
|
import { useStoreApi } from 'reactflow'
|
||||||
|
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
|
||||||
|
import WorkflowWithDefaultContext from '../index'
|
||||||
|
import { BlockEnum } from '../types'
|
||||||
|
import { useWorkflowHistoryStore } from '../workflow-history-store'
|
||||||
|
|
||||||
|
const nodes: Node[] = [
|
||||||
|
{
|
||||||
|
id: 'node-start',
|
||||||
|
type: 'custom',
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
data: {
|
||||||
|
title: 'Start',
|
||||||
|
desc: '',
|
||||||
|
type: BlockEnum.Start,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const edges: Edge[] = [
|
||||||
|
{
|
||||||
|
id: 'edge-1',
|
||||||
|
source: 'node-start',
|
||||||
|
target: 'node-end',
|
||||||
|
sourceHandle: null,
|
||||||
|
targetHandle: null,
|
||||||
|
type: 'custom',
|
||||||
|
data: {
|
||||||
|
sourceType: BlockEnum.Start,
|
||||||
|
targetType: BlockEnum.End,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const ContextConsumer = () => {
|
||||||
|
const { store, shortcutsEnabled } = useWorkflowHistoryStore()
|
||||||
|
const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length)
|
||||||
|
const reactFlowStore = useStoreApi()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
{`history:${store.getState().nodes.length}`}
|
||||||
|
{` shortcuts:${String(shortcutsEnabled)}`}
|
||||||
|
{` datasets:${datasetCount}`}
|
||||||
|
{` reactflow:${String(!!reactFlowStore)}`}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('WorkflowWithDefaultContext', () => {
|
||||||
|
it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => {
|
||||||
|
render(
|
||||||
|
<WorkflowWithDefaultContext
|
||||||
|
nodes={nodes}
|
||||||
|
edges={edges}
|
||||||
|
>
|
||||||
|
<ContextConsumer />
|
||||||
|
</WorkflowWithDefaultContext>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'),
|
||||||
|
).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -0,0 +1,51 @@
|
|||||||
|
import { render, screen } from '@testing-library/react'
|
||||||
|
import ShortcutsName from '../shortcuts-name'
|
||||||
|
|
||||||
|
describe('ShortcutsName', () => {
|
||||||
|
const originalNavigator = globalThis.navigator
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
Object.defineProperty(globalThis, 'navigator', {
|
||||||
|
value: originalNavigator,
|
||||||
|
writable: true,
|
||||||
|
configurable: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders mac-friendly key labels and style variants', () => {
|
||||||
|
Object.defineProperty(globalThis, 'navigator', {
|
||||||
|
value: { userAgent: 'Macintosh' },
|
||||||
|
writable: true,
|
||||||
|
configurable: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
const { container } = render(
|
||||||
|
<ShortcutsName
|
||||||
|
keys={['ctrl', 'shift', 's']}
|
||||||
|
bgColor="white"
|
||||||
|
textColor="secondary"
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByText('⌘')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('⇧')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('s')).toBeInTheDocument()
|
||||||
|
expect(container.querySelector('.system-kbd')).toHaveClass(
|
||||||
|
'bg-components-kbd-bg-white',
|
||||||
|
'text-text-tertiary',
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('keeps raw key names on non-mac systems', () => {
|
||||||
|
Object.defineProperty(globalThis, 'navigator', {
|
||||||
|
value: { userAgent: 'Windows NT' },
|
||||||
|
writable: true,
|
||||||
|
configurable: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
render(<ShortcutsName keys={['ctrl', 'alt']} />)
|
||||||
|
|
||||||
|
expect(screen.getByText('ctrl')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('alt')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -0,0 +1,97 @@
|
|||||||
|
import type { Edge, Node } from '../types'
|
||||||
|
import type { WorkflowHistoryState } from '../workflow-history-store'
|
||||||
|
import { render, renderHook, screen } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import { BlockEnum } from '../types'
|
||||||
|
import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store'
|
||||||
|
|
||||||
|
const nodes: Node[] = [
|
||||||
|
{
|
||||||
|
id: 'node-1',
|
||||||
|
type: 'custom',
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
data: {
|
||||||
|
title: 'Start',
|
||||||
|
desc: '',
|
||||||
|
type: BlockEnum.Start,
|
||||||
|
selected: true,
|
||||||
|
},
|
||||||
|
selected: true,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const edges: Edge[] = [
|
||||||
|
{
|
||||||
|
id: 'edge-1',
|
||||||
|
source: 'node-1',
|
||||||
|
target: 'node-2',
|
||||||
|
sourceHandle: null,
|
||||||
|
targetHandle: null,
|
||||||
|
type: 'custom',
|
||||||
|
selected: true,
|
||||||
|
data: {
|
||||||
|
sourceType: BlockEnum.Start,
|
||||||
|
targetType: BlockEnum.End,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
const HistoryConsumer = () => {
|
||||||
|
const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore()
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button onClick={() => setShortcutsEnabled(!shortcutsEnabled)}>
|
||||||
|
{`nodes:${store.getState().nodes.length} shortcuts:${String(shortcutsEnabled)}`}
|
||||||
|
</button>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('WorkflowHistoryProvider', () => {
|
||||||
|
it('provides workflow history state and shortcut toggles', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(
|
||||||
|
<WorkflowHistoryProvider
|
||||||
|
nodes={nodes}
|
||||||
|
edges={edges}
|
||||||
|
>
|
||||||
|
<HistoryConsumer />
|
||||||
|
</WorkflowHistoryProvider>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument()
|
||||||
|
|
||||||
|
await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' }))
|
||||||
|
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('sanitizes selected flags when history state is replaced through the exposed store api', () => {
|
||||||
|
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||||
|
<WorkflowHistoryProvider
|
||||||
|
nodes={nodes}
|
||||||
|
edges={edges}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</WorkflowHistoryProvider>
|
||||||
|
)
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper })
|
||||||
|
const nextState: WorkflowHistoryState = {
|
||||||
|
workflowHistoryEvent: undefined,
|
||||||
|
workflowHistoryEventMeta: undefined,
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
}
|
||||||
|
|
||||||
|
result.current.store.setState(nextState)
|
||||||
|
|
||||||
|
expect(result.current.store.getState().nodes[0].data.selected).toBe(false)
|
||||||
|
expect(result.current.store.getState().edges[0].selected).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('throws when consumed outside the provider', () => {
|
||||||
|
expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow(
|
||||||
|
'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider',
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -0,0 +1,140 @@
|
|||||||
|
import { render, screen, waitFor } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks'
|
||||||
|
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||||
|
import { useGetLanguage } from '@/context/i18n'
|
||||||
|
import useTheme from '@/hooks/use-theme'
|
||||||
|
import { Theme } from '@/types/app'
|
||||||
|
import AllTools from '../all-tools'
|
||||||
|
import { createGlobalPublicStoreState, createToolProvider } from './factories'
|
||||||
|
|
||||||
|
vi.mock('@/context/global-public-context', () => ({
|
||||||
|
useGlobalPublicStore: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/context/i18n', () => ({
|
||||||
|
useGetLanguage: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/hooks/use-theme', () => ({
|
||||||
|
default: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||||
|
useMarketplacePlugins: vi.fn(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||||
|
useMCPToolAvailability: () => ({
|
||||||
|
allowed: true,
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/utils/var', async importOriginal => ({
|
||||||
|
...(await importOriginal<typeof import('@/utils/var')>()),
|
||||||
|
getMarketplaceUrl: () => 'https://marketplace.test/tools',
|
||||||
|
}))
|
||||||
|
|
||||||
|
const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins)
|
||||||
|
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
|
||||||
|
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||||
|
const mockUseTheme = vi.mocked(useTheme)
|
||||||
|
|
||||||
|
const createMarketplacePluginsMock = () => ({
|
||||||
|
plugins: [],
|
||||||
|
total: 0,
|
||||||
|
resetPlugins: vi.fn(),
|
||||||
|
queryPlugins: vi.fn(),
|
||||||
|
queryPluginsWithDebounced: vi.fn(),
|
||||||
|
cancelQueryPluginsWithDebounced: vi.fn(),
|
||||||
|
isLoading: false,
|
||||||
|
isFetchingNextPage: false,
|
||||||
|
hasNextPage: false,
|
||||||
|
fetchNextPage: vi.fn(),
|
||||||
|
page: 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('AllTools', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false)))
|
||||||
|
mockUseGetLanguage.mockReturnValue('en_US')
|
||||||
|
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||||
|
mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock())
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters tools by the active tab', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
|
||||||
|
render(
|
||||||
|
<AllTools
|
||||||
|
searchText=""
|
||||||
|
tags={[]}
|
||||||
|
onSelect={vi.fn()}
|
||||||
|
buildInTools={[createToolProvider({
|
||||||
|
id: 'provider-built-in',
|
||||||
|
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
|
||||||
|
})]}
|
||||||
|
customTools={[createToolProvider({
|
||||||
|
id: 'provider-custom',
|
||||||
|
type: 'custom',
|
||||||
|
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
|
||||||
|
})]}
|
||||||
|
workflowTools={[]}
|
||||||
|
mcpTools={[]}
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||||
|
|
||||||
|
await user.click(screen.getByText('workflow.tabs.customTool'))
|
||||||
|
|
||||||
|
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||||
|
expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters the rendered tools by the search text', () => {
|
||||||
|
render(
|
||||||
|
<AllTools
|
||||||
|
searchText="report"
|
||||||
|
tags={[]}
|
||||||
|
onSelect={vi.fn()}
|
||||||
|
buildInTools={[
|
||||||
|
createToolProvider({
|
||||||
|
id: 'provider-report',
|
||||||
|
label: { en_US: 'Report Toolkit', zh_Hans: 'Report Toolkit' },
|
||||||
|
}),
|
||||||
|
createToolProvider({
|
||||||
|
id: 'provider-other',
|
||||||
|
label: { en_US: 'Other Toolkit', zh_Hans: 'Other Toolkit' },
|
||||||
|
}),
|
||||||
|
]}
|
||||||
|
customTools={[]}
|
||||||
|
workflowTools={[]}
|
||||||
|
mcpTools={[]}
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByText('Report Toolkit')).toBeInTheDocument()
|
||||||
|
expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('shows the empty state when no tool matches the current filter', async () => {
|
||||||
|
render(
|
||||||
|
<AllTools
|
||||||
|
searchText="missing"
|
||||||
|
tags={[]}
|
||||||
|
onSelect={vi.fn()}
|
||||||
|
buildInTools={[]}
|
||||||
|
customTools={[]}
|
||||||
|
workflowTools={[]}
|
||||||
|
mcpTools={[]}
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -0,0 +1,79 @@
|
|||||||
|
import type { NodeDefault } from '../../types'
|
||||||
|
import { render, screen } from '@testing-library/react'
|
||||||
|
import userEvent from '@testing-library/user-event'
|
||||||
|
import { BlockEnum } from '../../types'
|
||||||
|
import Blocks from '../blocks'
|
||||||
|
import { BlockClassificationEnum } from '../types'
|
||||||
|
|
||||||
|
const runtimeState = vi.hoisted(() => ({
|
||||||
|
nodes: [] as Array<{ data: { type?: BlockEnum } }>,
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('reactflow', () => ({
|
||||||
|
useStoreApi: () => ({
|
||||||
|
getState: () => ({
|
||||||
|
getNodes: () => runtimeState.nodes,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({
|
||||||
|
metaData: {
|
||||||
|
classification,
|
||||||
|
sort: 0,
|
||||||
|
type,
|
||||||
|
title,
|
||||||
|
author: 'Dify',
|
||||||
|
description: `${title} description`,
|
||||||
|
},
|
||||||
|
defaultValue: {},
|
||||||
|
checkValid: () => ({ isValid: true }),
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Blocks', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
runtimeState.nodes = []
|
||||||
|
})
|
||||||
|
|
||||||
|
it('renders grouped blocks, filters duplicate knowledge-base nodes, and selects a block', async () => {
|
||||||
|
const user = userEvent.setup()
|
||||||
|
const onSelect = vi.fn()
|
||||||
|
|
||||||
|
runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }]
|
||||||
|
|
||||||
|
render(
|
||||||
|
<Blocks
|
||||||
|
searchText=""
|
||||||
|
onSelect={onSelect}
|
||||||
|
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.LoopEnd, BlockEnum.KnowledgeBase]}
|
||||||
|
blocks={[
|
||||||
|
createBlock(BlockEnum.LLM, 'LLM'),
|
||||||
|
createBlock(BlockEnum.LoopEnd, 'Exit Loop', BlockClassificationEnum.Logic),
|
||||||
|
createBlock(BlockEnum.KnowledgeBase, 'Knowledge Retrieval'),
|
||||||
|
]}
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByText('LLM')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('Exit Loop')).toBeInTheDocument()
|
||||||
|
expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument()
|
||||||
|
expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument()
|
||||||
|
|
||||||
|
await user.click(screen.getByText('LLM'))
|
||||||
|
|
||||||
|
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('shows the empty state when no block matches the search', () => {
|
||||||
|
render(
|
||||||
|
<Blocks
|
||||||
|
searchText="missing"
|
||||||
|
onSelect={vi.fn()}
|
||||||
|
availableBlocksTypes={[BlockEnum.LLM]}
|
||||||
|
blocks={[createBlock(BlockEnum.LLM, 'LLM')]}
|
||||||
|
/>,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument()
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
import type { ToolWithProvider } from '../../types'
|
||||||
|
import type { Plugin } from '@/app/components/plugins/types'
|
||||||
|
import type { Tool } from '@/app/components/tools/types'
|
||||||
|
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||||
|
import { CollectionType } from '@/app/components/tools/types'
|
||||||
|
import { defaultSystemFeatures } from '@/types/feature'
|
||||||
|
|
||||||
|
export const createTool = (
|
||||||
|
name: string,
|
||||||
|
label: string,
|
||||||
|
description = `${label} description`,
|
||||||
|
): Tool => ({
|
||||||
|
name,
|
||||||
|
author: 'author',
|
||||||
|
label: {
|
||||||
|
en_US: label,
|
||||||
|
zh_Hans: label,
|
||||||
|
},
|
||||||
|
description: {
|
||||||
|
en_US: description,
|
||||||
|
zh_Hans: description,
|
||||||
|
},
|
||||||
|
parameters: [],
|
||||||
|
labels: [],
|
||||||
|
output_schema: {},
|
||||||
|
})
|
||||||
|
|
||||||
|
export const createToolProvider = (
|
||||||
|
overrides: Partial<ToolWithProvider> = {},
|
||||||
|
): ToolWithProvider => ({
|
||||||
|
id: 'provider-1',
|
||||||
|
name: 'provider-one',
|
||||||
|
author: 'Provider Author',
|
||||||
|
description: {
|
||||||
|
en_US: 'Provider description',
|
||||||
|
zh_Hans: 'Provider description',
|
||||||
|
},
|
||||||
|
icon: 'icon',
|
||||||
|
icon_dark: 'icon-dark',
|
||||||
|
label: {
|
||||||
|
en_US: 'Provider One',
|
||||||
|
zh_Hans: 'Provider One',
|
||||||
|
},
|
||||||
|
type: CollectionType.builtIn,
|
||||||
|
team_credentials: {},
|
||||||
|
is_team_authorization: false,
|
||||||
|
allow_delete: false,
|
||||||
|
labels: [],
|
||||||
|
plugin_id: 'plugin-1',
|
||||||
|
tools: [createTool('tool-a', 'Tool A')],
|
||||||
|
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
|
||||||
|
plugin_unique_identifier: 'plugin-1@1.0.0',
|
||||||
|
...overrides,
|
||||||
|
})
|
||||||
|
|
||||||
|
export const createPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
|
||||||
|
type: 'plugin',
|
||||||
|
org: 'org',
|
||||||
|
author: 'author',
|
||||||
|
name: 'Plugin One',
|
||||||
|
plugin_id: 'plugin-1',
|
||||||
|
version: '1.0.0',
|
||||||
|
latest_version: '1.0.0',
|
||||||
|
latest_package_identifier: 'plugin-1@1.0.0',
|
||||||
|
icon: 'icon',
|
||||||
|
verified: true,
|
||||||
|
label: {
|
||||||
|
en_US: 'Plugin One',
|
||||||
|
zh_Hans: 'Plugin One',
|
||||||
|
},
|
||||||
|
brief: {
|
||||||
|
en_US: 'Plugin description',
|
||||||
|
zh_Hans: 'Plugin description',
|
||||||
|
},
|
||||||
|
description: {
|
||||||
|
en_US: 'Plugin description',
|
||||||
|
zh_Hans: 'Plugin description',
|
||||||
|
},
|
||||||
|
introduction: 'Plugin introduction',
|
||||||
|
repository: 'https://example.com/plugin',
|
||||||
|
category: PluginCategoryEnum.tool,
|
||||||
|
tags: [],
|
||||||
|
badges: [],
|
||||||
|
install_count: 0,
|
||||||
|
endpoint: {
|
||||||
|
settings: [],
|
||||||
|
},
|
||||||
|
verification: {
|
||||||
|
authorized_category: 'community',
|
||||||
|
},
|
||||||
|
from: 'github',
|
||||||
|
...overrides,
|
||||||
|
})
|
||||||
|
|
||||||
|
export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({
|
||||||
|
systemFeatures: {
|
||||||
|
...defaultSystemFeatures,
|
||||||
|
enable_marketplace: enableMarketplace,
|
||||||
|
},
|
||||||
|
setSystemFeatures: vi.fn(),
|
||||||
|
})
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user