mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
Merge remote-tracking branch 'myori/main' into feat/collaboration2
This commit is contained in:
commit
ee2b021395
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@ -7,6 +7,7 @@
|
||||
## Summary
|
||||
|
||||
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
||||
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
|
||||
|
||||
## Screenshots
|
||||
|
||||
|
||||
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
@ -0,0 +1,118 @@
|
||||
name: Comment with Pyrefly Type Coverage
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows:
|
||||
- Pyrefly Type Coverage
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Checkout default branch (trusted code)
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Download type coverage artifact
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }},
|
||||
});
|
||||
const match = artifacts.data.artifacts.find((artifact) =>
|
||||
artifact.name === 'pyrefly_type_coverage'
|
||||
);
|
||||
if (!match) {
|
||||
throw new Error('pyrefly_type_coverage artifact not found');
|
||||
}
|
||||
const download = await github.rest.actions.downloadArtifact({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
artifact_id: match.id,
|
||||
archive_format: 'zip',
|
||||
});
|
||||
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
|
||||
|
||||
- name: Unzip artifact
|
||||
run: unzip -o pyrefly_type_coverage.zip
|
||||
|
||||
- name: Render coverage markdown from structured data
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
echo ""
|
||||
echo "$comment_body"
|
||||
} > /tmp/type_coverage_comment.md
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||
let prNumber = null;
|
||||
try {
|
||||
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
|
||||
} catch (err) {
|
||||
const prs = context.payload.workflow_run.pull_requests || [];
|
||||
if (prs.length > 0 && prs[0].number) {
|
||||
prNumber = prs[0].number;
|
||||
}
|
||||
}
|
||||
if (!prNumber) {
|
||||
throw new Error('PR number not found in artifact or workflow_run payload');
|
||||
}
|
||||
|
||||
// Update existing comment if one exists, otherwise create new
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
const marker = '### Pyrefly Type Coverage';
|
||||
const existing = comments.find(c => c.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
@ -0,0 +1,120 @@
|
||||
name: Pyrefly Type Coverage
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'api/**/*.py'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run pyrefly report on PR branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
|
||||
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
|
||||
echo '{}' > /tmp/pyrefly_report_pr.json
|
||||
|
||||
- name: Save helper script from base branch
|
||||
run: |
|
||||
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|
||||
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
|
||||
|
||||
- name: Checkout base branch
|
||||
run: git checkout ${{ github.base_ref }}
|
||||
|
||||
- name: Run pyrefly report on base branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
|
||||
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
|
||||
echo '{}' > /tmp/pyrefly_report_base.json
|
||||
|
||||
- name: Generate coverage comparison
|
||||
id: coverage
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
|
||||
--base /tmp/pyrefly_report_base.json \
|
||||
< /tmp/pyrefly_report_pr.json)"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
echo ""
|
||||
echo "$comment_body"
|
||||
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
|
||||
|
||||
# Save structured data for the fork-PR comment workflow
|
||||
cp /tmp/pyrefly_report_pr.json pr_report.json
|
||||
cp /tmp/pyrefly_report_base.json base_report.json
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload type coverage artifact
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: pyrefly_type_coverage
|
||||
path: |
|
||||
pr_report.json
|
||||
base_report.json
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with type coverage
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const marker = '### Pyrefly Type Coverage';
|
||||
let body;
|
||||
try {
|
||||
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||
} catch {
|
||||
body = `${marker}\n\n_Coverage report unavailable._`;
|
||||
}
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
|
||||
// Update existing comment if one exists, otherwise create new
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
const existing = comments.find(c => c.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
|
||||
## Getting Help
|
||||
|
||||
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
|
||||
|
||||
## Automated Agent Contributions
|
||||
|
||||
> [!NOTE]
|
||||
> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in.
|
||||
|
||||
@ -2,7 +2,6 @@ import base64
|
||||
import secrets
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account = db.session.merge(account)
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account = db.session.merge(account)
|
||||
account.email = normalized_new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
def _update_existing_account(self, account, password_hashed, salt):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
||||
@ -4,7 +4,6 @@ import urllib.parse
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
external_knowledge_api_id
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from flask_restx import Resource, fields, marshal_with
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -580,8 +579,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
|
||||
@ -3,7 +3,6 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphon.file import File, FileUploadConfig
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
||||
trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel):
|
||||
output_schema: dict | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
|
||||
|
||||
|
||||
class DatasourceProviderApiEntityDict(TypedDict):
|
||||
id: str
|
||||
author: str
|
||||
name: str
|
||||
plugin_id: str | None
|
||||
plugin_unique_identifier: str | None
|
||||
description: I18nObjectDict
|
||||
icon: str | dict
|
||||
label: I18nObjectDict
|
||||
type: str
|
||||
team_credentials: dict | None
|
||||
is_team_authorization: bool
|
||||
allow_delete: bool
|
||||
datasources: list[Any]
|
||||
labels: list[str]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntity(BaseModel):
|
||||
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> DatasourceProviderApiEntityDict:
|
||||
# -------------
|
||||
# overwrite datasource parameter types for temp fix
|
||||
datasources = jsonable_encoder(self.datasources)
|
||||
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
result: DatasourceProviderApiEntityDict = {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
"datasources": datasources,
|
||||
"labels": self.labels,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
if not isinstance(message.message.blob, bytes):
|
||||
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||
tool_file_manager = ToolFileManager()
|
||||
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
|
||||
@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
|
||||
class RequestContext[SessionT: BaseSession, LifespanContextT]:
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
|
||||
@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
|
||||
@ -31,7 +31,6 @@ ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
||||
type AnyFunction = Callable[..., Any]
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
|
||||
@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@ -172,10 +172,10 @@ class ModelInstance:
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
@ -193,15 +193,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(
|
||||
@ -216,15 +213,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
@ -241,15 +235,12 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
@ -261,14 +252,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(
|
||||
@ -289,23 +277,20 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
@ -320,17 +305,14 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str) -> bool:
|
||||
@ -342,14 +324,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes]) -> str:
|
||||
@ -361,14 +340,11 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
|
||||
@ -381,18 +357,15 @@ class ModelInstance:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
),
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
@ -430,9 +403,8 @@ class ModelInstance:
|
||||
continue
|
||||
|
||||
try:
|
||||
if "credentials" in kwargs:
|
||||
del kwargs["credentials"]
|
||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
||||
kwargs["credentials"] = lb_config.credentials
|
||||
return function(*args, **kwargs)
|
||||
except InvokeRateLimitError as e:
|
||||
# expire in 60 seconds
|
||||
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
|
||||
def _get_user(cls, user_id: str) -> EndUser | Account:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
index_name = f"{self._collection_name}_embedding_index"
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
||||
session.execute(drop_statement)
|
||||
create_statement = sql_text(f"""
|
||||
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
|
||||
$$);
|
||||
""")
|
||||
session.execute(index_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
|
||||
self.delete_by_uuids(ids)
|
||||
|
||||
def delete(self):
|
||||
with Session(self.client) as session:
|
||||
with sessionmaker(bind=self.client).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
||||
session.commit()
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self.client) as session:
|
||||
|
||||
@ -6,7 +6,7 @@ import sqlalchemy
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field, parse_metadata_json
|
||||
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
tidb_dist_func = self._get_distance_func()
|
||||
with Session(self._engine) as session:
|
||||
session.begin()
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
create_statement = sql_text(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._collection_name} (
|
||||
id CHAR(36) PRIMARY KEY,
|
||||
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
|
||||
);
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
session.commit()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
with Session(self._engine) as session:
|
||||
with sessionmaker(bind=self._engine).begin() as session:
|
||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||
session.commit()
|
||||
|
||||
def _get_distance_func(self) -> str:
|
||||
match self._distance_func:
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class ParagraphFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword = Keyword(dataset)
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
|
||||
if isinstance(chunks, list):
|
||||
preview = []
|
||||
for content in chunks:
|
||||
preview.append({"content": content})
|
||||
return {
|
||||
result: ParagraphFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
|
||||
"preview": preview,
|
||||
"total_segments": len(chunks),
|
||||
}
|
||||
return result
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
|
||||
@ -3,8 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParentChildFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
parent_mode: str
|
||||
preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
|
||||
return {
|
||||
result: ParentChildFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
|
||||
"parent_mode": parent_childs.parent_mode,
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
}
|
||||
return result
|
||||
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
|
||||
@ -4,8 +4,7 @@ import logging
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QAFormatPreviewDict(TypedDict):
|
||||
chunk_structure: str
|
||||
qa_preview: list[dict[str, Any]]
|
||||
total_segments: int
|
||||
|
||||
|
||||
class QAIndexProcessor(BaseIndexProcessor):
|
||||
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
|
||||
text_docs = ExtractProcessor.extract(
|
||||
@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
raise ValueError("Indexing technique must be high quality.")
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
preview = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
|
||||
return {
|
||||
result: QAFormatPreviewDict = {
|
||||
"chunk_structure": IndexStructureType.QA_INDEX,
|
||||
"qa_preview": preview,
|
||||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
return result
|
||||
|
||||
def generate_summary_preview(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.rerank_entities import RerankResult
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param query_type: query type
|
||||
:return: rerank result
|
||||
"""
|
||||
docs = []
|
||||
docs: list[MultimodalRerankInput] = []
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
document_file_base64 = base64.b64encode(blob).decode()
|
||||
document_file_dict = {
|
||||
"content": document_file_base64,
|
||||
"content_type": document.metadata["doc_type"],
|
||||
}
|
||||
docs.append(document_file_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document_file_base64,
|
||||
content_type=document.metadata["doc_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_text_dict = {
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
docs.append(document_text_dict)
|
||||
docs.append(
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(
|
||||
{
|
||||
"content": document.page_content,
|
||||
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
|
||||
}
|
||||
MultimodalRerankInput(
|
||||
content=document.page_content,
|
||||
content_type=document.metadata.get("doc_type") or DocType.TEXT,
|
||||
)
|
||||
)
|
||||
unique_documents.append(document)
|
||||
|
||||
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_query = base64.b64encode(blob).decode()
|
||||
file_query_dict = {
|
||||
"content": file_query,
|
||||
"content_type": DocType.IMAGE,
|
||||
}
|
||||
file_query_input = MultimodalRerankInput(
|
||||
content=file_query,
|
||||
content_type=DocType.IMAGE,
|
||||
)
|
||||
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
|
||||
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
|
||||
)
|
||||
return rerank_result, unique_documents
|
||||
else:
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
@ -158,7 +157,7 @@ class ToolFileManager:
|
||||
|
||||
return tool_file
|
||||
|
||||
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -176,7 +175,7 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.runtime import VariablePool
|
||||
@ -100,7 +100,7 @@ class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
_builtin_tools_labels: dict[str, I18nObject | None] = {}
|
||||
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
@ -190,7 +190,7 @@ class ToolManager:
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: str | None = None,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@ -398,7 +398,7 @@ class ToolManager:
|
||||
agent_tool: AgentToolEntity,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
variable_pool: "VariablePool | None" = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@ -442,7 +442,7 @@ class ToolManager:
|
||||
workflow_tool: WorkflowToolRuntimeSpec,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
variable_pool: "VariablePool | None" = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
@ -634,7 +634,7 @@ class ToolManager:
|
||||
cls._builtin_providers_loaded = False
|
||||
|
||||
@classmethod
|
||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
||||
def get_tool_label(cls, tool_name: str) -> I18nObject | None:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
@ -993,7 +993,7 @@ class ToolManager:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
@ -1001,7 +1001,7 @@ class ToolManager:
|
||||
mcp_provider = mcp_service.get_provider_entity(
|
||||
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||
)
|
||||
return mcp_provider.provider_icon
|
||||
return cast(EmojiIconDict | str, mcp_provider.provider_icon)
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
except Exception:
|
||||
@ -1013,7 +1013,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> str | EmojiIconDict | dict[str, str]:
|
||||
) -> str | EmojiIconDict:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
@ -1052,7 +1052,7 @@ class ToolManager:
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: Optional["VariablePool"],
|
||||
variable_pool: "VariablePool | None",
|
||||
tool_configurations: Mapping[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
if not isinstance(message.message.blob, bytes):
|
||||
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
|
||||
@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.retry import Retry
|
||||
from redis.sentinel import Sentinel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
@ -126,6 +127,35 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
||||
|
||||
|
||||
class RedisSSLParamsDict(TypedDict):
|
||||
ssl_cert_reqs: int
|
||||
ssl_ca_certs: str | None
|
||||
ssl_certfile: str | None
|
||||
ssl_keyfile: str | None
|
||||
|
||||
|
||||
class RedisHealthParamsDict(TypedDict):
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
class RedisBaseParamsDict(TypedDict):
|
||||
username: str | None
|
||||
password: str | None
|
||||
db: int
|
||||
encoding: str
|
||||
encoding_errors: str
|
||||
decode_responses: bool
|
||||
protocol: int
|
||||
cache_config: CacheConfig | None
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
"""Get SSL configuration for Redis connection."""
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
@ -171,14 +201,14 @@ def _get_retry_policy() -> Retry:
|
||||
)
|
||||
|
||||
|
||||
def _get_connection_health_params() -> dict[str, Any]:
|
||||
def _get_connection_health_params() -> RedisHealthParamsDict:
|
||||
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
|
||||
return {
|
||||
"retry": _get_retry_policy(),
|
||||
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
}
|
||||
return RedisHealthParamsDict(
|
||||
retry=_get_retry_policy(),
|
||||
socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
|
||||
health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
|
||||
|
||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
@ -189,26 +219,26 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||
are passed through.
|
||||
"""
|
||||
params = _get_connection_health_params()
|
||||
params: dict[str, Any] = dict(_get_connection_health_params())
|
||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||
|
||||
|
||||
def _get_base_redis_params() -> dict[str, Any]:
|
||||
def _get_base_redis_params() -> RedisBaseParamsDict:
|
||||
"""Get base Redis connection parameters including retry and health policy."""
|
||||
return {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
"password": dify_config.REDIS_PASSWORD or None,
|
||||
"db": dify_config.REDIS_DB,
|
||||
"encoding": "utf-8",
|
||||
"encoding_errors": "strict",
|
||||
"decode_responses": False,
|
||||
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
"cache_config": _get_cache_configuration(),
|
||||
return RedisBaseParamsDict(
|
||||
username=dify_config.REDIS_USERNAME,
|
||||
password=dify_config.REDIS_PASSWORD or None,
|
||||
db=dify_config.REDIS_DB,
|
||||
encoding="utf-8",
|
||||
encoding_errors="strict",
|
||||
decode_responses=False,
|
||||
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||
cache_config=_get_cache_configuration(),
|
||||
**_get_connection_health_params(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create Redis client using Sentinel configuration."""
|
||||
if not dify_config.REDIS_SENTINELS:
|
||||
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
||||
@ -232,7 +262,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
|
||||
sentinel_kwargs=sentinel_kwargs,
|
||||
)
|
||||
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||
params: dict[str, Any] = {**redis_params}
|
||||
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
|
||||
return master
|
||||
|
||||
|
||||
@ -259,18 +290,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||
return cluster
|
||||
|
||||
|
||||
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||
def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
|
||||
"""Create standalone Redis client."""
|
||||
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||
|
||||
params = {**redis_params}
|
||||
params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
**redis_params,
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
|
||||
if dify_config.REDIS_MAX_CONNECTIONS:
|
||||
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
|
||||
@ -293,8 +322,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
|
||||
kwargs["max_connections"] = max_conns
|
||||
return RedisCluster.from_url(pubsub_url, **kwargs)
|
||||
|
||||
health_params = _get_connection_health_params()
|
||||
kwargs = {**health_params}
|
||||
standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
|
||||
kwargs = {**standalone_health_params}
|
||||
if max_conns:
|
||||
kwargs["max_connections"] = max_conns
|
||||
return redis.Redis.from_url(pubsub_url, **kwargs)
|
||||
|
||||
@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
|
||||
handler = _get_handler_instance(handler_class or SpanHandler)
|
||||
tracer = get_tracer(__name__)
|
||||
|
||||
return handler.wrapper(
|
||||
tracer=tracer,
|
||||
wrapped=func,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return handler.wrapper(tracer, func, *args, **kwargs)
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import inspect
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
|
||||
|
||||
class SpanHandler:
|
||||
@ -16,9 +16,9 @@ class SpanHandler:
|
||||
exceptions. Handlers can override the wrapper method to customize behavior.
|
||||
"""
|
||||
|
||||
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
|
||||
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
|
||||
|
||||
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
|
||||
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
|
||||
"""
|
||||
Build the span name from the wrapped function.
|
||||
|
||||
@ -29,11 +29,11 @@ class SpanHandler:
|
||||
"""
|
||||
return f"{wrapped.__module__}.{wrapped.__qualname__}"
|
||||
|
||||
def _extract_arguments[T](
|
||||
def _extract_arguments[**P, R](
|
||||
self,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract function arguments using inspect.signature.
|
||||
@ -59,13 +59,13 @@ class SpanHandler:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def wrapper[T](
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> T:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
"""
|
||||
Fully control the wrapper behavior.
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from extensions.otel.decorators.handler import SpanHandler
|
||||
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
|
||||
class AppGenerateHandler(SpanHandler):
|
||||
"""Span handler for ``AppGenerateService.generate``."""
|
||||
|
||||
def wrapper[T](
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> T:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||
if not arguments:
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from extensions.otel.decorators.handler import SpanHandler
|
||||
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
|
||||
class WorkflowAppRunnerHandler(SpanHandler):
|
||||
"""Span handler for ``WorkflowAppRunner.run``."""
|
||||
|
||||
def wrapper(
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Any:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||
if not arguments:
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
|
||||
@ -14,9 +14,15 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import redis
|
||||
from redis.cluster import RedisCluster
|
||||
from redis.exceptions import LockNotOwnedError, RedisError
|
||||
from redis.lock import Lock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
|
||||
primary error/exit code.
|
||||
"""
|
||||
|
||||
_redis_client: Any
|
||||
_redis_client: redis.Redis | RedisCluster | RedisClientWrapper
|
||||
_name: str
|
||||
_ttl_seconds: float
|
||||
_renew_interval_seconds: float
|
||||
_log_context: str | None
|
||||
_logger: logging.Logger
|
||||
|
||||
_lock: Any
|
||||
_lock: Lock | None
|
||||
_stop_event: threading.Event | None
|
||||
_thread: threading.Thread | None
|
||||
_acquired: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any,
|
||||
redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
|
||||
name: str,
|
||||
ttl_seconds: float = 60,
|
||||
renew_interval_seconds: float | None = None,
|
||||
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
||||
def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
|
||||
while not stop_event.wait(self._renew_interval_seconds):
|
||||
try:
|
||||
lock.reacquire()
|
||||
|
||||
@ -10,7 +10,7 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
def extract_tenant_id(user: "Account | EndUser") -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
@ -164,7 +164,10 @@ def email(email):
|
||||
EmailStr = Annotated[str, AfterValidator(email)]
|
||||
|
||||
|
||||
def uuid_value(value: Any) -> str:
|
||||
def uuid_value(value: str | UUID) -> str:
|
||||
if isinstance(value, UUID):
|
||||
return str(value)
|
||||
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
||||
@ -405,7 +408,7 @@ class TokenManager:
|
||||
def generate_token(
|
||||
cls,
|
||||
token_type: str,
|
||||
account: Optional["Account"] = None,
|
||||
account: "Account | None" = None,
|
||||
email: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
) -> str:
|
||||
@ -465,9 +468,7 @@ class TokenManager:
|
||||
return current_token
|
||||
|
||||
@classmethod
|
||||
def _set_current_token_for_account(
|
||||
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
|
||||
):
|
||||
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
expiry_seconds = int(expiry_minutes * 60)
|
||||
redis_client.setex(key, expiry_seconds, token)
|
||||
|
||||
145
api/libs/pyrefly_type_coverage.py
Normal file
145
api/libs/pyrefly_type_coverage.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Helpers for generating type-coverage summaries from pyrefly report output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class CoverageSummary(TypedDict):
|
||||
n_modules: int
|
||||
n_typable: int
|
||||
n_typed: int
|
||||
n_any: int
|
||||
n_untyped: int
|
||||
coverage: float
|
||||
strict_coverage: float
|
||||
|
||||
|
||||
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
|
||||
|
||||
_EMPTY_SUMMARY: CoverageSummary = {
|
||||
"n_modules": 0,
|
||||
"n_typable": 0,
|
||||
"n_typed": 0,
|
||||
"n_any": 0,
|
||||
"n_untyped": 0,
|
||||
"coverage": 0.0,
|
||||
"strict_coverage": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def parse_summary(report_json: str) -> CoverageSummary:
|
||||
"""Extract the summary section from ``pyrefly report`` JSON output.
|
||||
|
||||
Returns an empty summary when *report_json* is empty or malformed so that
|
||||
the CI workflow can degrade gracefully instead of crashing.
|
||||
"""
|
||||
if not report_json or not report_json.strip():
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
try:
|
||||
data = json.loads(report_json)
|
||||
except json.JSONDecodeError:
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
summary = data.get("summary")
|
||||
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
return {
|
||||
"n_modules": summary["n_modules"],
|
||||
"n_typable": summary["n_typable"],
|
||||
"n_typed": summary["n_typed"],
|
||||
"n_any": summary["n_any"],
|
||||
"n_untyped": summary["n_untyped"],
|
||||
"coverage": summary["coverage"],
|
||||
"strict_coverage": summary["strict_coverage"],
|
||||
}
|
||||
|
||||
|
||||
def format_summary_markdown(summary: CoverageSummary) -> str:
|
||||
"""Format a single coverage summary as a Markdown table."""
|
||||
|
||||
return (
|
||||
"| Metric | Value |\n"
|
||||
"| --- | ---: |\n"
|
||||
f"| Modules | {summary['n_modules']} |\n"
|
||||
f"| Typable symbols | {summary['n_typable']:,} |\n"
|
||||
f"| Typed symbols | {summary['n_typed']:,} |\n"
|
||||
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
|
||||
f"| Any symbols | {summary['n_any']:,} |\n"
|
||||
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
|
||||
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
|
||||
)
|
||||
|
||||
|
||||
def format_comparison_markdown(
|
||||
base: CoverageSummary,
|
||||
pr: CoverageSummary,
|
||||
) -> str:
|
||||
"""Format a comparison between base and PR coverage as Markdown."""
|
||||
|
||||
coverage_delta = pr["coverage"] - base["coverage"]
|
||||
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
|
||||
typed_delta = pr["n_typed"] - base["n_typed"]
|
||||
untyped_delta = pr["n_untyped"] - base["n_untyped"]
|
||||
|
||||
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
|
||||
sign = "+" if value > 0 else ""
|
||||
return f"{sign}{value:{fmt}}"
|
||||
|
||||
lines = [
|
||||
"| Metric | Base | PR | Delta |",
|
||||
"| --- | ---: | ---: | ---: |",
|
||||
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
|
||||
(
|
||||
f"| Strict coverage | {base['strict_coverage']:.2f}% "
|
||||
f"| {pr['strict_coverage']:.2f}% "
|
||||
f"| {_fmt_delta(strict_delta)}% |"
|
||||
),
|
||||
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
|
||||
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
|
||||
(
|
||||
f"| Modules | {base['n_modules']} "
|
||||
f"| {pr['n_modules']} "
|
||||
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
|
||||
),
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Read pyrefly report JSON from stdin and print a Markdown summary.
|
||||
|
||||
Accepts an optional ``--base <file>`` argument. When provided, the output
|
||||
includes a base-vs-PR comparison table.
|
||||
"""
|
||||
|
||||
args = sys.argv[1:]
|
||||
|
||||
base_file: str | None = None
|
||||
if "--base" in args:
|
||||
idx = args.index("--base")
|
||||
if idx + 1 >= len(args):
|
||||
sys.stderr.write("error: --base requires a file path\n")
|
||||
return 1
|
||||
base_file = args[idx + 1]
|
||||
|
||||
pr_report = sys.stdin.read()
|
||||
pr_summary = parse_summary(pr_report)
|
||||
|
||||
if base_file is not None:
|
||||
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
|
||||
base_summary = parse_summary(base_text)
|
||||
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
|
||||
else:
|
||||
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
|
||||
|
||||
|
||||
class DefaultFieldsMixin:
|
||||
"""Mixin for models that inherit from Base (non-dataclass)."""
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
class DefaultFieldsDCMixin(MappedAsDataclass):
|
||||
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuidv7()),
|
||||
default_factory=lambda: str(uuidv7()),
|
||||
init=False,
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
insert_default=naive_utc_now,
|
||||
default_factory=naive_utc_now,
|
||||
init=False,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
insert_default=naive_utc_now,
|
||||
default_factory=naive_utc_now,
|
||||
init=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}(id={self.id})>"
|
||||
|
||||
|
||||
def gen_uuidv4_string() -> str:
|
||||
"""gen_uuidv4_string generate a UUIDv4 string.
|
||||
|
||||
|
||||
@ -913,11 +913,7 @@ class TrialApp(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@ -941,11 +937,7 @@ class AccountTrialAppRecord(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
|
||||
raise ValueError(f"invalid workflow type value {value}")
|
||||
|
||||
@classmethod
|
||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
||||
def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
|
||||
"""
|
||||
Get workflow type from app mode.
|
||||
|
||||
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
|
||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||
|
||||
@property
|
||||
|
||||
@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
|
||||
|
||||
class InvitationData(TypedDict):
|
||||
@ -800,19 +801,19 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
|
||||
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
query_session = session or db.session
|
||||
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
with session_factory.create_session() as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
@ -1516,8 +1517,7 @@ class RegisterService:
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -88,7 +88,7 @@ class AppGenerateService:
|
||||
def generate(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
@ -356,11 +356,11 @@ class AppGenerateService:
|
||||
def generate_more_like_this(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
message_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping, Generator]:
|
||||
) -> Mapping | Generator:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
|
||||
@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import select
|
||||
@ -50,7 +50,7 @@ class AsyncWorkflowService:
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_async(
|
||||
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
|
||||
cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
||||
@ -177,7 +177,7 @@ class AsyncWorkflowService:
|
||||
|
||||
@classmethod
|
||||
def reinvoke_trigger(
|
||||
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
|
||||
cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
||||
|
||||
@ -2822,6 +2822,10 @@ class DocumentService:
|
||||
|
||||
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
|
||||
if not knowledge_config.process_rule.rules.parent_mode:
|
||||
knowledge_config.process_rule.rules.parent_mode = "paragraph"
|
||||
|
||||
if not knowledge_config.process_rule.rules.segmentation:
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
@ -148,18 +148,23 @@ class ExternalDatasetService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
|
||||
"""
|
||||
Return usage for an external knowledge API within a single tenant.
|
||||
|
||||
The caller already scopes access by tenant, so this query must do the
|
||||
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
|
||||
"""
|
||||
count = (
|
||||
db.session.scalar(
|
||||
select(func.count(ExternalKnowledgeBindings.id)).where(
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
|
||||
ExternalKnowledgeBindings.tenant_id == tenant_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
return False, 0
|
||||
return count > 0, count
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
@ -190,9 +195,7 @@ class ExternalDatasetService:
|
||||
raise ValueError(f"{parameter.get('name')} is required")
|
||||
|
||||
@staticmethod
|
||||
def process_external_api(
|
||||
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
|
||||
) -> httpx.Response:
|
||||
def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
|
||||
@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager, suppress
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
from zipfile import ZIP_DEFLATED, ZipFile
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -52,7 +52,7 @@ class FileService:
|
||||
filename: str,
|
||||
content: bytes,
|
||||
mimetype: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
source: Literal["datasets"] | None = None,
|
||||
source_url: str = "",
|
||||
) -> UploadFile:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict, Union
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -626,7 +626,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def _get_credential_schema(
|
||||
self, provider_configuration: ProviderConfiguration
|
||||
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
|
||||
) -> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
"""Get form schemas."""
|
||||
if provider_configuration.provider.model_credential_schema:
|
||||
return provider_configuration.provider.model_credential_schema
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with session_factory.create_session() as session:
|
||||
return session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
@ -19,7 +18,7 @@ class PluginPermissionService:
|
||||
install_permission: TenantPluginPermission.InstallPermission,
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
permission = session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
@ -17,7 +17,7 @@ class PipelineGenerateService:
|
||||
def generate(
|
||||
cls,
|
||||
pipeline: Pipeline,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
|
||||
@ -5,7 +5,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
@ -1387,7 +1387,7 @@ class RagPipelineService:
|
||||
"uninstalled_recommended_plugins": uninstalled_plugin_list,
|
||||
}
|
||||
|
||||
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
|
||||
def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser):
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from yarl import URL
|
||||
@ -69,7 +69,7 @@ class ToolTransformService:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
||||
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
|
||||
@ -7,15 +7,16 @@ with appropriate retry policies and error handling.
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from celery import shared_task
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.app.layers.timeslice_layer import TimeSliceLayer
|
||||
@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowGeneratorArgsDict(TypedDict):
|
||||
inputs: dict[str, Any]
|
||||
files: list[Any]
|
||||
_skip_prepare_user_inputs: bool
|
||||
workflow_id: NotRequired[str]
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for professional tier with highest priority"""
|
||||
@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
|
||||
)
|
||||
|
||||
|
||||
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]:
|
||||
def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict:
|
||||
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
|
||||
|
||||
args: dict[str, Any] = {
|
||||
return {
|
||||
"inputs": dict(trigger_data.inputs),
|
||||
"files": list(trigger_data.files),
|
||||
SKIP_PREPARE_USER_INPUTS_KEY: True,
|
||||
"_skip_prepare_user_inputs": True,
|
||||
}
|
||||
return args
|
||||
|
||||
|
||||
def _execute_workflow_common(
|
||||
|
||||
@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi:
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_fetches_account_with_original_email(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
@ -126,6 +128,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_db.session.merge.return_value = mock_account
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -437,7 +437,10 @@ class TestAccountGeneration:
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -335,10 +335,12 @@ class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
||||
def test_reset_password_success(
|
||||
self,
|
||||
mock_get_tenants,
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
@ -356,6 +358,7 @@ class TestForgotPasswordResetApi:
|
||||
# Arrange
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_db.session.merge.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
||||
# Act
|
||||
|
||||
@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("controllers.web.forgot_password.sessionmaker")
|
||||
def test_should_normalize_email_before_sending(
|
||||
self,
|
||||
mock_session_cls,
|
||||
mock_extract_ip,
|
||||
mock_rate_limit,
|
||||
mock_get_account,
|
||||
@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||
@ -153,14 +148,14 @@ class TestForgotPasswordResetApi:
|
||||
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.sessionmaker")
|
||||
@patch("controllers.web.forgot_password.db")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_should_fetch_account_with_fallback(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
@ -168,29 +163,27 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.merge.return_value = mock_account
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_update_account.assert_called_once()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.sessionmaker")
|
||||
@patch("controllers.web.forgot_password.db")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@ -199,7 +192,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_account,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_db,
|
||||
mock_token_bytes,
|
||||
mock_hash_password,
|
||||
app,
|
||||
@ -207,20 +200,18 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
||||
account = MagicMock()
|
||||
mock_get_account.return_value = account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.merge.return_value = account
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("reset-token")
|
||||
|
||||
@ -1,239 +1,193 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class TestHitTestingService:
|
||||
"""Test suite for HitTestingService"""
|
||||
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
ds = Dataset(
|
||||
tenant_id=kwargs.get("tenant_id", tenant_id),
|
||||
name=kwargs.get("name", "test-dataset"),
|
||||
created_by=kwargs.get("created_by", created_by),
|
||||
provider=provider,
|
||||
)
|
||||
db_session.add(ds)
|
||||
db_session.commit()
|
||||
db_session.refresh(ds)
|
||||
return ds
|
||||
|
||||
# ===== Utility Method Tests =====
|
||||
|
||||
class TestHitTestingService:
|
||||
# ── Utility methods (pure logic, no DB) ────────────────────────────
|
||||
|
||||
def test_escape_query_for_search_should_escape_double_quotes(self):
|
||||
"""Test that escape_query_for_search escapes double quotes correctly"""
|
||||
# Arrange
|
||||
query = 'test "query" with quotes'
|
||||
expected = 'test \\"query\\" with quotes'
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
assert result == 'test \\"query\\" with quotes'
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_query(self):
|
||||
"""Test that hit_testing_args_check passes with a valid query"""
|
||||
# Arrange
|
||||
args = {"query": "valid query"}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
HitTestingService.hit_testing_args_check({"query": "valid query"})
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
|
||||
"""Test that hit_testing_args_check passes with valid attachment_ids"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": ["id1", "id2"]}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
|
||||
# Arrange
|
||||
args = {}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query or attachment_ids is required" in str(exc_info.value)
|
||||
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
|
||||
HitTestingService.hit_testing_args_check({})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
|
||||
# Arrange
|
||||
args = {"query": "a" * 251}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query cannot exceed 250 characters" in str(exc_info.value)
|
||||
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check({"query": "a" * 251})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": "not a list"}
|
||||
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Attachment_ids must be a list" in str(exc_info.value)
|
||||
|
||||
# ===== Response Formatting Tests =====
|
||||
# ── Response formatting ────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
|
||||
"""Test that compact_retrieve_response formats the response correctly"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
documents = [mock_doc]
|
||||
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"content": "formatted content"}
|
||||
mock_format.return_value = [mock_record]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert len(result["records"]) == 1
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
|
||||
mock_format.assert_called_once_with(documents)
|
||||
mock_format.assert_called_once_with([mock_doc])
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "external"
|
||||
query = "test query"
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
documents = [
|
||||
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
|
||||
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
|
||||
]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
result = cast(
|
||||
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert len(result["records"]) == 2
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
|
||||
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
documents = [{"content": "c1"}]
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== External Retrieve Tests =====
|
||||
# ── External retrieve (real DB) ────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
|
||||
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
dataset.provider = "external"
|
||||
query = 'test "query"'
|
||||
def test_external_retrieve_should_succeed_for_external_provider(
|
||||
self, mock_ext_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
account_id = str(uuid4())
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
account.id = account_id
|
||||
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
|
||||
|
||||
# Act
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query='test "query"',
|
||||
account=account,
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
|
||||
|
||||
# Verify call to RetrievalService.external_retrieve with escaped query
|
||||
mock_ext_retrieve.assert_called_once_with(
|
||||
dataset_id="dataset_id",
|
||||
dataset_id=dataset.id,
|
||||
query='test \\"query\\"',
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
)
|
||||
|
||||
# Verify DatasetQuery record was added and committed
|
||||
mock_add.assert_called_once()
|
||||
mock_commit.assert_called_once()
|
||||
db_session_with_containers.expire_all()
|
||||
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
assert after_count == before_count + 1
|
||||
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that external_retrieve returns empty results immediately if provider is not external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
account = MagicMock()
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== Retrieve Tests =====
|
||||
# ── Retrieve (real DB) ─────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve uses default model when retrieval_model is not provided"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
def test_retrieve_should_use_default_model_when_none_provided(
|
||||
self, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
dataset.retrieval_model = None
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
account.id = str(uuid4())
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
|
||||
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
mock_retrieve.assert_called_once()
|
||||
# Verify top_k from default_retrieval_model (4)
|
||||
assert mock_retrieve.call_args.kwargs["top_k"] == 4
|
||||
mock_commit.assert_called_once()
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
assert after_count == before_count + 1
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_metadata_filtering(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
@ -242,29 +196,27 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response
|
||||
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
|
||||
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_get_meta.assert_called_once()
|
||||
mock_retrieve.assert_called_once()
|
||||
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
|
||||
retrieval_model = {
|
||||
@ -274,37 +226,27 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response: condition returned but no IDs
|
||||
mock_get_meta.return_value = ({}, "condition_string")
|
||||
|
||||
# Act
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
mock_retrieve.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
attachment_ids = ["att1", "att2"]
|
||||
|
||||
retrieval_model = {
|
||||
@ -315,21 +257,19 @@ class TestHitTestingService:
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once_with(
|
||||
retrieval_method=ANY,
|
||||
dataset_id="dataset_id",
|
||||
query=query,
|
||||
dataset_id=dataset.id,
|
||||
query="test query",
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=4,
|
||||
score_threshold=0.0,
|
||||
@ -338,26 +278,27 @@ class TestHitTestingService:
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
)
|
||||
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
|
||||
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
|
||||
called_query = mock_add.call_args[0][0]
|
||||
query_content = json.loads(called_query.content)
|
||||
|
||||
# Verify DatasetQuery was persisted with correct content structure
|
||||
db_session_with_containers.expire_all()
|
||||
latest = db_session_with_containers.scalar(
|
||||
select(DatasetQuery)
|
||||
.where(DatasetQuery.dataset_id == dataset.id)
|
||||
.order_by(DatasetQuery.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
assert latest is not None
|
||||
query_content = json.loads(latest.content)
|
||||
assert len(query_content) == 3 # 1 text + 2 images
|
||||
assert query_content[0]["content_type"] == "text_query"
|
||||
assert query_content[1]["content_type"] == "image_query"
|
||||
assert query_content[1]["content"] == "att1"
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve passes reranking and threshold parameters correctly"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "hybrid_search",
|
||||
@ -371,12 +312,14 @@ class TestHitTestingService:
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once()
|
||||
kwargs = mock_retrieve.call_args.kwargs
|
||||
assert kwargs["score_threshold"] == 0.5
|
||||
@ -0,0 +1,363 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.entities.config_entity import TracingProviderEnum
|
||||
from models.model import TraceAppConfig
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.ops_service import OpsService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestOpsService:
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager.for_tenant") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ops_trace_manager(self):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock:
|
||||
yield mock
|
||||
|
||||
def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
fake = Faker()
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
},
|
||||
account,
|
||||
)
|
||||
return app, account
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def _insert_trace_config(
|
||||
self,
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
provider: str,
|
||||
tracing_config: dict | None | object = _SENTINEL,
|
||||
) -> TraceAppConfig:
|
||||
trace_config = TraceAppConfig(
|
||||
app_id=app_id,
|
||||
tracing_provider=provider,
|
||||
tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"},
|
||||
)
|
||||
db_session.add(trace_config)
|
||||
db_session.commit()
|
||||
return trace_config
|
||||
|
||||
# ── get_tracing_app_config ─────────────────────────────────────────
|
||||
|
||||
def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize")
|
||||
assert result is None
|
||||
|
||||
def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
fake_app_id = str(uuid.uuid4())
|
||||
self._insert_trace_config(db_session_with_containers, fake_app_id, "arize")
|
||||
result = OpsService.get_tracing_app_config(fake_app_id, "arize")
|
||||
assert result is None
|
||||
|
||||
def test_get_tracing_app_config_none_config(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager
|
||||
):
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||
OpsService.get_tracing_app_config(app.id, "arize")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "default_url"),
|
||||
[
|
||||
("arize", "https://app.arize.com/"),
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
("databricks", "https://www.databricks.com/"),
|
||||
],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {}
|
||||
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, provider)
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, provider)
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == default_url
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, provider
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"}
|
||||
mock_otm.get_trace_config_project_url.return_value = "success_url"
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, provider)
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, provider)
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == "success_url"
|
||||
|
||||
def test_get_tracing_app_config_langfuse_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.get_trace_config_project_key.return_value = "key"
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, "langfuse")
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
||||
|
||||
def test_get_tracing_app_config_langfuse_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, "langfuse")
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
||||
|
||||
# ── create_tracing_app_config ──────────────────────────────────────
|
||||
|
||||
def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
|
||||
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
|
||||
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
||||
|
||||
def test_create_tracing_app_config_invalid_credentials(
|
||||
self, db_session_with_containers: Session, mock_ops_trace_manager
|
||||
):
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
result = OpsService.create_tracing_app_config(
|
||||
str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}
|
||||
)
|
||||
assert result == {"error": "Invalid Credentials"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "config"),
|
||||
[
|
||||
(TracingProviderEnum.ARIZE, {}),
|
||||
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
||||
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
||||
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
||||
],
|
||||
)
|
||||
def test_create_tracing_app_config_project_url_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config
|
||||
):
|
||||
# Existing config causes the service to return None before reaching the DB insert
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(provider))
|
||||
|
||||
result = OpsService.create_tracing_app_config(app.id, provider, config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_create_tracing_app_config_langfuse_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_key.return_value = "key"
|
||||
mock_otm.encrypt_tracing_config.return_value = {}
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app.id,
|
||||
TracingProviderEnum.LANGFUSE,
|
||||
{"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"},
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_tracing_app_config_already_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||
|
||||
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
|
||||
assert result is None
|
||||
|
||||
def test_create_tracing_app_config_with_empty_other_keys(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# "project" is in other_keys for Arize; providing "" triggers default substitution
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_url.side_effect = Exception("no url")
|
||||
mock_otm.encrypt_tracing_config.return_value = {}
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_tracing_app_config_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_url.return_value = "http://project_url"
|
||||
mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# ── update_tracing_app_config ──────────────────────────────────────
|
||||
|
||||
def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
||||
OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
|
||||
|
||||
def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
|
||||
assert result is None
|
||||
|
||||
def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
fake_app_id = str(uuid.uuid4())
|
||||
self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE))
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {})
|
||||
assert result is None
|
||||
|
||||
def test_update_tracing_app_config_invalid_credentials(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.encrypt_tracing_config.return_value = {}
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.check_trace_config_is_effective.return_value = False
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid Credentials"):
|
||||
OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
def test_update_tracing_app_config_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.encrypt_tracing_config.return_value = {"updated": "config"}
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||
|
||||
result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
assert result is not None
|
||||
assert result["app_id"] == app.id
|
||||
|
||||
# ── delete_tracing_app_config ──────────────────────────────────────
|
||||
|
||||
def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session):
|
||||
result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize")
|
||||
assert result is None
|
||||
|
||||
def test_delete_tracing_app_config_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "arize")
|
||||
|
||||
result = OpsService.delete_tracing_app_config(app.id, "arize")
|
||||
|
||||
assert result is True
|
||||
remaining = db_session_with_containers.scalar(
|
||||
select(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize")
|
||||
.limit(1)
|
||||
)
|
||||
assert remaining is None
|
||||
@ -233,11 +233,10 @@ class TestWebAppAuthService:
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.password is not None
|
||||
assert result.password_salt is not None
|
||||
refreshed = db_session_with_containers.get(Account, result.id)
|
||||
assert refreshed is not None
|
||||
assert refreshed.password is not None
|
||||
assert refreshed.password_salt is not None
|
||||
|
||||
def test_authenticate_account_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -414,9 +413,8 @@ class TestWebAppAuthService:
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
refreshed = db_session_with_containers.get(Account, result.id)
|
||||
assert refreshed is not None
|
||||
|
||||
def test_get_user_through_email_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from importlib import import_module
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -11,6 +12,7 @@ from controllers.console.datasets.external import (
|
||||
BedrockRetrievalApi,
|
||||
ExternalApiTemplateApi,
|
||||
ExternalApiTemplateListApi,
|
||||
ExternalApiUseCheckApi,
|
||||
ExternalDatasetCreateApi,
|
||||
ExternalKnowledgeHitTestingApi,
|
||||
)
|
||||
@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
external_controller = import_module("controllers.console.datasets.external")
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
@ -44,10 +48,11 @@ def current_user():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(mocker, current_user):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.external.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
def mock_auth(monkeypatch, current_user):
|
||||
monkeypatch.setattr(
|
||||
external_controller,
|
||||
"current_account_with_tenant",
|
||||
lambda: (current_user, "tenant-1"),
|
||||
)
|
||||
|
||||
|
||||
@ -136,6 +141,26 @@ class TestExternalApiTemplateApi:
|
||||
method(api, "api-id")
|
||||
|
||||
|
||||
class TestExternalApiUseCheckApi:
|
||||
def test_get_scopes_usage_check_to_current_tenant(self, app):
|
||||
api = ExternalApiUseCheckApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"external_knowledge_api_use_check",
|
||||
return_value=(True, 2),
|
||||
) as mock_use_check,
|
||||
):
|
||||
response, status = method(api, "api-id")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"is_using": True, "count": 2}
|
||||
mock_use_check.assert_called_once_with("api-id", "tenant-1")
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = ExternalDatasetCreateApi()
|
||||
|
||||
@ -233,15 +233,20 @@ class TestCheckEmailUnique:
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
first = MagicMock()
|
||||
first.scalar_one_or_none.return_value = None
|
||||
second = MagicMock()
|
||||
expected_account = MagicMock()
|
||||
second.scalar_one_or_none.return_value = expected_account
|
||||
session.execute.side_effect = [first, second]
|
||||
mock_session.execute.side_effect = [first, second]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("services.account_service.session_factory", mock_factory):
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert session.execute.call_count == 2
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -4,9 +4,7 @@ from unittest.mock import Mock
|
||||
|
||||
from core.mcp.entities import (
|
||||
SUPPORTED_PROTOCOL_VERSIONS,
|
||||
LifespanContextT,
|
||||
RequestContext,
|
||||
SessionT,
|
||||
)
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
|
||||
@ -198,42 +196,3 @@ class TestRequestContext:
|
||||
assert "RequestContext" in repr_str
|
||||
assert "test-123" in repr_str
|
||||
assert "MockSession" in repr_str
|
||||
|
||||
|
||||
class TestTypeVariables:
|
||||
"""Test type variables defined in the module."""
|
||||
|
||||
def test_session_type_var(self):
|
||||
"""Test SessionT type variable."""
|
||||
|
||||
# Create a custom session class
|
||||
class CustomSession(BaseSession):
|
||||
pass
|
||||
|
||||
# Use in generic context
|
||||
def process_session(session: SessionT) -> SessionT:
|
||||
return session
|
||||
|
||||
mock_session = Mock(spec=CustomSession)
|
||||
result = process_session(mock_session)
|
||||
assert result == mock_session
|
||||
|
||||
def test_lifespan_context_type_var(self):
|
||||
"""Test LifespanContextT type variable."""
|
||||
|
||||
# Use in generic context
|
||||
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
|
||||
return context
|
||||
|
||||
# Test with different types
|
||||
str_context = "string-context"
|
||||
assert process_lifespan(str_context) == str_context
|
||||
|
||||
dict_context = {"key": "value"}
|
||||
assert process_lifespan(dict_context) == dict_context
|
||||
|
||||
class CustomContext:
|
||||
pass
|
||||
|
||||
custom_context = CustomContext()
|
||||
assert process_lifespan(custom_context) == custom_context
|
||||
|
||||
@ -39,6 +39,25 @@ class _FakeSession:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeBeginContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _patch_both(monkeypatch, module, session):
|
||||
"""Patch both Session and sessionmaker on the module."""
|
||||
monkeypatch.setattr(module, "Session", lambda _client: session)
|
||||
monkeypatch.setattr(
|
||||
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def relyt_module(monkeypatch):
|
||||
for name, module in _build_fake_relyt_modules().items():
|
||||
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.create_collection(3)
|
||||
session.execute.assert_not_called()
|
||||
|
||||
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.create_collection(3)
|
||||
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
|
||||
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
|
||||
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
|
||||
|
||||
|
||||
# 8. delete commits session
|
||||
def test_delete_commits_session(relyt_module, monkeypatch):
|
||||
def test_delete_drops_table(relyt_module, monkeypatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
vector.embedding_dimension = 3
|
||||
session = _FakeSession()
|
||||
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
|
||||
_patch_both(monkeypatch, relyt_module, session)
|
||||
vector.delete()
|
||||
session.commit.assert_called_once()
|
||||
session.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
||||
|
||||
@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
|
||||
session = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
class _BeginCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
|
||||
vector._create_collection(3)
|
||||
|
||||
session.begin.assert_called_once()
|
||||
sql = str(session.execute.call_args.args[0])
|
||||
assert "VECTOR<FLOAT>(3)" in sql
|
||||
assert "VEC_L2_DISTANCE" in sql
|
||||
session.commit.assert_called_once()
|
||||
tidb_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = None
|
||||
session.commit = MagicMock()
|
||||
|
||||
class _SessionCtx:
|
||||
class _BeginCtx:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
|
||||
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
|
||||
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
vector.delete()
|
||||
drop_sql = str(session.execute.call_args.args[0])
|
||||
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
||||
|
||||
@ -39,7 +39,7 @@ class TestAppGenerateHandler:
|
||||
"root_node_id": None,
|
||||
}
|
||||
|
||||
arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs)
|
||||
arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs)
|
||||
|
||||
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
|
||||
assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
|
||||
@ -70,14 +70,11 @@ class TestAppGenerateHandler:
|
||||
handler.wrapper(
|
||||
tracer,
|
||||
dummy_func,
|
||||
(),
|
||||
{
|
||||
"app_model": mock_app_model,
|
||||
"user": mock_account_user,
|
||||
"args": {"workflow_id": test_workflow_id},
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
"streaming": False,
|
||||
},
|
||||
app_model=mock_app_model,
|
||||
user=mock_account_user,
|
||||
args={"workflow_id": test_workflow_id},
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
|
||||
@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler:
|
||||
def runner_run(self):
|
||||
return "result"
|
||||
|
||||
handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {})
|
||||
handler.wrapper(tracer, runner_run, mock_workflow_runner)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
|
||||
@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1, 2, 3)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = ()
|
||||
kwargs = {"a": 1, "b": 2, "c": 3}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1,)
|
||||
kwargs = {"b": 2, "c": 3}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1,)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments:
|
||||
instance = MyClass()
|
||||
args = (1, 2)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(instance.method, args, kwargs)
|
||||
result = handler._extract_arguments(instance.method, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1,)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
assert func not in handler._signature_cache
|
||||
|
||||
handler._extract_arguments(func, (1, 2), {})
|
||||
handler._extract_arguments(func, 1, 2)
|
||||
assert func in handler._signature_cache
|
||||
|
||||
cached_sig = handler._signature_cache[func]
|
||||
handler._extract_arguments(func, (3, 4), {})
|
||||
handler._extract_arguments(func, 3, 4)
|
||||
assert handler._signature_cache[func] is cached_sig
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func():
|
||||
return "result"
|
||||
|
||||
result = handler.wrapper(tracer, test_func, (), {})
|
||||
result = handler.wrapper(tracer, test_func)
|
||||
|
||||
assert result == "result"
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
@ -159,7 +159,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func():
|
||||
return "result"
|
||||
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -174,7 +174,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func():
|
||||
return "result"
|
||||
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -190,7 +190,7 @@ class TestSpanHandlerWrapper:
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -208,7 +208,7 @@ class TestSpanHandlerWrapper:
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -225,7 +225,7 @@ class TestSpanHandlerWrapper:
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
|
||||
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
|
||||
@ -236,7 +236,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func(a, b, c=10):
|
||||
return a + b + c
|
||||
|
||||
result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3})
|
||||
result = handler.wrapper(tracer, test_func, 1, 2, c=3)
|
||||
|
||||
assert result == 6
|
||||
|
||||
@ -249,7 +249,7 @@ class TestSpanHandlerWrapper:
|
||||
def my_function(x):
|
||||
return x * 2
|
||||
|
||||
result = handler.wrapper(tracer, my_function, (5,), {})
|
||||
result = handler.wrapper(tracer, my_function, 5)
|
||||
|
||||
assert result == 10
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
|
||||
138
api/tests/unit_tests/libs/test_pyrefly_type_coverage.py
Normal file
138
api/tests/unit_tests/libs/test_pyrefly_type_coverage.py
Normal file
@ -0,0 +1,138 @@
|
||||
import json
|
||||
|
||||
from libs.pyrefly_type_coverage import (
|
||||
CoverageSummary,
|
||||
format_comparison_markdown,
|
||||
format_summary_markdown,
|
||||
parse_summary,
|
||||
)
|
||||
|
||||
|
||||
def _make_report(summary: dict) -> str:
|
||||
return json.dumps({"module_reports": [], "summary": summary})
|
||||
|
||||
|
||||
_SAMPLE_SUMMARY: dict = {
|
||||
"n_modules": 100,
|
||||
"n_typable": 1000,
|
||||
"n_typed": 400,
|
||||
"n_any": 50,
|
||||
"n_untyped": 550,
|
||||
"coverage": 45.0,
|
||||
"strict_coverage": 40.0,
|
||||
"n_functions": 200,
|
||||
"n_methods": 300,
|
||||
"n_function_params": 150,
|
||||
"n_method_params": 250,
|
||||
"n_classes": 80,
|
||||
"n_attrs": 40,
|
||||
"n_properties": 20,
|
||||
"n_type_ignores": 10,
|
||||
}
|
||||
|
||||
|
||||
def _make_summary(
|
||||
*,
|
||||
n_modules: int = 100,
|
||||
n_typable: int = 1000,
|
||||
n_typed: int = 400,
|
||||
n_any: int = 50,
|
||||
n_untyped: int = 550,
|
||||
coverage: float = 45.0,
|
||||
strict_coverage: float = 40.0,
|
||||
) -> CoverageSummary:
|
||||
return {
|
||||
"n_modules": n_modules,
|
||||
"n_typable": n_typable,
|
||||
"n_typed": n_typed,
|
||||
"n_any": n_any,
|
||||
"n_untyped": n_untyped,
|
||||
"coverage": coverage,
|
||||
"strict_coverage": strict_coverage,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_summary_extracts_fields() -> None:
|
||||
report_json = _make_report(_SAMPLE_SUMMARY)
|
||||
|
||||
result = parse_summary(report_json)
|
||||
|
||||
assert result["n_modules"] == 100
|
||||
assert result["n_typable"] == 1000
|
||||
assert result["n_typed"] == 400
|
||||
assert result["n_any"] == 50
|
||||
assert result["n_untyped"] == 550
|
||||
assert result["coverage"] == 45.0
|
||||
assert result["strict_coverage"] == 40.0
|
||||
|
||||
|
||||
def test_parse_summary_handles_empty_input() -> None:
|
||||
assert parse_summary("")["n_modules"] == 0
|
||||
assert parse_summary(" ")["n_modules"] == 0
|
||||
|
||||
|
||||
def test_parse_summary_handles_invalid_json() -> None:
|
||||
assert parse_summary("not json")["n_modules"] == 0
|
||||
|
||||
|
||||
def test_parse_summary_handles_missing_summary_key() -> None:
|
||||
assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0
|
||||
|
||||
|
||||
def test_parse_summary_handles_incomplete_summary() -> None:
|
||||
partial = json.dumps({"summary": {"n_modules": 5}})
|
||||
assert parse_summary(partial)["n_modules"] == 0
|
||||
|
||||
|
||||
def test_format_summary_markdown_contains_key_metrics() -> None:
|
||||
summary = _make_summary()
|
||||
|
||||
result = format_summary_markdown(summary)
|
||||
|
||||
assert "**Type coverage**" in result
|
||||
assert "45.00%" in result
|
||||
assert "40.00%" in result
|
||||
assert "| Modules | 100 |" in result
|
||||
|
||||
|
||||
def test_format_comparison_markdown_shows_positive_delta() -> None:
|
||||
base = _make_summary()
|
||||
pr = _make_summary(
|
||||
n_modules=101,
|
||||
n_typable=1010,
|
||||
n_typed=420,
|
||||
n_untyped=540,
|
||||
coverage=46.53,
|
||||
strict_coverage=41.58,
|
||||
)
|
||||
|
||||
result = format_comparison_markdown(base, pr)
|
||||
|
||||
assert "| Base | PR | Delta |" in result
|
||||
assert "+1.53%" in result
|
||||
assert "+1.58%" in result
|
||||
assert "+20" in result
|
||||
|
||||
|
||||
def test_format_comparison_markdown_shows_negative_delta() -> None:
|
||||
base = _make_summary()
|
||||
pr = _make_summary(
|
||||
n_typed=390,
|
||||
n_any=60,
|
||||
coverage=44.0,
|
||||
strict_coverage=39.0,
|
||||
)
|
||||
|
||||
result = format_comparison_markdown(base, pr)
|
||||
|
||||
assert "-1.00%" in result
|
||||
assert "-10" in result
|
||||
|
||||
|
||||
def test_format_comparison_markdown_shows_zero_delta() -> None:
|
||||
summary = _make_summary()
|
||||
|
||||
result = format_comparison_markdown(summary, summary)
|
||||
|
||||
assert "0.00%" in result
|
||||
assert "| 0 |" in result
|
||||
@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
|
||||
mock_db_session.scalar.return_value = 3
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is True
|
||||
assert count == 3
|
||||
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
|
||||
|
||||
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
|
||||
"""
|
||||
@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
|
||||
mock_db_session.scalar.return_value = 0
|
||||
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
|
||||
|
||||
assert in_use is False
|
||||
assert count == 0
|
||||
|
||||
@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
||||
"""Patch session_factory.create_session() to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=False)
|
||||
session.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_session.return_value = session
|
||||
patcher = patch(f"{MODULE}.session_factory", mock_factory)
|
||||
return patcher, session
|
||||
|
||||
|
||||
class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.scalar.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
@ -30,10 +32,10 @@ class TestGetPermission:
|
||||
assert result is permission
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
@ -43,10 +45,10 @@ class TestGetPermission:
|
||||
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
@ -54,20 +56,24 @@ class TestChangePermission:
|
||||
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
|
||||
)
|
||||
|
||||
assert result is True
|
||||
session.begin.assert_called_once()
|
||||
session.add.assert_called_once()
|
||||
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
|
||||
)
|
||||
|
||||
assert result is True
|
||||
session.begin.assert_called_once()
|
||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
session.add.assert_not_called()
|
||||
|
||||
@ -1427,16 +1427,7 @@ class TestRegisterService:
|
||||
mock_tenant.name = "Test Workspace"
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
|
||||
# Mock database queries - need to mock the sessionmaker query
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_lookup.return_value = None
|
||||
@ -1475,7 +1466,7 @@ class TestRegisterService:
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com")
|
||||
|
||||
def test_invite_new_member_normalizes_new_account_email(
|
||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||
@ -1486,13 +1477,7 @@ class TestRegisterService:
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
mixed_email = "Invitee@Example.com"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_lookup.return_value = None
|
||||
@ -1525,7 +1510,7 @@ class TestRegisterService:
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
|
||||
mock_lookup.assert_called_once_with(mixed_email)
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
|
||||
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
||||
@ -1545,16 +1530,7 @@ class TestRegisterService:
|
||||
account_id="existing-user-456", email="existing@example.com", status="pending"
|
||||
)
|
||||
|
||||
# Mock database queries - need to mock the sessionmaker query
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
@ -1584,7 +1560,7 @@ class TestRegisterService:
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
|
||||
mock_task_dependencies.delay.assert_called_once()
|
||||
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
|
||||
mock_lookup.assert_called_once_with("existing@example.com")
|
||||
|
||||
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
|
||||
"""Test inviting a member who is already in the tenant."""
|
||||
|
||||
@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation:
|
||||
assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1
|
||||
assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False
|
||||
|
||||
def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self):
|
||||
knowledge_config = KnowledgeConfig(
|
||||
indexing_technique="economy",
|
||||
data_source=DataSource(
|
||||
info_list=InfoList(
|
||||
data_source_type="upload_file",
|
||||
file_info_list=FileInfo(file_ids=["file-1"]),
|
||||
)
|
||||
),
|
||||
process_rule=ProcessRule(
|
||||
mode="hierarchical",
|
||||
rules=Rule(
|
||||
pre_processing_rules=[
|
||||
PreProcessingRule(id="remove_extra_spaces", enabled=True),
|
||||
],
|
||||
segmentation=Segmentation(separator="\n", max_tokens=1024),
|
||||
subchunk_segmentation=Segmentation(separator="\n", max_tokens=512),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
DocumentService.process_rule_args_validate(knowledge_config)
|
||||
|
||||
assert knowledge_config.process_rule is not None
|
||||
assert knowledge_config.process_rule.rules is not None
|
||||
assert knowledge_config.process_rule.rules.parent_mode == "paragraph"
|
||||
|
||||
|
||||
class TestDocumentServiceSaveDocumentWithDatasetId:
|
||||
"""Unit tests for non-SQL validation branches in save_document_with_dataset_id."""
|
||||
|
||||
@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
"""Test API use check when API has one binding."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_db.session.scalar.return_value = 1
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert in_use is True
|
||||
assert count == 1
|
||||
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
|
||||
"""Test API use check with multiple bindings."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_db.session.scalar.return_value = 10
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert in_use is True
|
||||
@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck:
|
||||
"""Test API use check when API is not in use."""
|
||||
# Arrange
|
||||
api_id = "api-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_db.session.scalar.return_value = 0
|
||||
|
||||
# Act
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
|
||||
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert in_use is False
|
||||
|
||||
@ -1,392 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import TracingProviderEnum
|
||||
from models.model import App, TraceAppConfig
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
class TestOpsService:
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = None
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||
OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "default_url"),
|
||||
[
|
||||
("arize", "https://app.arize.com/"),
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
("databricks", "https://www.databricks.com/"),
|
||||
],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == default_url
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "success_url"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid Credentials"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "config"),
|
||||
[
|
||||
(TracingProviderEnum.ARIZE, {}),
|
||||
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
||||
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
||||
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
||||
],
|
||||
)
|
||||
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
|
||||
# Arrange
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, config)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config(
|
||||
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
# 'project' is in other_keys for Arize
|
||||
# provide an empty string for the project in the tracing_config
|
||||
# create_tracing_app_config will replace it with the default from the model
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.add.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
||||
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid Credentials"):
|
||||
OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
current_config.to_dict.return_value = {"some": "data"}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"some": "data"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_no_config(self, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_success(self, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_db.session.delete.assert_called_with(trace_config)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import type { Area } from 'react-easy-crop'
|
||||
import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput'
|
||||
import type { AvatarProps } from '@/app/components/base/avatar'
|
||||
import type { AvatarProps } from '@/app/components/base/ui/avatar'
|
||||
import type { ImageFile } from '@/types/app'
|
||||
import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
@ -10,10 +10,10 @@ import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
|
||||
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'
|
||||
|
||||
@ -6,9 +6,9 @@ import {
|
||||
import { Fragment } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { useLogout, useUserProfile } from '@/service/use-common'
|
||||
|
||||
@ -10,9 +10,9 @@ import {
|
||||
import * as React from 'react'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
|
||||
@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useSelector } from '@/context/app-context'
|
||||
import { SubjectType } from '@/models/access-control'
|
||||
import { useSearchForWhiteListCandidates } from '@/service/access-control'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Button from '../../base/button'
|
||||
import Checkbox from '../../base/checkbox'
|
||||
import Input from '../../base/input'
|
||||
|
||||
@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c
|
||||
import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useAppWhiteListSubjects } from '@/service/access-control'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Loading from '../../base/loading'
|
||||
import Tooltip from '../../base/tooltip'
|
||||
import AddMemberOrGroupDialog from './add-member-or-group-pop'
|
||||
|
||||
@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/avatar', () => ({
|
||||
vi.mock('@/app/components/base/ui/avatar', () => ({
|
||||
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
|
||||
}))
|
||||
|
||||
|
||||
@ -7,11 +7,11 @@ import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||
import { getLastAnswer } from '@/app/components/base/chat/utils'
|
||||
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||
|
||||
@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import { useChat } from '@/app/components/base/chat/chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
|
||||
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||
|
||||
@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon'
|
||||
import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form'
|
||||
import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import {
|
||||
AppSourceType,
|
||||
@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { formatBooleanInputs } from '@/utils/model-config'
|
||||
import { Avatar } from '../../avatar'
|
||||
import Chat from '../chat'
|
||||
import { useChat } from '../chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||
|
||||
@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
|
||||
import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form'
|
||||
import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import {
|
||||
AppSourceType,
|
||||
@ -23,7 +24,6 @@ import {
|
||||
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { Avatar } from '../../avatar'
|
||||
import Chat from '../chat'
|
||||
import { useChat } from '../chat/hooks'
|
||||
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { Avatar } from '../index'
|
||||
import { Avatar } from '..'
|
||||
|
||||
describe('Avatar', () => {
|
||||
describe('Rendering', () => {
|
||||
@ -53,8 +53,8 @@ function AvatarRoot({
|
||||
return (
|
||||
<BaseAvatar.Root
|
||||
className={cn(
|
||||
'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600',
|
||||
isAvatarPresetSize(size) && avatarSizeClasses[size].root,
|
||||
'relative inline-flex shrink-0 items-center justify-center overflow-hidden rounded-full bg-primary-600 select-none',
|
||||
avatarSizeClasses[size].root,
|
||||
className,
|
||||
)}
|
||||
style={resolvedStyle}
|
||||
@ -104,7 +104,7 @@ function AvatarImage({
|
||||
}: AvatarImageProps) {
|
||||
return (
|
||||
<BaseAvatar.Image
|
||||
className={cn('absolute inset-0 size-full object-cover', className)}
|
||||
className={cn('inset-0 absolute size-full object-cover', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
@ -4,13 +4,13 @@ import { useDebounceFn } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||
import { DatasetPermission } from '@/models/datasets'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import ThemeSwitcher from '@/app/components/base/theme-switcher'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import type { InvitationResult } from '@/models/common'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import { NUM_INFINITE } from '@/app/components/billing/config'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
|
||||
@ -3,9 +3,9 @@ import type { FC } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
|
||||
import { Avatar } from '@/app/components/base/ui/avatar'
|
||||
import { useMembers } from '@/service/use-common'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
|
||||
46
web/app/components/workflow/__tests__/block-icon.spec.tsx
Normal file
46
web/app/components/workflow/__tests__/block-icon.spec.tsx
Normal file
@ -0,0 +1,46 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import BlockIcon, { VarBlockIcon } from '../block-icon'
|
||||
import { BlockEnum } from '../types'
|
||||
|
||||
describe('BlockIcon', () => {
|
||||
it('renders the default workflow icon container for regular nodes', () => {
|
||||
const { container } = render(<BlockIcon type={BlockEnum.Start} size="xs" className="extra-class" />)
|
||||
|
||||
const iconContainer = container.firstElementChild
|
||||
expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class')
|
||||
expect(iconContainer?.querySelector('svg')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('normalizes protected plugin icon urls for tool-like nodes', () => {
|
||||
const { container } = render(
|
||||
<BlockIcon
|
||||
type={BlockEnum.Tool}
|
||||
toolIcon="/foo/workspaces/current/plugin/icon/plugin-tool.png"
|
||||
/>,
|
||||
)
|
||||
|
||||
const iconContainer = container.firstElementChild as HTMLElement
|
||||
const backgroundIcon = iconContainer.querySelector('div') as HTMLElement
|
||||
|
||||
expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500')
|
||||
expect(backgroundIcon.style.backgroundImage).toContain(
|
||||
`${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('VarBlockIcon', () => {
|
||||
it('renders the compact icon variant without the default container wrapper', () => {
|
||||
const { container } = render(
|
||||
<VarBlockIcon
|
||||
type={BlockEnum.Answer}
|
||||
className="custom-var-icon"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(container.querySelector('.custom-var-icon')).toBeInTheDocument()
|
||||
expect(container.querySelector('svg')).toBeInTheDocument()
|
||||
expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
39
web/app/components/workflow/__tests__/context.spec.tsx
Normal file
39
web/app/components/workflow/__tests__/context.spec.tsx
Normal file
@ -0,0 +1,39 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { WorkflowContextProvider } from '../context'
|
||||
import { useStore, useWorkflowStore } from '../store'
|
||||
|
||||
const StoreConsumer = () => {
|
||||
const showSingleRunPanel = useStore(s => s.showSingleRunPanel)
|
||||
const store = useWorkflowStore()
|
||||
|
||||
return (
|
||||
<button onClick={() => store.getState().setShowSingleRunPanel(!showSingleRunPanel)}>
|
||||
{showSingleRunPanel ? 'open' : 'closed'}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
describe('WorkflowContextProvider', () => {
|
||||
it('provides the workflow store to descendants and keeps the same store across rerenders', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { rerender } = render(
|
||||
<WorkflowContextProvider>
|
||||
<StoreConsumer />
|
||||
</WorkflowContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'closed' }))
|
||||
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
|
||||
|
||||
rerender(
|
||||
<WorkflowContextProvider>
|
||||
<StoreConsumer />
|
||||
</WorkflowContextProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
67
web/app/components/workflow/__tests__/index.spec.tsx
Normal file
67
web/app/components/workflow/__tests__/index.spec.tsx
Normal file
@ -0,0 +1,67 @@
|
||||
import type { Edge, Node } from '../types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
|
||||
import WorkflowWithDefaultContext from '../index'
|
||||
import { BlockEnum } from '../types'
|
||||
import { useWorkflowHistoryStore } from '../workflow-history-store'
|
||||
|
||||
const nodes: Node[] = [
|
||||
{
|
||||
id: 'node-start',
|
||||
type: 'custom',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
title: 'Start',
|
||||
desc: '',
|
||||
type: BlockEnum.Start,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const edges: Edge[] = [
|
||||
{
|
||||
id: 'edge-1',
|
||||
source: 'node-start',
|
||||
target: 'node-end',
|
||||
sourceHandle: null,
|
||||
targetHandle: null,
|
||||
type: 'custom',
|
||||
data: {
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.End,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const ContextConsumer = () => {
|
||||
const { store, shortcutsEnabled } = useWorkflowHistoryStore()
|
||||
const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length)
|
||||
const reactFlowStore = useStoreApi()
|
||||
|
||||
return (
|
||||
<div>
|
||||
{`history:${store.getState().nodes.length}`}
|
||||
{` shortcuts:${String(shortcutsEnabled)}`}
|
||||
{` datasets:${datasetCount}`}
|
||||
{` reactflow:${String(!!reactFlowStore)}`}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('WorkflowWithDefaultContext', () => {
|
||||
it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => {
|
||||
render(
|
||||
<WorkflowWithDefaultContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
<ContextConsumer />
|
||||
</WorkflowWithDefaultContext>,
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'),
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,51 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import ShortcutsName from '../shortcuts-name'
|
||||
|
||||
describe('ShortcutsName', () => {
|
||||
const originalNavigator = globalThis.navigator
|
||||
|
||||
afterEach(() => {
|
||||
Object.defineProperty(globalThis, 'navigator', {
|
||||
value: originalNavigator,
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('renders mac-friendly key labels and style variants', () => {
|
||||
Object.defineProperty(globalThis, 'navigator', {
|
||||
value: { userAgent: 'Macintosh' },
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const { container } = render(
|
||||
<ShortcutsName
|
||||
keys={['ctrl', 'shift', 's']}
|
||||
bgColor="white"
|
||||
textColor="secondary"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('⌘')).toBeInTheDocument()
|
||||
expect(screen.getByText('⇧')).toBeInTheDocument()
|
||||
expect(screen.getByText('s')).toBeInTheDocument()
|
||||
expect(container.querySelector('.system-kbd')).toHaveClass(
|
||||
'bg-components-kbd-bg-white',
|
||||
'text-text-tertiary',
|
||||
)
|
||||
})
|
||||
|
||||
it('keeps raw key names on non-mac systems', () => {
|
||||
Object.defineProperty(globalThis, 'navigator', {
|
||||
value: { userAgent: 'Windows NT' },
|
||||
writable: true,
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
render(<ShortcutsName keys={['ctrl', 'alt']} />)
|
||||
|
||||
expect(screen.getByText('ctrl')).toBeInTheDocument()
|
||||
expect(screen.getByText('alt')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,97 @@
|
||||
import type { Edge, Node } from '../types'
|
||||
import type { WorkflowHistoryState } from '../workflow-history-store'
|
||||
import { render, renderHook, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { BlockEnum } from '../types'
|
||||
import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store'
|
||||
|
||||
const nodes: Node[] = [
|
||||
{
|
||||
id: 'node-1',
|
||||
type: 'custom',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
title: 'Start',
|
||||
desc: '',
|
||||
type: BlockEnum.Start,
|
||||
selected: true,
|
||||
},
|
||||
selected: true,
|
||||
},
|
||||
]
|
||||
|
||||
const edges: Edge[] = [
|
||||
{
|
||||
id: 'edge-1',
|
||||
source: 'node-1',
|
||||
target: 'node-2',
|
||||
sourceHandle: null,
|
||||
targetHandle: null,
|
||||
type: 'custom',
|
||||
selected: true,
|
||||
data: {
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.End,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const HistoryConsumer = () => {
|
||||
const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore()
|
||||
|
||||
return (
|
||||
<button onClick={() => setShortcutsEnabled(!shortcutsEnabled)}>
|
||||
{`nodes:${store.getState().nodes.length} shortcuts:${String(shortcutsEnabled)}`}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
describe('WorkflowHistoryProvider', () => {
|
||||
it('provides workflow history state and shortcut toggles', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<WorkflowHistoryProvider
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
<HistoryConsumer />
|
||||
</WorkflowHistoryProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' }))
|
||||
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('sanitizes selected flags when history state is replaced through the exposed store api', () => {
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<WorkflowHistoryProvider
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
{children}
|
||||
</WorkflowHistoryProvider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper })
|
||||
const nextState: WorkflowHistoryState = {
|
||||
workflowHistoryEvent: undefined,
|
||||
workflowHistoryEventMeta: undefined,
|
||||
nodes,
|
||||
edges,
|
||||
}
|
||||
|
||||
result.current.store.setState(nextState)
|
||||
|
||||
expect(result.current.store.getState().nodes[0].data.selected).toBe(false)
|
||||
expect(result.current.store.getState().edges[0].selected).toBe(false)
|
||||
})
|
||||
|
||||
it('throws when consumed outside the provider', () => {
|
||||
expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow(
|
||||
'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider',
|
||||
)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,140 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { Theme } from '@/types/app'
|
||||
import AllTools from '../all-tools'
|
||||
import { createGlobalPublicStoreState, createToolProvider } from './factories'
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
useMarketplacePlugins: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
|
||||
useMCPToolAvailability: () => ({
|
||||
allowed: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/var', async importOriginal => ({
|
||||
...(await importOriginal<typeof import('@/utils/var')>()),
|
||||
getMarketplaceUrl: () => 'https://marketplace.test/tools',
|
||||
}))
|
||||
|
||||
const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins)
|
||||
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
|
||||
const mockUseGetLanguage = vi.mocked(useGetLanguage)
|
||||
const mockUseTheme = vi.mocked(useTheme)
|
||||
|
||||
const createMarketplacePluginsMock = () => ({
|
||||
plugins: [],
|
||||
total: 0,
|
||||
resetPlugins: vi.fn(),
|
||||
queryPlugins: vi.fn(),
|
||||
queryPluginsWithDebounced: vi.fn(),
|
||||
cancelQueryPluginsWithDebounced: vi.fn(),
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
hasNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
page: 0,
|
||||
})
|
||||
|
||||
describe('AllTools', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false)))
|
||||
mockUseGetLanguage.mockReturnValue('en_US')
|
||||
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
|
||||
mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock())
|
||||
})
|
||||
|
||||
it('filters tools by the active tab', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(
|
||||
<AllTools
|
||||
searchText=""
|
||||
tags={[]}
|
||||
onSelect={vi.fn()}
|
||||
buildInTools={[createToolProvider({
|
||||
id: 'provider-built-in',
|
||||
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
|
||||
})]}
|
||||
customTools={[createToolProvider({
|
||||
id: 'provider-custom',
|
||||
type: 'custom',
|
||||
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
|
||||
})]}
|
||||
workflowTools={[]}
|
||||
mcpTools={[]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
|
||||
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('workflow.tabs.customTool'))
|
||||
|
||||
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('filters the rendered tools by the search text', () => {
|
||||
render(
|
||||
<AllTools
|
||||
searchText="report"
|
||||
tags={[]}
|
||||
onSelect={vi.fn()}
|
||||
buildInTools={[
|
||||
createToolProvider({
|
||||
id: 'provider-report',
|
||||
label: { en_US: 'Report Toolkit', zh_Hans: 'Report Toolkit' },
|
||||
}),
|
||||
createToolProvider({
|
||||
id: 'provider-other',
|
||||
label: { en_US: 'Other Toolkit', zh_Hans: 'Other Toolkit' },
|
||||
}),
|
||||
]}
|
||||
customTools={[]}
|
||||
workflowTools={[]}
|
||||
mcpTools={[]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Report Toolkit')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the empty state when no tool matches the current filter', async () => {
|
||||
render(
|
||||
<AllTools
|
||||
searchText="missing"
|
||||
tags={[]}
|
||||
onSelect={vi.fn()}
|
||||
buildInTools={[]}
|
||||
customTools={[]}
|
||||
workflowTools={[]}
|
||||
mcpTools={[]}
|
||||
/>,
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,79 @@
|
||||
import type { NodeDefault } from '../../types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { BlockEnum } from '../../types'
|
||||
import Blocks from '../blocks'
|
||||
import { BlockClassificationEnum } from '../types'
|
||||
|
||||
const runtimeState = vi.hoisted(() => ({
|
||||
nodes: [] as Array<{ data: { type?: BlockEnum } }>,
|
||||
}))
|
||||
|
||||
vi.mock('reactflow', () => ({
|
||||
useStoreApi: () => ({
|
||||
getState: () => ({
|
||||
getNodes: () => runtimeState.nodes,
|
||||
}),
|
||||
}),
|
||||
}))
|
||||
|
||||
const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({
|
||||
metaData: {
|
||||
classification,
|
||||
sort: 0,
|
||||
type,
|
||||
title,
|
||||
author: 'Dify',
|
||||
description: `${title} description`,
|
||||
},
|
||||
defaultValue: {},
|
||||
checkValid: () => ({ isValid: true }),
|
||||
})
|
||||
|
||||
describe('Blocks', () => {
|
||||
beforeEach(() => {
|
||||
runtimeState.nodes = []
|
||||
})
|
||||
|
||||
it('renders grouped blocks, filters duplicate knowledge-base nodes, and selects a block', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onSelect = vi.fn()
|
||||
|
||||
runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }]
|
||||
|
||||
render(
|
||||
<Blocks
|
||||
searchText=""
|
||||
onSelect={onSelect}
|
||||
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.LoopEnd, BlockEnum.KnowledgeBase]}
|
||||
blocks={[
|
||||
createBlock(BlockEnum.LLM, 'LLM'),
|
||||
createBlock(BlockEnum.LoopEnd, 'Exit Loop', BlockClassificationEnum.Logic),
|
||||
createBlock(BlockEnum.KnowledgeBase, 'Knowledge Retrieval'),
|
||||
]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('LLM')).toBeInTheDocument()
|
||||
expect(screen.getByText('Exit Loop')).toBeInTheDocument()
|
||||
expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('LLM'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM)
|
||||
})
|
||||
|
||||
it('shows the empty state when no block matches the search', () => {
|
||||
render(
|
||||
<Blocks
|
||||
searchText="missing"
|
||||
onSelect={vi.fn()}
|
||||
availableBlocksTypes={[BlockEnum.LLM]}
|
||||
blocks={[createBlock(BlockEnum.LLM, 'LLM')]}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,101 @@
|
||||
import type { ToolWithProvider } from '../../types'
|
||||
import type { Plugin } from '@/app/components/plugins/types'
|
||||
import type { Tool } from '@/app/components/tools/types'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { defaultSystemFeatures } from '@/types/feature'
|
||||
|
||||
export const createTool = (
|
||||
name: string,
|
||||
label: string,
|
||||
description = `${label} description`,
|
||||
): Tool => ({
|
||||
name,
|
||||
author: 'author',
|
||||
label: {
|
||||
en_US: label,
|
||||
zh_Hans: label,
|
||||
},
|
||||
description: {
|
||||
en_US: description,
|
||||
zh_Hans: description,
|
||||
},
|
||||
parameters: [],
|
||||
labels: [],
|
||||
output_schema: {},
|
||||
})
|
||||
|
||||
export const createToolProvider = (
|
||||
overrides: Partial<ToolWithProvider> = {},
|
||||
): ToolWithProvider => ({
|
||||
id: 'provider-1',
|
||||
name: 'provider-one',
|
||||
author: 'Provider Author',
|
||||
description: {
|
||||
en_US: 'Provider description',
|
||||
zh_Hans: 'Provider description',
|
||||
},
|
||||
icon: 'icon',
|
||||
icon_dark: 'icon-dark',
|
||||
label: {
|
||||
en_US: 'Provider One',
|
||||
zh_Hans: 'Provider One',
|
||||
},
|
||||
type: CollectionType.builtIn,
|
||||
team_credentials: {},
|
||||
is_team_authorization: false,
|
||||
allow_delete: false,
|
||||
labels: [],
|
||||
plugin_id: 'plugin-1',
|
||||
tools: [createTool('tool-a', 'Tool A')],
|
||||
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
|
||||
plugin_unique_identifier: 'plugin-1@1.0.0',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
export const createPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
|
||||
type: 'plugin',
|
||||
org: 'org',
|
||||
author: 'author',
|
||||
name: 'Plugin One',
|
||||
plugin_id: 'plugin-1',
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
latest_package_identifier: 'plugin-1@1.0.0',
|
||||
icon: 'icon',
|
||||
verified: true,
|
||||
label: {
|
||||
en_US: 'Plugin One',
|
||||
zh_Hans: 'Plugin One',
|
||||
},
|
||||
brief: {
|
||||
en_US: 'Plugin description',
|
||||
zh_Hans: 'Plugin description',
|
||||
},
|
||||
description: {
|
||||
en_US: 'Plugin description',
|
||||
zh_Hans: 'Plugin description',
|
||||
},
|
||||
introduction: 'Plugin introduction',
|
||||
repository: 'https://example.com/plugin',
|
||||
category: PluginCategoryEnum.tool,
|
||||
tags: [],
|
||||
badges: [],
|
||||
install_count: 0,
|
||||
endpoint: {
|
||||
settings: [],
|
||||
},
|
||||
verification: {
|
||||
authorized_category: 'community',
|
||||
},
|
||||
from: 'github',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({
|
||||
systemFeatures: {
|
||||
...defaultSystemFeatures,
|
||||
enable_marketplace: enableMarketplace,
|
||||
},
|
||||
setSystemFeatures: vi.fn(),
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user