Merge remote-tracking branch 'myori/main' into feat/collaboration2

This commit is contained in:
hjlarry 2026-04-10 22:47:40 +08:00
commit ee2b021395
380 changed files with 28622 additions and 3985 deletions

View File

@ -7,6 +7,7 @@
## Summary ## Summary
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. --> <!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
## Screenshots ## Screenshots

View File

@ -0,0 +1,118 @@
name: Comment with Pyrefly Type Coverage
on:
workflow_run:
workflows:
- Pyrefly Type Coverage
types:
- completed
permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
pull-requests: write
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Checkout default branch (trusted code)
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Download type coverage artifact
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }},
});
const match = artifacts.data.artifacts.find((artifact) =>
artifact.name === 'pyrefly_type_coverage'
);
if (!match) {
throw new Error('pyrefly_type_coverage artifact not found');
}
const download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: match.id,
archive_format: 'zip',
});
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
- name: Unzip artifact
run: unzip -o pyrefly_type_coverage.zip
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} > /tmp/type_coverage_comment.md
- name: Post comment
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
let prNumber = null;
try {
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
} catch (err) {
const prs = context.payload.workflow_run.pull_requests || [];
if (prs.length > 0 && prs[0].number) {
prNumber = prs[0].number;
}
}
if (!prNumber) {
throw new Error('PR number not found in artifact or workflow_run payload');
}
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const marker = '### Pyrefly Type Coverage';
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -0,0 +1,120 @@
name: Pyrefly Type Coverage
on:
pull_request:
paths:
- 'api/**/*.py'
permissions:
contents: read
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly report on PR branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
echo '{}' > /tmp/pyrefly_report_pr.json
- name: Save helper script from base branch
run: |
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly report on base branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
echo '{}' > /tmp/pyrefly_report_base.json
- name: Generate coverage comparison
id: coverage
run: |
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
--base /tmp/pyrefly_report_base.json \
< /tmp/pyrefly_report_pr.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload type coverage artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: pyrefly_type_coverage
path: |
pr_report.json
base_report.json
pr_number.txt
- name: Comment PR with type coverage
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const marker = '### Pyrefly Type Coverage';
let body;
try {
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
} catch {
body = `${marker}\n\n_Coverage report unavailable._`;
}
const prNumber = context.payload.pull_request.number;
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help ## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@ -2,7 +2,6 @@ import base64
import secrets import secrets
import click import click
from sqlalchemy.orm import sessionmaker
from constants.languages import languages from constants.languages import languages
from extensions.ext_database import db from extensions.ext_database import db
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
return return
normalized_email = email.strip().lower() normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email.strip())
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account: if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return return
try: try:
valid_password(new_password) valid_password(new_password)
except: except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return return
# generate password salt # generate password salt
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode() base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt # encrypt password with salt
password_hashed = hash_password(new_password, salt) password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed account = db.session.merge(account)
account.password_salt = base64_salt account.password = base64_password_hashed
AccountService.reset_login_error_rate_limit(normalized_email) account.password_salt = base64_salt
click.echo(click.style("Password reset successfully.", fg="green")) db.session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.") @click.command("reset-email", help="Reset the account email.")
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
return return
normalized_new_email = new_email.strip().lower() normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email.strip())
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account: if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return return
try: try:
email_validate(normalized_new_email) email_validate(normalized_new_email)
except: except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red")) click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return return
account.email = normalized_new_email account = db.session.merge(account)
click.echo(click.style("Email updated successfully.", fg="green")) account.email = normalized_new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command("create-tenant", help="Create account and tenant.") @click.command("create-tenant", help="Create account and tenant.")

View File

@ -1,7 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
InvalidTokenError, InvalidTokenError,
PasswordMismatchError, PasswordMismatchError,
) )
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models import Account from models import Account
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError() raise AccountInFreezeError()
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
email = register_data.get("email", "") email = register_data.get("email", "")
normalized_email = email.lower() normalized_email = email.lower()
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email)
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account: if account:
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
else: else:
account = self._create_new_account(normalized_email, args.password_confirm) account = self._create_new_account(normalized_email, args.password_confirm)
if not account: if not account:
raise AccountNotFoundError() raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email) AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}

View File

@ -4,7 +4,6 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
else: else:
language = "en-US" language = "en-US"
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
account=account, account=account,
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt) password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "") email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email)
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account: if account:
self._update_existing_account(account, password_hashed, salt, session) account = db.session.merge(account)
else: self._update_existing_account(account, password_hashed, salt)
raise AccountNotFound() db.session.commit()
else:
raise AccountNotFound()
return {"result": "success"} return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session): def _update_existing_account(self, account, password_hashed, salt):
# Update existing account credentials # Update existing account credentials
account.password = base64.b64encode(password_hashed).decode() account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode() account.password_salt = base64.b64encode(salt).decode()

View File

@ -4,7 +4,6 @@ import urllib.parse
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id) account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account: if not account:
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account return account

View File

@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, external_knowledge_api_id): def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id external_knowledge_api_id, current_tenant_id
) )
return {"is_using": external_knowledge_api_is_using, "count": count}, 200 return {"is_using": external_knowledge_api_is_using, "count": count}, 200

View File

@ -9,7 +9,6 @@ from flask_restx import Resource, fields, marshal_with
from graphon.file import helpers as file_helpers from graphon.file import helpers as file_helpers
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
@ -580,8 +579,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email user_email = current_user.email
else: else:
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None: if account is None:
raise AccountNotFound() raise AccountNotFound()
email_for_sending = account.email email_for_sending = account.email

View File

@ -3,7 +3,6 @@ import secrets
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models from controllers.common.schema import register_schema_models
from controllers.console.auth.error import ( from controllers.console.auth.error import (
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
else: else:
language = "en-US" language = "en-US"
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(request_email)
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None: if account is None:
raise AuthenticationFailedError() raise AuthenticationFailedError()
else: else:
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "") email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email)
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account: if account:
self._update_existing_account(account, password_hashed, salt) account = db.session.merge(account)
else: self._update_existing_account(account, password_hashed, salt)
raise AuthenticationFailedError() db.session.commit()
else:
raise AuthenticationFailedError()
return {"result": "success"} return {"result": "success"}

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any
from graphon.file import File, FileUploadConfig from graphon.file import File, FileUploadConfig
from graphon.model_runtime.entities.model_entities import AIModelEntity from graphon.model_runtime.entities.model_entities import AIModelEntity
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict) extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance # tracing instance
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False) trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity): class EasyUIBasedAppGenerateEntity(AppGenerateEntity):

View File

@ -1,10 +1,10 @@
from typing import Literal, Optional from typing import Any, Literal, TypedDict
from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter from core.datasource.entities.datasource_entities import DatasourceParameter
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject, I18nObjectDict
class DatasourceApiEntity(BaseModel): class DatasourceApiEntity(BaseModel):
@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel):
output_schema: dict | None = None output_schema: dict | None = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
class DatasourceProviderApiEntityDict(TypedDict):
id: str
author: str
name: str
plugin_id: str | None
plugin_unique_identifier: str | None
description: I18nObjectDict
icon: str | dict
label: I18nObjectDict
type: str
team_credentials: dict | None
is_team_authorization: bool
allow_delete: bool
datasources: list[Any]
labels: list[str]
class DatasourceProviderApiEntity(BaseModel): class DatasourceProviderApiEntity(BaseModel):
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
def convert_none_to_empty_list(cls, v): def convert_none_to_empty_list(cls, v):
return v if v is not None else [] return v if v is not None else []
def to_dict(self) -> dict: def to_dict(self) -> DatasourceProviderApiEntityDict:
# ------------- # -------------
# overwrite datasource parameter types for temp fix # overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources) datasources = jsonable_encoder(self.datasources)
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
parameter["type"] = "files" parameter["type"] = "files"
# ------------- # -------------
return { result: DatasourceProviderApiEntityDict = {
"id": self.id, "id": self.id,
"author": self.author, "author": self.author,
"name": self.name, "name": self.name,
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
"datasources": datasources, "datasources": datasources,
"labels": self.labels, "labels": self.labels,
} }
return result

View File

@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
if not isinstance(message.message, DatasourceMessage.BlobMessage): if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type") raise ValueError("unexpected message type")
# FIXME: should do a type check here. if not isinstance(message.message.blob, bytes):
assert isinstance(message.message.blob, bytes) raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager() tool_file_manager = ToolFileManager()
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw( blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
user_id=user_id, user_id=user_id,

View File

@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
logger.exception("Authentication retry failed") logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any: def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
""" """
Execute a function with authentication retry logic. Execute a function with authentication retry logic.

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
from typing import Any, TypeVar from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass @dataclass
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]: class RequestContext[SessionT: BaseSession, LifespanContextT]:
request_id: RequestId request_id: RequestId
meta: RequestParams.Meta | None meta: RequestParams.Meta | None
session: SessionT session: SessionT

View File

@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request: ReceiveRequestT request: ReceiveRequestT
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]" _session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
def __init__( def __init__(
self, self,
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request_meta: RequestParams.Meta | None, request_meta: RequestParams.Meta | None,
request: ReceiveRequestT, request: ReceiveRequestT,
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]", session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
): ):
self.request_id = request_id self.request_id = request_id
self.request_meta = request_meta self.request_meta = request_meta

View File

@ -31,7 +31,6 @@ ProgressToken = str | int
Cursor = str Cursor = str
Role = Literal["user", "assistant"] Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")] RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
type AnyFunction = Callable[..., Any]
class RequestParams(BaseModel): class RequestParams(BaseModel):

View File

@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -172,10 +172,10 @@ class ModelInstance:
function=self.model_type_instance.invoke, function=self.model_type_instance.invoke,
model=self.model_name, model=self.model_name,
credentials=self.credentials, credentials=self.credentials,
prompt_messages=prompt_messages, prompt_messages=list(prompt_messages),
model_parameters=model_parameters, model_parameters=model_parameters,
tools=tools, tools=list(tools) if tools else None,
stop=stop, stop=list(stop) if stop else None,
stream=stream, stream=stream,
callbacks=callbacks, callbacks=callbacks,
), ),
@ -193,15 +193,12 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, LargeLanguageModel): if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel") raise Exception("Model type instance is not LargeLanguageModel")
return cast( return self._round_robin_invoke(
int, function=self.model_type_instance.get_num_tokens,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.get_num_tokens, credentials=self.credentials,
model=self.model_name, prompt_messages=list(prompt_messages),
credentials=self.credentials, tools=list(tools) if tools else None,
prompt_messages=prompt_messages,
tools=tools,
),
) )
def invoke_text_embedding( def invoke_text_embedding(
@ -216,15 +213,12 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TextEmbeddingModel): if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")
return cast( return self._round_robin_invoke(
EmbeddingResult, function=self.model_type_instance.invoke,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.invoke, credentials=self.credentials,
model=self.model_name, texts=texts,
credentials=self.credentials, input_type=input_type,
texts=texts,
input_type=input_type,
),
) )
def invoke_multimodal_embedding( def invoke_multimodal_embedding(
@ -241,15 +235,12 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TextEmbeddingModel): if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")
return cast( return self._round_robin_invoke(
EmbeddingResult, function=self.model_type_instance.invoke,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.invoke, credentials=self.credentials,
model=self.model_name, multimodel_documents=multimodel_documents,
credentials=self.credentials, input_type=input_type,
multimodel_documents=multimodel_documents,
input_type=input_type,
),
) )
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
@ -261,14 +252,11 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TextEmbeddingModel): if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel") raise Exception("Model type instance is not TextEmbeddingModel")
return cast( return self._round_robin_invoke(
list[int], function=self.model_type_instance.get_num_tokens,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.get_num_tokens, credentials=self.credentials,
model=self.model_name, texts=texts,
credentials=self.credentials,
texts=texts,
),
) )
def invoke_rerank( def invoke_rerank(
@ -289,23 +277,20 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, RerankModel): if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel") raise Exception("Model type instance is not RerankModel")
return cast( return self._round_robin_invoke(
RerankResult, function=self.model_type_instance.invoke,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.invoke, credentials=self.credentials,
model=self.model_name, query=query,
credentials=self.credentials, docs=docs,
query=query, score_threshold=score_threshold,
docs=docs, top_n=top_n,
score_threshold=score_threshold,
top_n=top_n,
),
) )
def invoke_multimodal_rerank( def invoke_multimodal_rerank(
self, self,
query: dict, query: MultimodalRerankInput,
docs: list[dict], docs: list[MultimodalRerankInput],
score_threshold: float | None = None, score_threshold: float | None = None,
top_n: int | None = None, top_n: int | None = None,
) -> RerankResult: ) -> RerankResult:
@ -320,17 +305,14 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, RerankModel): if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel") raise Exception("Model type instance is not RerankModel")
return cast( return self._round_robin_invoke(
RerankResult, function=self.model_type_instance.invoke_multimodal_rerank,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.invoke_multimodal_rerank, credentials=self.credentials,
model=self.model_name, query=query,
credentials=self.credentials, docs=docs,
query=query, score_threshold=score_threshold,
docs=docs, top_n=top_n,
score_threshold=score_threshold,
top_n=top_n,
),
) )
def invoke_moderation(self, text: str) -> bool: def invoke_moderation(self, text: str) -> bool:
@ -342,14 +324,11 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, ModerationModel): if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel") raise Exception("Model type instance is not ModerationModel")
return cast( return self._round_robin_invoke(
bool, function=self.model_type_instance.invoke,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.invoke, credentials=self.credentials,
model=self.model_name, text=text,
credentials=self.credentials,
text=text,
),
) )
def invoke_speech2text(self, file: IO[bytes]) -> str: def invoke_speech2text(self, file: IO[bytes]) -> str:
@ -361,14 +340,11 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, Speech2TextModel): if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel") raise Exception("Model type instance is not Speech2TextModel")
return cast( return self._round_robin_invoke(
str, function=self.model_type_instance.invoke,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.invoke, credentials=self.credentials,
model=self.model_name, file=file,
credentials=self.credentials,
file=file,
),
) )
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
@ -381,18 +357,15 @@ class ModelInstance:
""" """
if not isinstance(self.model_type_instance, TTSModel): if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel") raise Exception("Model type instance is not TTSModel")
return cast( return self._round_robin_invoke(
Iterable[bytes], function=self.model_type_instance.invoke,
self._round_robin_invoke( model=self.model_name,
function=self.model_type_instance.invoke, credentials=self.credentials,
model=self.model_name, content_text=content_text,
credentials=self.credentials, voice=voice,
content_text=content_text,
voice=voice,
),
) )
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
""" """
Round-robin invoke Round-robin invoke
:param function: function to invoke :param function: function to invoke
@ -430,9 +403,8 @@ class ModelInstance:
continue continue
try: try:
if "credentials" in kwargs: kwargs["credentials"] = lb_config.credentials
del kwargs["credentials"] return function(*args, **kwargs)
return function(*args, **kwargs, credentials=lb_config.credentials)
except InvokeRateLimitError as e: except InvokeRateLimitError as e:
# expire in 60 seconds # expire in 60 seconds
self.load_balancing_manager.cooldown(lb_config, expire=60) self.load_balancing_manager.cooldown(lb_config, expire=60)

View File

@ -1,6 +1,6 @@
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Union, cast from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
) )
@classmethod @classmethod
def _get_user(cls, user_id: str) -> Union[EndUser, Account]: def _get_user(cls, user_id: str) -> EndUser | Account:
""" """
get the user by user id get the user by user id
""" """

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
if redis_client.get(collection_exist_cache_key): if redis_client.get(collection_exist_cache_key):
return return
index_name = f"{self._collection_name}_embedding_index" index_name = f"{self._collection_name}_embedding_index"
with Session(self.client) as session: with sessionmaker(bind=self.client).begin() as session:
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """) drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement) session.execute(drop_statement)
create_statement = sql_text(f""" create_statement = sql_text(f"""
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
$$); $$);
""") """)
session.execute(index_statement) session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
self.delete_by_uuids(ids) self.delete_by_uuids(ids)
def delete(self): def delete(self):
with Session(self.client) as session: with sessionmaker(bind=self.client).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";""")) session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
with Session(self.client) as session: with Session(self.client) as session:

View File

@ -6,7 +6,7 @@ import sqlalchemy
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base from sqlalchemy.orm import Session, declarative_base, sessionmaker
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json from core.rag.datasource.vdb.field import Field, parse_metadata_json
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
if redis_client.get(collection_exist_cache_key): if redis_client.get(collection_exist_cache_key):
return return
tidb_dist_func = self._get_distance_func() tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session: with sessionmaker(bind=self._engine).begin() as session:
session.begin()
create_statement = sql_text(f""" create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} ( CREATE TABLE IF NOT EXISTS {self._collection_name} (
id CHAR(36) PRIMARY KEY, id CHAR(36) PRIMARY KEY,
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
); );
""") """)
session.execute(create_statement) session.execute(create_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
return [] return []
def delete(self): def delete(self):
with Session(self._engine) as session: with sessionmaker(bind=self._engine).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()
def _get_distance_func(self) -> str: def _get_distance_func(self) -> str:
match self._distance_func: match self._distance_func:

View File

@ -3,8 +3,7 @@
import logging import logging
import re import re
import uuid import uuid
from collections.abc import Mapping from typing import Any, TypedDict, cast
from typing import Any, cast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController() _file_access_controller = DatabaseFileAccessController()
class ParagraphFormatPreviewDict(TypedDict):
chunk_structure: str
preview: list[dict[str, Any]]
total_segments: int
class ParagraphIndexProcessor(BaseIndexProcessor): class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract( text_docs = ExtractProcessor.extract(
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset) keyword = Keyword(dataset)
keyword.add_texts(documents) keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]: def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
if isinstance(chunks, list): if isinstance(chunks, list):
preview = [] preview = []
for content in chunks: for content in chunks:
preview.append({"content": content}) preview.append({"content": content})
return { result: ParagraphFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX, "chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview, "preview": preview,
"total_segments": len(chunks), "total_segments": len(chunks),
} }
return result
else: else:
raise ValueError("Chunks is not a list") raise ValueError("Chunks is not a list")

View File

@ -3,8 +3,7 @@
import json import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping from typing import Any, TypedDict
from typing import Any
from sqlalchemy import delete, select from sqlalchemy import delete, select
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ParentChildFormatPreviewDict(TypedDict):
chunk_structure: str
parent_mode: str
preview: list[dict[str, Any]]
total_segments: int
class ParentChildIndexProcessor(BaseIndexProcessor): class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract( text_docs = ExtractProcessor.extract(
@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if all_multimodal_documents and dataset.is_multimodal: if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents) vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]: def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
parent_childs = ParentChildStructureChunk.model_validate(chunks) parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = [] preview = []
for parent_child in parent_childs.parent_child_chunks: for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return { result: ParentChildFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode, "parent_mode": parent_childs.parent_mode,
"preview": preview, "preview": preview,
"total_segments": len(parent_childs.parent_child_chunks), "total_segments": len(parent_childs.parent_child_chunks),
} }
return result
def generate_summary_preview( def generate_summary_preview(
self, self,

View File

@ -4,8 +4,7 @@ import logging
import re import re
import threading import threading
import uuid import uuid
from collections.abc import Mapping from typing import Any, TypedDict
from typing import Any
import pandas as pd import pandas as pd
from flask import Flask, current_app from flask import Flask, current_app
@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class QAFormatPreviewDict(TypedDict):
chunk_structure: str
qa_preview: list[dict[str, Any]]
total_segments: int
class QAIndexProcessor(BaseIndexProcessor): class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract( text_docs = ExtractProcessor.extract(
@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor):
else: else:
raise ValueError("Indexing technique must be high quality.") raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Any) -> Mapping[str, Any]: def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
qa_chunks = QAStructureChunk.model_validate(chunks) qa_chunks = QAStructureChunk.model_validate(chunks)
preview = [] preview = []
for qa_chunk in qa_chunks.qa_chunks: for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return { result: QAFormatPreviewDict = {
"chunk_structure": IndexStructureType.QA_INDEX, "chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview, "qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks), "total_segments": len(qa_chunks.qa_chunks),
} }
return result
def generate_summary_preview( def generate_summary_preview(
self, self,

View File

@ -1,7 +1,7 @@
import base64 import base64
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
:param query_type: query type :param query_type: query type
:return: rerank result :return: rerank result
""" """
docs = [] docs: list[MultimodalRerankInput] = []
doc_ids = set() doc_ids = set()
unique_documents = [] unique_documents = []
for document in documents: for document in documents:
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file: if upload_file:
blob = storage.load_once(upload_file.key) blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode() document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = { docs.append(
"content": document_file_base64, MultimodalRerankInput(
"content_type": document.metadata["doc_type"], content=document_file_base64,
} content_type=document.metadata["doc_type"],
docs.append(document_file_dict) )
)
else: else:
document_text_dict = { docs.append(
"content": document.page_content, MultimodalRerankInput(
"content_type": document.metadata.get("doc_type") or DocType.TEXT, content=document.page_content,
} content_type=document.metadata.get("doc_type") or DocType.TEXT,
docs.append(document_text_dict) )
)
doc_ids.add(document.metadata["doc_id"]) doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document) unique_documents.append(document)
elif document.provider == "external": elif document.provider == "external":
if document not in unique_documents: if document not in unique_documents:
docs.append( docs.append(
{ MultimodalRerankInput(
"content": document.page_content, content=document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT, content_type=document.metadata.get("doc_type") or DocType.TEXT,
} )
) )
unique_documents.append(document) unique_documents.append(document)
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file: if upload_file:
blob = storage.load_once(upload_file.key) blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode() file_query = base64.b64encode(blob).decode()
file_query_dict = { file_query_input = MultimodalRerankInput(
"content": file_query, content=file_query,
"content_type": DocType.IMAGE, content_type=DocType.IMAGE,
} )
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
) )
return rerank_result, unique_documents return rerank_result, unique_documents
else: else:

View File

@ -6,7 +6,6 @@ import os
import time import time
from collections.abc import Generator from collections.abc import Generator
from mimetypes import guess_extension, guess_type from mimetypes import guess_extension, guess_type
from typing import Union
from uuid import uuid4 from uuid import uuid4
import httpx import httpx
@ -158,7 +157,7 @@ class ToolFileManager:
return tool_file return tool_file
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
""" """
get file binary get file binary
@ -176,7 +175,7 @@ class ToolFileManager:
return blob, tool_file.mimetype return blob, tool_file.mimetype
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
""" """
get file binary get file binary

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
import sqlalchemy as sa import sqlalchemy as sa
from graphon.runtime import VariablePool from graphon.runtime import VariablePool
@ -100,7 +100,7 @@ class ToolManager:
_builtin_provider_lock = Lock() _builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {} _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False _builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} _builtin_tools_labels: dict[str, I18nObject | None] = {}
@classmethod @classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
@ -190,7 +190,7 @@ class ToolManager:
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: str | None = None, credential_id: str | None = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: ) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
""" """
get the tool runtime get the tool runtime
@ -398,7 +398,7 @@ class ToolManager:
agent_tool: AgentToolEntity, agent_tool: AgentToolEntity,
user_id: str | None = None, user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None, variable_pool: "VariablePool | None" = None,
) -> Tool: ) -> Tool:
""" """
get the agent tool runtime get the agent tool runtime
@ -442,7 +442,7 @@ class ToolManager:
workflow_tool: WorkflowToolRuntimeSpec, workflow_tool: WorkflowToolRuntimeSpec,
user_id: str | None = None, user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None, variable_pool: "VariablePool | None" = None,
) -> Tool: ) -> Tool:
""" """
get the workflow tool runtime get the workflow tool runtime
@ -634,7 +634,7 @@ class ToolManager:
cls._builtin_providers_loaded = False cls._builtin_providers_loaded = False
@classmethod @classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: def get_tool_label(cls, tool_name: str) -> I18nObject | None:
""" """
get the tool label get the tool label
@ -993,7 +993,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod @classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session) mcp_service = MCPToolManageService(session=session)
@ -1001,7 +1001,7 @@ class ToolManager:
mcp_provider = mcp_service.get_provider_entity( mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
) )
return mcp_provider.provider_icon return cast(EmojiIconDict | str, mcp_provider.provider_icon)
except ValueError: except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception: except Exception:
@ -1013,7 +1013,7 @@ class ToolManager:
tenant_id: str, tenant_id: str,
provider_type: ToolProviderType, provider_type: ToolProviderType,
provider_id: str, provider_id: str,
) -> str | EmojiIconDict | dict[str, str]: ) -> str | EmojiIconDict:
""" """
get the tool icon get the tool icon
@ -1052,7 +1052,7 @@ class ToolManager:
def _convert_tool_parameters_type( def _convert_tool_parameters_type(
cls, cls,
parameters: list[ToolParameter], parameters: list[ToolParameter],
variable_pool: Optional["VariablePool"], variable_pool: "VariablePool | None",
tool_configurations: Mapping[str, Any], tool_configurations: Mapping[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow", typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]: ) -> dict[str, Any]:

View File

@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
if not isinstance(message.message, ToolInvokeMessage.BlobMessage): if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type") raise ValueError("unexpected message type")
assert isinstance(message.message.blob, bytes) if not isinstance(message.message.blob, bytes):
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager() tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw( tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id, user_id=user_id,

View File

@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection from redis.connection import Connection, SSLConnection
from redis.retry import Retry from redis.retry import Retry
from redis.sentinel import Sentinel from redis.sentinel import Sentinel
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from dify_app import DifyApp from dify_app import DifyApp
@ -126,6 +127,35 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
_pubsub_redis_client: redis.Redis | RedisCluster | None = None _pubsub_redis_client: redis.Redis | RedisCluster | None = None
class RedisSSLParamsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
class RedisHealthParamsDict(TypedDict):
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
class RedisBaseParamsDict(TypedDict):
username: str | None
password: str | None
db: int
encoding: str
encoding_errors: str
decode_responses: bool
protocol: int
cache_config: CacheConfig | None
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
"""Get SSL configuration for Redis connection.""" """Get SSL configuration for Redis connection."""
if not dify_config.REDIS_USE_SSL: if not dify_config.REDIS_USE_SSL:
@ -171,14 +201,14 @@ def _get_retry_policy() -> Retry:
) )
def _get_connection_health_params() -> dict[str, Any]: def _get_connection_health_params() -> RedisHealthParamsDict:
"""Get connection health and retry parameters for standalone and Sentinel Redis clients.""" """Get connection health and retry parameters for standalone and Sentinel Redis clients."""
return { return RedisHealthParamsDict(
"retry": _get_retry_policy(), retry=_get_retry_policy(),
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT, socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL, health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
} )
def _get_cluster_connection_health_params() -> dict[str, Any]: def _get_cluster_connection_health_params() -> dict[str, Any]:
@ -189,26 +219,26 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through. are passed through.
""" """
params = _get_connection_health_params() params: dict[str, Any] = dict(_get_connection_health_params())
return {k: v for k, v in params.items() if k != "health_check_interval"} return {k: v for k, v in params.items() if k != "health_check_interval"}
def _get_base_redis_params() -> dict[str, Any]: def _get_base_redis_params() -> RedisBaseParamsDict:
"""Get base Redis connection parameters including retry and health policy.""" """Get base Redis connection parameters including retry and health policy."""
return { return RedisBaseParamsDict(
"username": dify_config.REDIS_USERNAME, username=dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None, password=dify_config.REDIS_PASSWORD or None,
"db": dify_config.REDIS_DB, db=dify_config.REDIS_DB,
"encoding": "utf-8", encoding="utf-8",
"encoding_errors": "strict", encoding_errors="strict",
"decode_responses": False, decode_responses=False,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(), cache_config=_get_cache_configuration(),
**_get_connection_health_params(), **_get_connection_health_params(),
} )
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create Redis client using Sentinel configuration.""" """Create Redis client using Sentinel configuration."""
if not dify_config.REDIS_SENTINELS: if not dify_config.REDIS_SENTINELS:
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
@ -232,7 +262,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
sentinel_kwargs=sentinel_kwargs, sentinel_kwargs=sentinel_kwargs,
) )
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) params: dict[str, Any] = {**redis_params}
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
return master return master
@ -259,18 +290,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
return cluster return cluster
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create standalone Redis client.""" """Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration() connection_class, ssl_kwargs = _get_ssl_configuration()
params = {**redis_params} params: dict[str, Any] = {
params.update( **redis_params,
{ "host": dify_config.REDIS_HOST,
"host": dify_config.REDIS_HOST, "port": dify_config.REDIS_PORT,
"port": dify_config.REDIS_PORT, "connection_class": connection_class,
"connection_class": connection_class, }
}
)
if dify_config.REDIS_MAX_CONNECTIONS: if dify_config.REDIS_MAX_CONNECTIONS:
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
@ -293,8 +322,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
kwargs["max_connections"] = max_conns kwargs["max_connections"] = max_conns
return RedisCluster.from_url(pubsub_url, **kwargs) return RedisCluster.from_url(pubsub_url, **kwargs)
health_params = _get_connection_health_params() standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
kwargs = {**health_params} kwargs = {**standalone_health_params}
if max_conns: if max_conns:
kwargs["max_connections"] = max_conns kwargs["max_connections"] = max_conns
return redis.Redis.from_url(pubsub_url, **kwargs) return redis.Redis.from_url(pubsub_url, **kwargs)

View File

@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
handler = _get_handler_instance(handler_class or SpanHandler) handler = _get_handler_instance(handler_class or SpanHandler)
tracer = get_tracer(__name__) tracer = get_tracer(__name__)
return handler.wrapper( return handler.wrapper(tracer, func, *args, **kwargs)
tracer=tracer,
wrapped=func,
args=args,
kwargs=kwargs,
)
return cast(Callable[P, R], wrapper) return cast(Callable[P, R], wrapper)

View File

@ -1,8 +1,8 @@
import inspect import inspect
from collections.abc import Callable, Mapping from collections.abc import Callable
from typing import Any from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
class SpanHandler: class SpanHandler:
@ -16,9 +16,9 @@ class SpanHandler:
exceptions. Handlers can override the wrapper method to customize behavior. exceptions. Handlers can override the wrapper method to customize behavior.
""" """
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {} _signature_cache: dict[Callable[..., object], inspect.Signature] = {}
def _build_span_name(self, wrapped: Callable[..., Any]) -> str: def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
""" """
Build the span name from the wrapped function. Build the span name from the wrapped function.
@ -29,11 +29,11 @@ class SpanHandler:
""" """
return f"{wrapped.__module__}.{wrapped.__qualname__}" return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments[T]( def _extract_arguments[**P, R](
self, self,
wrapped: Callable[..., T], wrapped: Callable[P, R],
args: tuple[object, ...], *args: P.args,
kwargs: Mapping[str, object], **kwargs: P.kwargs,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
""" """
Extract function arguments using inspect.signature. Extract function arguments using inspect.signature.
@ -59,13 +59,13 @@ class SpanHandler:
except Exception: except Exception:
return None return None
def wrapper[T]( def wrapper[**P, R](
self, self,
tracer: Any, tracer: Tracer,
wrapped: Callable[..., T], wrapped: Callable[P, R],
args: tuple[object, ...], *args: P.args,
kwargs: Mapping[str, object], **kwargs: P.kwargs,
) -> T: ) -> R:
""" """
Fully control the wrapper behavior. Fully control the wrapper behavior.

View File

@ -1,8 +1,7 @@
import logging import logging
from collections.abc import Callable, Mapping from collections.abc import Callable
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler from extensions.otel.decorators.handler import SpanHandler
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
class AppGenerateHandler(SpanHandler): class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``.""" """Span handler for ``AppGenerateService.generate``."""
def wrapper[T]( def wrapper[**P, R](
self, self,
tracer: Any, tracer: Tracer,
wrapped: Callable[..., T], wrapped: Callable[P, R],
args: tuple[object, ...], *args: P.args,
kwargs: Mapping[str, object], **kwargs: P.kwargs,
) -> T: ) -> R:
try: try:
arguments = self._extract_arguments(wrapped, args, kwargs) arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments: if not arguments:
return wrapped(*args, **kwargs) return wrapped(*args, **kwargs)

View File

@ -1,8 +1,7 @@
import logging import logging
from collections.abc import Callable, Mapping from collections.abc import Callable
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler from extensions.otel.decorators.handler import SpanHandler
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
class WorkflowAppRunnerHandler(SpanHandler): class WorkflowAppRunnerHandler(SpanHandler):
"""Span handler for ``WorkflowAppRunner.run``.""" """Span handler for ``WorkflowAppRunner.run``."""
def wrapper( def wrapper[**P, R](
self, self,
tracer: Any, tracer: Tracer,
wrapped: Callable[..., Any], wrapped: Callable[P, R],
args: tuple[Any, ...], *args: P.args,
kwargs: Mapping[str, Any], **kwargs: P.kwargs,
) -> Any: ) -> R:
try: try:
arguments = self._extract_arguments(wrapped, args, kwargs) arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments: if not arguments:
return wrapped(*args, **kwargs) return wrapped(*args, **kwargs)

View File

@ -14,9 +14,15 @@ from __future__ import annotations
import logging import logging
import threading import threading
from typing import Any from typing import TYPE_CHECKING, Any
import redis
from redis.cluster import RedisCluster
from redis.exceptions import LockNotOwnedError, RedisError from redis.exceptions import LockNotOwnedError, RedisError
from redis.lock import Lock
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
primary error/exit code. primary error/exit code.
""" """
_redis_client: Any _redis_client: redis.Redis | RedisCluster | RedisClientWrapper
_name: str _name: str
_ttl_seconds: float _ttl_seconds: float
_renew_interval_seconds: float _renew_interval_seconds: float
_log_context: str | None _log_context: str | None
_logger: logging.Logger _logger: logging.Logger
_lock: Any _lock: Lock | None
_stop_event: threading.Event | None _stop_event: threading.Event | None
_thread: threading.Thread | None _thread: threading.Thread | None
_acquired: bool _acquired: bool
def __init__( def __init__(
self, self,
redis_client: Any, redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
name: str, name: str,
ttl_seconds: float = 60, ttl_seconds: float = 60,
renew_interval_seconds: float | None = None, renew_interval_seconds: float | None = None,
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
) )
self._thread.start() self._thread.start()
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None: def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
while not stop_event.wait(self._renew_interval_seconds): while not stop_event.wait(self._renew_interval_seconds):
try: try:
lock.reacquire() lock.reacquire()

View File

@ -10,7 +10,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping from collections.abc import Callable, Generator, Mapping
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
from uuid import UUID from uuid import UUID
from zoneinfo import available_timezones from zoneinfo import available_timezones
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: def extract_tenant_id(user: "Account | EndUser") -> str | None:
""" """
Extract tenant_id from Account or EndUser object. Extract tenant_id from Account or EndUser object.
@ -164,7 +164,10 @@ def email(email):
EmailStr = Annotated[str, AfterValidator(email)] EmailStr = Annotated[str, AfterValidator(email)]
def uuid_value(value: Any) -> str: def uuid_value(value: str | UUID) -> str:
if isinstance(value, UUID):
return str(value)
if value == "": if value == "":
return str(value) return str(value)
@ -405,7 +408,7 @@ class TokenManager:
def generate_token( def generate_token(
cls, cls,
token_type: str, token_type: str,
account: Optional["Account"] = None, account: "Account | None" = None,
email: str | None = None, email: str | None = None,
additional_data: dict | None = None, additional_data: dict | None = None,
) -> str: ) -> str:
@ -465,9 +468,7 @@ class TokenManager:
return current_token return current_token
@classmethod @classmethod
def _set_current_token_for_account( def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
):
key = cls._get_account_token_key(account_id, token_type) key = cls._get_account_token_key(account_id, token_type)
expiry_seconds = int(expiry_minutes * 60) expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(key, expiry_seconds, token) redis_client.setex(key, expiry_seconds, token)

View File

@ -0,0 +1,145 @@
"""Helpers for generating type-coverage summaries from pyrefly report output."""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import TypedDict
class CoverageSummary(TypedDict):
n_modules: int
n_typable: int
n_typed: int
n_any: int
n_untyped: int
coverage: float
strict_coverage: float
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
_EMPTY_SUMMARY: CoverageSummary = {
"n_modules": 0,
"n_typable": 0,
"n_typed": 0,
"n_any": 0,
"n_untyped": 0,
"coverage": 0.0,
"strict_coverage": 0.0,
}
def parse_summary(report_json: str) -> CoverageSummary:
"""Extract the summary section from ``pyrefly report`` JSON output.
Returns an empty summary when *report_json* is empty or malformed so that
the CI workflow can degrade gracefully instead of crashing.
"""
if not report_json or not report_json.strip():
return _EMPTY_SUMMARY.copy()
try:
data = json.loads(report_json)
except json.JSONDecodeError:
return _EMPTY_SUMMARY.copy()
summary = data.get("summary")
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
return _EMPTY_SUMMARY.copy()
return {
"n_modules": summary["n_modules"],
"n_typable": summary["n_typable"],
"n_typed": summary["n_typed"],
"n_any": summary["n_any"],
"n_untyped": summary["n_untyped"],
"coverage": summary["coverage"],
"strict_coverage": summary["strict_coverage"],
}
def format_summary_markdown(summary: CoverageSummary) -> str:
"""Format a single coverage summary as a Markdown table."""
return (
"| Metric | Value |\n"
"| --- | ---: |\n"
f"| Modules | {summary['n_modules']} |\n"
f"| Typable symbols | {summary['n_typable']:,} |\n"
f"| Typed symbols | {summary['n_typed']:,} |\n"
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
f"| Any symbols | {summary['n_any']:,} |\n"
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
)
def format_comparison_markdown(
base: CoverageSummary,
pr: CoverageSummary,
) -> str:
"""Format a comparison between base and PR coverage as Markdown."""
coverage_delta = pr["coverage"] - base["coverage"]
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
typed_delta = pr["n_typed"] - base["n_typed"]
untyped_delta = pr["n_untyped"] - base["n_untyped"]
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
sign = "+" if value > 0 else ""
return f"{sign}{value:{fmt}}"
lines = [
"| Metric | Base | PR | Delta |",
"| --- | ---: | ---: | ---: |",
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
(
f"| Strict coverage | {base['strict_coverage']:.2f}% "
f"| {pr['strict_coverage']:.2f}% "
f"| {_fmt_delta(strict_delta)}% |"
),
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
(
f"| Modules | {base['n_modules']} "
f"| {pr['n_modules']} "
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
),
]
return "\n".join(lines)
def main() -> int:
"""Read pyrefly report JSON from stdin and print a Markdown summary.
Accepts an optional ``--base <file>`` argument. When provided, the output
includes a base-vs-PR comparison table.
"""
args = sys.argv[1:]
base_file: str | None = None
if "--base" in args:
idx = args.index("--base")
if idx + 1 >= len(args):
sys.stderr.write("error: --base requires a file path\n")
return 1
base_file = args[idx + 1]
pr_report = sys.stdin.read()
pr_summary = parse_summary(pr_report)
if base_file is not None:
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
base_summary = parse_summary(base_text)
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
else:
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
class DefaultFieldsMixin: class DefaultFieldsMixin:
"""Mixin for models that inherit from Base (non-dataclass)."""
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
StringUUID, StringUUID,
primary_key=True, primary_key=True,
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
return f"<{self.__class__.__name__}(id={self.id})>" return f"<{self.__class__.__name__}(id={self.id})>"
class DefaultFieldsDCMixin(MappedAsDataclass):
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
__abstract__ = True
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuidv7()),
default_factory=lambda: str(uuidv7()),
init=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(id={self.id})>"
def gen_uuidv4_string() -> str: def gen_uuidv4_string() -> str:
"""gen_uuidv4_string generate a UUIDv4 string. """gen_uuidv4_string generate a UUIDv4 string.

View File

@ -913,11 +913,7 @@ class TrialApp(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime, sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
) )
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3) trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
@ -941,11 +937,7 @@ class AccountTrialAppRecord(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime, sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
) )
@property @property

View File

@ -4,7 +4,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
from uuid import uuid4 from uuid import uuid4
import sqlalchemy as sa import sqlalchemy as sa
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}") raise ValueError(f"invalid workflow type value {value}")
@classmethod @classmethod
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
""" """
Get workflow type from app mode. Get workflow type from app mode.
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
) )
return extras return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
return next(iter([i for i in self.offload_data if i.type_ == type_]), None) return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property @property

View File

@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select, update from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from core.db.session_factory import session_factory
class InvitationData(TypedDict): class InvitationData(TypedDict):
@ -800,19 +801,19 @@ class AccountService:
return token return token
@staticmethod @staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: def get_account_by_email_with_case_fallback(email: str) -> Account | None:
""" """
Retrieve an account by email and fall back to the lowercase email if the original lookup fails. Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs. rest of the system gradually normalizes new inputs.
""" """
query_session = session or db.session with session_factory.create_session() as session:
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower(): if account or email == email.lower():
return account return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod @classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
@ -1516,8 +1517,7 @@ class RegisterService:
check_workspace_member_invite_permission(tenant.id) check_workspace_member_invite_permission(tenant.id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email)
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account: if not account:
TenantService.check_member_permission(tenant, inviter, None, "add") TenantService.check_member_permission(tenant, inviter, None, "add")

View File

@ -4,7 +4,7 @@ import logging
import threading import threading
import uuid import uuid
from collections.abc import Callable, Generator, Mapping from collections.abc import Callable, Generator, Mapping
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any
from configs import dify_config from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -88,7 +88,7 @@ class AppGenerateService:
def generate( def generate(
cls, cls,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
@ -356,11 +356,11 @@ class AppGenerateService:
def generate_more_like_this( def generate_more_like_this(
cls, cls,
app_model: App, app_model: App,
user: Union[Account, EndUser], user: Account | EndUser,
message_id: str, message_id: str,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
) -> Union[Mapping, Generator]: ) -> Mapping | Generator:
""" """
Generate more like this Generate more like this
:param app_model: app model :param app_model: app model

View File

@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
import json import json
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Union from typing import Any
from celery.result import AsyncResult from celery.result import AsyncResult
from sqlalchemy import select from sqlalchemy import select
@ -50,7 +50,7 @@ class AsyncWorkflowService:
@classmethod @classmethod
def trigger_workflow_async( def trigger_workflow_async(
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
) -> AsyncTriggerResponse: ) -> AsyncTriggerResponse:
""" """
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
@ -177,7 +177,7 @@ class AsyncWorkflowService:
@classmethod @classmethod
def reinvoke_trigger( def reinvoke_trigger(
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
) -> AsyncTriggerResponse: ) -> AsyncTriggerResponse:
""" """
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK

View File

@ -2822,6 +2822,10 @@ class DocumentService:
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
if not knowledge_config.process_rule.rules.parent_mode:
knowledge_config.process_rule.rules.parent_mode = "paragraph"
if not knowledge_config.process_rule.rules.segmentation: if not knowledge_config.process_rule.rules.segmentation:
raise ValueError("Process rule segmentation is required") raise ValueError("Process rule segmentation is required")

View File

@ -1,6 +1,6 @@
import json import json
from copy import deepcopy from copy import deepcopy
from typing import Any, Union, cast from typing import Any, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
@ -148,18 +148,23 @@ class ExternalDatasetService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
"""
Return usage for an external knowledge API within a single tenant.
The caller already scopes access by tenant, so this query must do the
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
"""
count = ( count = (
db.session.scalar( db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where( select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
ExternalKnowledgeBindings.tenant_id == tenant_id,
) )
) )
or 0 or 0
) )
if count > 0: return count > 0, count
return True, count
return False, 0
@staticmethod @staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
@ -190,9 +195,7 @@ class ExternalDatasetService:
raise ValueError(f"{parameter.get('name')} is required") raise ValueError(f"{parameter.get('name')} is required")
@staticmethod @staticmethod
def process_external_api( def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
) -> httpx.Response:
""" """
do http request depending on api bundle do http request depending on api bundle
""" """

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Iterator, Sequence from collections.abc import Iterator, Sequence
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Literal, Union from typing import Literal
from zipfile import ZIP_DEFLATED, ZipFile from zipfile import ZIP_DEFLATED, ZipFile
from graphon.file import helpers as file_helpers from graphon.file import helpers as file_helpers
@ -52,7 +52,7 @@ class FileService:
filename: str, filename: str,
content: bytes, content: bytes,
mimetype: str, mimetype: str,
user: Union[Account, EndUser], user: Account | EndUser,
source: Literal["datasets"] | None = None, source: Literal["datasets"] | None = None,
source_url: str = "", source_url: str = "",
) -> UploadFile: ) -> UploadFile:

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any, TypedDict, Union from typing import Any, TypedDict
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ( from graphon.model_runtime.entities.provider_entities import (
@ -626,7 +626,7 @@ class ModelLoadBalancingService:
def _get_credential_schema( def _get_credential_schema(
self, provider_configuration: ProviderConfiguration self, provider_configuration: ProviderConfiguration
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: ) -> ModelCredentialSchema | ProviderCredentialSchema:
"""Get form schemas.""" """Get form schemas."""
if provider_configuration.provider.model_credential_schema: if provider_configuration.provider.model_credential_schema:
return provider_configuration.provider.model_credential_schema return provider_configuration.provider.model_credential_schema

View File

@ -1,14 +1,13 @@
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db from core.db.session_factory import session_factory
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
class PluginPermissionService: class PluginPermissionService:
@staticmethod @staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None: def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with sessionmaker(bind=db.engine).begin() as session: with session_factory.create_session() as session:
return session.scalar( return session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
) )
@ -19,7 +18,7 @@ class PluginPermissionService:
install_permission: TenantPluginPermission.InstallPermission, install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission, debug_permission: TenantPluginPermission.DebugPermission,
): ):
with sessionmaker(bind=db.engine).begin() as session: with session_factory.create_session() as session, session.begin():
permission = session.scalar( permission = session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
) )

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Union from typing import Any
from configs import dify_config from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
@ -17,7 +17,7 @@ class PipelineGenerateService:
def generate( def generate(
cls, cls,
pipeline: Pipeline, pipeline: Pipeline,
user: Union[Account, EndUser], user: Account | EndUser,
args: Mapping[str, Any], args: Mapping[str, Any],
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,

View File

@ -5,7 +5,7 @@ import threading
import time import time
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Union, cast from typing import Any, cast
from uuid import uuid4 from uuid import uuid4
from flask_login import current_user from flask_login import current_user
@ -1387,7 +1387,7 @@ class RagPipelineService:
"uninstalled_recommended_plugins": uninstalled_plugin_list, "uninstalled_recommended_plugins": uninstalled_plugin_list,
} }
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser):
""" """
Retry error document Retry error document
""" """

View File

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Union from typing import Any
from pydantic import TypeAdapter, ValidationError from pydantic import TypeAdapter, ValidationError
from yarl import URL from yarl import URL
@ -69,7 +69,7 @@ class ToolTransformService:
return "" return ""
@staticmethod @staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
""" """
repack provider repack provider

View File

@ -7,15 +7,16 @@ with appropriate retry policies and error handling.
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any, NotRequired
from celery import shared_task from celery import shared_task
from graphon.runtime import GraphRuntimeState from graphon.runtime import GraphRuntimeState
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.timeslice_layer import TimeSliceLayer
@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowGeneratorArgsDict(TypedDict):
inputs: dict[str, Any]
files: list[Any]
_skip_prepare_user_inputs: bool
workflow_id: NotRequired[str]
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]): def execute_workflow_professional(task_data_dict: dict[str, Any]):
"""Execute workflow for professional tier with highest priority""" """Execute workflow for professional tier with highest priority"""
@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
) )
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]: def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict:
"""Build args passed into WorkflowAppGenerator.generate for Celery executions.""" """Build args passed into WorkflowAppGenerator.generate for Celery executions."""
return {
args: dict[str, Any] = {
"inputs": dict(trigger_data.inputs), "inputs": dict(trigger_data.inputs),
"files": list(trigger_data.files), "files": list(trigger_data.files),
SKIP_PREPARE_USER_INPUTS_KEY: True, "_skip_prepare_user_inputs": True,
} }
return args
def _execute_workflow_common( def _execute_workflow_common(

View File

@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
second_result.scalar_one_or_none.return_value = expected_account second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result] mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
assert result is expected_account assert result is expected_account
assert mock_session.execute.call_count == 2 assert mock_session.execute.call_count == 2

View File

@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi: class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_fetches_account_with_original_email( def test_reset_fetches_account_with_original_email(
self, self,
mock_get_reset_data, mock_get_reset_data,
mock_revoke_token, mock_revoke_token,
mock_db,
mock_get_account, mock_get_account,
mock_update_account, mock_update_account,
app, app,
@ -126,6 +128,7 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock() mock_account = MagicMock()
mock_get_account.return_value = mock_account mock_get_account.return_value = mock_account
mock_db.session.merge.return_value = mock_account
wraps_features = SimpleNamespace(enable_email_password_login=True) wraps_features = SimpleNamespace(enable_email_password_login=True)
with ( with (
@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
second_result.scalar_one_or_none.return_value = expected_account second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result] mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
assert result is expected_account assert result is expected_account
assert mock_session.execute.call_count == 2 assert mock_session.execute.call_count == 2

View File

@ -437,7 +437,10 @@ class TestAccountGeneration:
second_result.scalar_one_or_none.return_value = expected_account second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result] mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
assert result is expected_account assert result is expected_account
assert mock_session.execute.call_count == 2 assert mock_session.execute.call_count == 2

View File

@ -335,10 +335,12 @@ class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
def test_reset_password_success( def test_reset_password_success(
self, self,
mock_get_tenants, mock_get_tenants,
mock_db,
mock_get_account, mock_get_account,
mock_revoke_token, mock_revoke_token,
mock_get_data, mock_get_data,
@ -356,6 +358,7 @@ class TestForgotPasswordResetApi:
# Arrange # Arrange
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_get_account.return_value = mock_account mock_get_account.return_value = mock_account
mock_db.session.merge.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()] mock_get_tenants.return_value = [MagicMock()]
# Act # Act

View File

@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi:
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@patch("controllers.web.forgot_password.sessionmaker")
def test_should_normalize_email_before_sending( def test_should_normalize_email_before_sending(
self, self,
mock_session_cls,
mock_extract_ip, mock_extract_ip,
mock_rate_limit, mock_rate_limit,
mock_get_account, mock_get_account,
@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi:
mock_account = MagicMock() mock_account = MagicMock()
mock_get_account.return_value = mock_account mock_get_account.return_value = mock_account
mock_send_mail.return_value = "token-123" mock_send_mail.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with app.test_request_context(
with app.test_request_context( "/web/forgot-password",
"/web/forgot-password", method="POST",
method="POST", json={"email": "User@Example.com", "language": "zh-Hans"},
json={"email": "User@Example.com", "language": "zh-Hans"}, ):
): response = ForgotPasswordSendEmailApi().post()
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"} assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_get_account.assert_called_once_with("User@Example.com")
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
mock_extract_ip.assert_called_once() mock_extract_ip.assert_called_once()
mock_rate_limit.assert_called_once_with("127.0.0.1") mock_rate_limit.assert_called_once_with("127.0.0.1")
@ -153,14 +148,14 @@ class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.sessionmaker") @patch("controllers.web.forgot_password.db")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_should_fetch_account_with_fallback( def test_should_fetch_account_with_fallback(
self, self,
mock_get_reset_data, mock_get_reset_data,
mock_revoke_token, mock_revoke_token,
mock_session_cls, mock_db,
mock_get_account, mock_get_account,
mock_update_account, mock_update_account,
app, app,
@ -168,29 +163,27 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
mock_account = MagicMock() mock_account = MagicMock()
mock_get_account.return_value = mock_account mock_get_account.return_value = mock_account
mock_session = MagicMock() mock_db.session.merge.return_value = mock_account
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with app.test_request_context(
with app.test_request_context( "/web/forgot-password/resets",
"/web/forgot-password/resets", method="POST",
method="POST", json={
json={ "token": "token-123",
"token": "token-123", "new_password": "ValidPass123!",
"new_password": "ValidPass123!", "password_confirm": "ValidPass123!",
"password_confirm": "ValidPass123!", },
}, ):
): response = ForgotPasswordResetApi().post()
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"} assert response == {"result": "success"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_get_account.assert_called_once_with("User@Example.com")
mock_update_account.assert_called_once() mock_update_account.assert_called_once()
mock_revoke_token.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.sessionmaker") @patch("controllers.web.forgot_password.db")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@ -199,7 +192,7 @@ class TestForgotPasswordResetApi:
mock_get_account, mock_get_account,
mock_get_reset_data, mock_get_reset_data,
mock_revoke_token, mock_revoke_token,
mock_session_cls, mock_db,
mock_token_bytes, mock_token_bytes,
mock_hash_password, mock_hash_password,
app, app,
@ -207,20 +200,18 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
account = MagicMock() account = MagicMock()
mock_get_account.return_value = account mock_get_account.return_value = account
mock_session = MagicMock() mock_db.session.merge.return_value = account
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): with app.test_request_context(
with app.test_request_context( "/web/forgot-password/resets",
"/web/forgot-password/resets", method="POST",
method="POST", json={
json={ "token": "reset-token",
"token": "reset-token", "new_password": "StrongPass123!",
"new_password": "StrongPass123!", "password_confirm": "StrongPass123!",
"password_confirm": "StrongPass123!", },
}, ):
): response = ForgotPasswordResetApi().post()
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"} assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("reset-token") mock_get_reset_data.assert_called_once_with("reset-token")

View File

@ -1,239 +1,193 @@
from __future__ import annotations
import json import json
from typing import Any, cast from typing import Any, cast
from unittest.mock import ANY, MagicMock, patch from unittest.mock import ANY, MagicMock, patch
from uuid import uuid4
import pytest import pytest
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.models.document import Document from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset, DatasetQuery
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
class TestHitTestingService: def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
"""Test suite for HitTestingService""" tenant_id = str(uuid4())
created_by = str(uuid4())
ds = Dataset(
tenant_id=kwargs.get("tenant_id", tenant_id),
name=kwargs.get("name", "test-dataset"),
created_by=kwargs.get("created_by", created_by),
provider=provider,
)
db_session.add(ds)
db_session.commit()
db_session.refresh(ds)
return ds
# ===== Utility Method Tests =====
class TestHitTestingService:
# ── Utility methods (pure logic, no DB) ────────────────────────────
def test_escape_query_for_search_should_escape_double_quotes(self): def test_escape_query_for_search_should_escape_double_quotes(self):
"""Test that escape_query_for_search escapes double quotes correctly"""
# Arrange
query = 'test "query" with quotes' query = 'test "query" with quotes'
expected = 'test \\"query\\" with quotes'
# Act
result = HitTestingService.escape_query_for_search(query) result = HitTestingService.escape_query_for_search(query)
assert result == 'test \\"query\\" with quotes'
# Assert
assert result == expected
def test_hit_testing_args_check_should_pass_with_valid_query(self): def test_hit_testing_args_check_should_pass_with_valid_query(self):
"""Test that hit_testing_args_check passes with a valid query""" HitTestingService.hit_testing_args_check({"query": "valid query"})
# Arrange
args = {"query": "valid query"}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_should_pass_with_valid_attachments(self): def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
"""Test that hit_testing_args_check passes with valid attachment_ids""" HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
# Arrange
args = {"attachment_ids": ["id1", "id2"]}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" with pytest.raises(ValueError, match="Query or attachment_ids is required"):
# Arrange HitTestingService.hit_testing_args_check({})
args = {}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query or attachment_ids is required" in str(exc_info.value)
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
# Arrange HitTestingService.hit_testing_args_check({"query": "a" * 251})
args = {"query": "a" * 251}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query cannot exceed 250 characters" in str(exc_info.value)
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" with pytest.raises(ValueError, match="Attachment_ids must be a list"):
# Arrange HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
args = {"attachment_ids": "not a list"}
# Act & Assert # ── Response formatting ────────────────────────────────────────────
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Attachment_ids must be a list" in str(exc_info.value)
# ===== Response Formatting Tests =====
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
def test_compact_retrieve_response_should_format_correctly(self, mock_format): def test_compact_retrieve_response_should_format_correctly(self, mock_format):
"""Test that compact_retrieve_response formats the response correctly"""
# Arrange
query = "test query" query = "test query"
mock_doc = MagicMock(spec=Document) mock_doc = MagicMock(spec=Document)
documents = [mock_doc]
mock_record = MagicMock() mock_record = MagicMock()
mock_record.model_dump.return_value = {"content": "formatted content"} mock_record.model_dump.return_value = {"content": "formatted content"}
mock_format.return_value = [mock_record] mock_format.return_value = [mock_record]
# Act result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 1 assert len(result["records"]) == 1
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
mock_format.assert_called_once_with(documents) mock_format.assert_called_once_with([mock_doc])
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): def test_compact_external_retrieve_response_should_return_records_for_external_provider(
"""Test that compact_external_retrieve_response returns records when dataset provider is external""" self, db_session_with_containers: Session
# Arrange ):
dataset = MagicMock(spec=Dataset) dataset = _create_dataset(db_session_with_containers, provider="external")
dataset.provider = "external"
query = "test query"
documents = [ documents = [
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
] ]
# Act result = cast(
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
)
# Assert assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 2 assert len(result["records"]) == 2
assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
assert cast(dict[str, Any], result["records"][1])["title"] == "t2" assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
"""Test that compact_external_retrieve_response returns empty records for non-external provider""" self, db_session_with_containers: Session
# Arrange ):
dataset = MagicMock(spec=Dataset) dataset = _create_dataset(db_session_with_containers, provider="vendor")
dataset.provider = "not_external"
query = "test query"
documents = [{"content": "c1"}]
# Act result = cast(
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) dict[str, Any],
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
)
# Assert assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert cast(dict[str, Any], result["query"])["content"] == query
assert result["records"] == [] assert result["records"] == []
# ===== External Retrieve Tests ===== # ── External retrieve (real DB) ────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
@patch("extensions.ext_database.db.session.add") def test_external_retrieve_should_succeed_for_external_provider(
@patch("extensions.ext_database.db.session.commit") self, mock_ext_retrieve, db_session_with_containers: Session
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): ):
"""Test that external_retrieve successfully retrieves from external provider and commits query""" dataset = _create_dataset(db_session_with_containers, provider="external")
# Arrange account_id = str(uuid4())
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
dataset.provider = "external"
query = 'test "query"'
account = MagicMock() account = MagicMock()
account.id = "account_id" account.id = account_id
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
# Act before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast( result = cast(
dict[str, Any], dict[str, Any],
HitTestingService.external_retrieve( HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=query, query='test "query"',
account=account, account=account,
external_retrieval_model={"model": "test"}, external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"}, metadata_filtering_conditions={"key": "val"},
), ),
) )
# Assert assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
# Verify call to RetrievalService.external_retrieve with escaped query
mock_ext_retrieve.assert_called_once_with( mock_ext_retrieve.assert_called_once_with(
dataset_id="dataset_id", dataset_id=dataset.id,
query='test \\"query\\"', query='test \\"query\\"',
external_retrieval_model={"model": "test"}, external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"}, metadata_filtering_conditions={"key": "val"},
) )
# Verify DatasetQuery record was added and committed db_session_with_containers.expire_all()
mock_add.assert_called_once() after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
mock_commit.assert_called_once() assert after_count == before_count + 1
def test_external_retrieve_should_return_empty_for_non_external_provider(self): def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
"""Test that external_retrieve returns empty results immediately if provider is not external""" dataset = _create_dataset(db_session_with_containers, provider="vendor")
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
account = MagicMock() account = MagicMock()
# Act result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
# Assert assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert cast(dict[str, Any], result["query"])["content"] == query
assert result["records"] == [] assert result["records"] == []
# ===== Retrieve Tests ===== # ── Retrieve (real DB) ─────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add") def test_retrieve_should_use_default_model_when_none_provided(
@patch("extensions.ext_database.db.session.commit") self, mock_retrieve, db_session_with_containers: Session
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): ):
"""Test that retrieve uses default model when retrieval_model is not provided""" dataset = _create_dataset(db_session_with_containers)
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
dataset.retrieval_model = None dataset.retrieval_model = None
query = "test query"
account = MagicMock() account = MagicMock()
account.id = "account_id" account.id = str(uuid4())
mock_retrieve.return_value = [] mock_retrieve.return_value = []
# Act before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast( result = cast(
dict[str, Any], dict[str, Any],
HitTestingService.retrieve( HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
), ),
) )
# Assert assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert cast(dict[str, Any], result["query"])["content"] == query
mock_retrieve.assert_called_once() mock_retrieve.assert_called_once()
# Verify top_k from default_retrieval_model (4)
assert mock_retrieve.call_args.kwargs["top_k"] == 4 assert mock_retrieve.call_args.kwargs["top_k"] == 4
mock_commit.assert_called_once()
db_session_with_containers.expire_all()
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
@patch("extensions.ext_database.db.session.add") def test_retrieve_should_handle_metadata_filtering(
@patch("extensions.ext_database.db.session.commit") self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): ):
"""Test that retrieve correctly calls metadata filtering when conditions are present""" dataset = _create_dataset(db_session_with_containers)
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock() account = MagicMock()
account.id = "account_id" account.id = str(uuid4())
retrieval_model = { retrieval_model = {
"search_method": "semantic_search", "search_method": "semantic_search",
@ -242,29 +196,27 @@ class TestHitTestingService:
"reranking_enable": False, "reranking_enable": False,
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
# Mock metadata filtering response
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
mock_retrieve.return_value = [] mock_retrieve.return_value = []
# Act
HitTestingService.retrieve( HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
) )
# Assert
mock_get_meta.assert_called_once() mock_get_meta.assert_called_once()
mock_retrieve.assert_called_once() mock_retrieve.assert_called_once()
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): def test_retrieve_should_return_empty_if_metadata_filtering_fails(
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
# Arrange ):
dataset = MagicMock(spec=Dataset) dataset = _create_dataset(db_session_with_containers)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock() account = MagicMock()
retrieval_model = { retrieval_model = {
@ -274,37 +226,27 @@ class TestHitTestingService:
"reranking_enable": False, "reranking_enable": False,
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }
# Mock metadata filtering response: condition returned but no IDs
mock_get_meta.return_value = ({}, "condition_string") mock_get_meta.return_value = ({}, "condition_string")
# Act
result = cast( result = cast(
dict[str, Any], dict[str, Any],
HitTestingService.retrieve( HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=query, query="test query",
account=account, account=account,
retrieval_model=retrieval_model, retrieval_model=retrieval_model,
external_retrieval_model={}, external_retrieval_model={},
), ),
) )
# Assert
assert result["records"] == [] assert result["records"] == []
mock_retrieve.assert_not_called() mock_retrieve.assert_not_called()
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add") def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
@patch("extensions.ext_database.db.session.commit") dataset = _create_dataset(db_session_with_containers)
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock() account = MagicMock()
account.id = "account_id" account.id = str(uuid4())
attachment_ids = ["att1", "att2"] attachment_ids = ["att1", "att2"]
retrieval_model = { retrieval_model = {
@ -315,21 +257,19 @@ class TestHitTestingService:
} }
mock_retrieve.return_value = [] mock_retrieve.return_value = []
# Act
HitTestingService.retrieve( HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=query, query="test query",
account=account, account=account,
retrieval_model=retrieval_model, retrieval_model=retrieval_model,
external_retrieval_model={}, external_retrieval_model={},
attachment_ids=attachment_ids, attachment_ids=attachment_ids,
) )
# Assert
mock_retrieve.assert_called_once_with( mock_retrieve.assert_called_once_with(
retrieval_method=ANY, retrieval_method=ANY,
dataset_id="dataset_id", dataset_id=dataset.id,
query=query, query="test query",
attachment_ids=attachment_ids, attachment_ids=attachment_ids,
top_k=4, top_k=4,
score_threshold=0.0, score_threshold=0.0,
@ -338,26 +278,27 @@ class TestHitTestingService:
weights=None, weights=None,
document_ids_filter=None, document_ids_filter=None,
) )
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) # Verify DatasetQuery was persisted with correct content structure
called_query = mock_add.call_args[0][0] db_session_with_containers.expire_all()
query_content = json.loads(called_query.content) latest = db_session_with_containers.scalar(
select(DatasetQuery)
.where(DatasetQuery.dataset_id == dataset.id)
.order_by(DatasetQuery.created_at.desc())
.limit(1)
)
assert latest is not None
query_content = json.loads(latest.content)
assert len(query_content) == 3 # 1 text + 2 images assert len(query_content) == 3 # 1 text + 2 images
assert query_content[0]["content_type"] == "text_query" assert query_content[0]["content_type"] == "text_query"
assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content_type"] == "image_query"
assert query_content[1]["content"] == "att1" assert query_content[1]["content"] == "att1"
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add") def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
@patch("extensions.ext_database.db.session.commit") dataset = _create_dataset(db_session_with_containers)
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve passes reranking and threshold parameters correctly"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock() account = MagicMock()
account.id = "account_id" account.id = str(uuid4())
retrieval_model = { retrieval_model = {
"search_method": "hybrid_search", "search_method": "hybrid_search",
@ -371,12 +312,14 @@ class TestHitTestingService:
} }
mock_retrieve.return_value = [] mock_retrieve.return_value = []
# Act
HitTestingService.retrieve( HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
) )
# Assert
mock_retrieve.assert_called_once() mock_retrieve.assert_called_once()
kwargs = mock_retrieve.call_args.kwargs kwargs = mock_retrieve.call_args.kwargs
assert kwargs["score_threshold"] == 0.5 assert kwargs["score_threshold"] == 0.5

View File

@ -0,0 +1,363 @@
from __future__ import annotations
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import TracingProviderEnum
from models.model import TraceAppConfig
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.ops_service import OpsService
from tests.test_containers_integration_tests.helpers import generate_valid_password
class TestOpsService:
@pytest.fixture
def mock_external_service_dependencies(self):
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager.for_tenant") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
}
@pytest.fixture
def mock_ops_trace_manager(self):
with patch("services.ops_service.OpsTraceManager") as mock:
yield mock
def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies):
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
},
account,
)
return app, account
_SENTINEL = object()
def _insert_trace_config(
self,
db_session: Session,
app_id: str,
provider: str,
tracing_config: dict | None | object = _SENTINEL,
) -> TraceAppConfig:
trace_config = TraceAppConfig(
app_id=app_id,
tracing_provider=provider,
tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"},
)
db_session.add(trace_config)
db_session.commit()
return trace_config
# ── get_tracing_app_config ─────────────────────────────────────────
def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize")
assert result is None
def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
fake_app_id = str(uuid.uuid4())
self._insert_trace_config(db_session_with_containers, fake_app_id, "arize")
result = OpsService.get_tracing_app_config(fake_app_id, "arize")
assert result is None
def test_get_tracing_app_config_none_config(
self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager
):
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None)
with pytest.raises(ValueError, match="Tracing config cannot be None."):
OpsService.get_tracing_app_config(app.id, "arize")
@pytest.mark.parametrize(
("provider", "default_url"),
[
("arize", "https://app.arize.com/"),
("phoenix", "https://app.phoenix.arize.com/projects/"),
("langsmith", "https://smith.langchain.com/"),
("opik", "https://www.comet.com/opik/"),
("weave", "https://wandb.ai/"),
("aliyun", "https://arms.console.aliyun.com/"),
("tencent", "https://console.cloud.tencent.com/apm"),
("mlflow", "http://localhost:5000/"),
("databricks", "https://www.databricks.com/"),
],
)
def test_get_tracing_app_config_providers_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.obfuscated_decrypt_token.return_value = {}
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, provider)
result = OpsService.get_tracing_app_config(app.id, provider)
assert result is not None
assert result["tracing_config"]["project_url"] == default_url
@pytest.mark.parametrize(
"provider",
["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"],
)
def test_get_tracing_app_config_providers_success(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"}
mock_otm.get_trace_config_project_url.return_value = "success_url"
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, provider)
result = OpsService.get_tracing_app_config(app.id, provider)
assert result is not None
assert result["tracing_config"]["project_url"] == "success_url"
def test_get_tracing_app_config_langfuse_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_otm.get_trace_config_project_key.return_value = "key"
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
result = OpsService.get_tracing_app_config(app.id, "langfuse")
assert result is not None
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
def test_get_tracing_app_config_langfuse_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
result = OpsService.get_tracing_app_config(app.id, "langfuse")
assert result is not None
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
# ── create_tracing_app_config ──────────────────────────────────────
def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
assert result == {"error": "Invalid tracing provider: invalid_provider"}
def test_create_tracing_app_config_invalid_credentials(
self, db_session_with_containers: Session, mock_ops_trace_manager
):
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
result = OpsService.create_tracing_app_config(
str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}
)
assert result == {"error": "Invalid Credentials"}
@pytest.mark.parametrize(
("provider", "config"),
[
(TracingProviderEnum.ARIZE, {}),
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
],
)
def test_create_tracing_app_config_project_url_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config
):
# Existing config causes the service to return None before reaching the DB insert
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(provider))
result = OpsService.create_tracing_app_config(app.id, provider, config)
assert result is None
def test_create_tracing_app_config_langfuse_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_key.return_value = "key"
mock_otm.encrypt_tracing_config.return_value = {}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(
app.id,
TracingProviderEnum.LANGFUSE,
{"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"},
)
assert result == {"result": "success"}
def test_create_tracing_app_config_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result is None
def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
assert result is None
def test_create_tracing_app_config_with_empty_other_keys(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# "project" is in other_keys for Arize; providing "" triggers default substitution
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.side_effect = Exception("no url")
mock_otm.encrypt_tracing_config.return_value = {}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""})
assert result == {"result": "success"}
def test_create_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.return_value = "http://project_url"
mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result == {"result": "success"}
# ── update_tracing_app_config ──────────────────────────────────────
def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
assert result is None
def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
fake_app_id = str(uuid.uuid4())
self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE))
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {})
assert result is None
def test_update_tracing_app_config_invalid_credentials(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.encrypt_tracing_config.return_value = {}
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.check_trace_config_is_effective.return_value = False
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
with pytest.raises(ValueError, match="Invalid Credentials"):
OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
def test_update_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.encrypt_tracing_config.return_value = {"updated": "config"}
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.check_trace_config_is_effective.return_value = True
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result is not None
assert result["app_id"] == app.id
# ── delete_tracing_app_config ──────────────────────────────────────
def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session):
result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize")
assert result is None
def test_delete_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "arize")
result = OpsService.delete_tracing_app_config(app.id, "arize")
assert result is True
remaining = db_session_with_containers.scalar(
select(TraceAppConfig)
.where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize")
.limit(1)
)
assert remaining is None

View File

@ -233,11 +233,10 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE assert result.status == AccountStatus.ACTIVE
# Verify database state # Verify database state
refreshed = db_session_with_containers.get(Account, result.id)
db_session_with_containers.refresh(result) assert refreshed is not None
assert result.id is not None assert refreshed.password is not None
assert result.password is not None assert refreshed.password_salt is not None
assert result.password_salt is not None
def test_authenticate_account_not_found( def test_authenticate_account_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies
@ -414,9 +413,8 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE assert result.status == AccountStatus.ACTIVE
# Verify database state # Verify database state
refreshed = db_session_with_containers.get(Account, result.id)
db_session_with_containers.refresh(result) assert refreshed is not None
assert result.id is not None
def test_get_user_through_email_not_found( def test_get_user_through_email_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -1,3 +1,4 @@
from importlib import import_module
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import MagicMock, PropertyMock, patch
import pytest import pytest
@ -11,6 +12,7 @@ from controllers.console.datasets.external import (
BedrockRetrievalApi, BedrockRetrievalApi,
ExternalApiTemplateApi, ExternalApiTemplateApi,
ExternalApiTemplateListApi, ExternalApiTemplateListApi,
ExternalApiUseCheckApi,
ExternalDatasetCreateApi, ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi, ExternalKnowledgeHitTestingApi,
) )
@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService from services.knowledge_service import ExternalDatasetTestService
external_controller = import_module("controllers.console.datasets.external")
def unwrap(func): def unwrap(func):
while hasattr(func, "__wrapped__"): while hasattr(func, "__wrapped__"):
@ -44,10 +48,11 @@ def current_user():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_auth(mocker, current_user): def mock_auth(monkeypatch, current_user):
mocker.patch( monkeypatch.setattr(
"controllers.console.datasets.external.current_account_with_tenant", external_controller,
return_value=(current_user, "tenant-1"), "current_account_with_tenant",
lambda: (current_user, "tenant-1"),
) )
@ -136,6 +141,26 @@ class TestExternalApiTemplateApi:
method(api, "api-id") method(api, "api-id")
class TestExternalApiUseCheckApi:
def test_get_scopes_usage_check_to_current_tenant(self, app):
api = ExternalApiUseCheckApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
ExternalDatasetService,
"external_knowledge_api_use_check",
return_value=(True, 2),
) as mock_use_check,
):
response, status = method(api, "api-id")
assert status == 200
assert response == {"is_using": True, "count": 2}
mock_use_check.assert_called_once_with("api-id", "tenant-1")
class TestExternalDatasetCreateApi: class TestExternalDatasetCreateApi:
def test_create_success(self, app): def test_create_success(self, app):
api = ExternalDatasetCreateApi() api = ExternalDatasetCreateApi()

View File

@ -233,15 +233,20 @@ class TestCheckEmailUnique:
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
session = MagicMock() mock_session = MagicMock()
first = MagicMock() first = MagicMock()
first.scalar_one_or_none.return_value = None first.scalar_one_or_none.return_value = None
second = MagicMock() second = MagicMock()
expected_account = MagicMock() expected_account = MagicMock()
second.scalar_one_or_none.return_value = expected_account second.scalar_one_or_none.return_value = expected_account
session.execute.side_effect = [first, second] mock_session.execute.side_effect = [first, second]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) mock_factory = MagicMock()
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
with patch("services.account_service.session_factory", mock_factory):
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
assert result is expected_account assert result is expected_account
assert session.execute.call_count == 2 assert mock_session.execute.call_count == 2

View File

@ -4,9 +4,7 @@ from unittest.mock import Mock
from core.mcp.entities import ( from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS, SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext, RequestContext,
SessionT,
) )
from core.mcp.session.base_session import BaseSession from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
@ -198,42 +196,3 @@ class TestRequestContext:
assert "RequestContext" in repr_str assert "RequestContext" in repr_str
assert "test-123" in repr_str assert "test-123" in repr_str
assert "MockSession" in repr_str assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@ -39,6 +39,25 @@ class _FakeSession:
return None return None
class _FakeBeginContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return None
def _patch_both(monkeypatch, module, session):
"""Patch both Session and sessionmaker on the module."""
monkeypatch.setattr(module, "Session", lambda _client: session)
monkeypatch.setattr(
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
)
@pytest.fixture @pytest.fixture
def relyt_module(monkeypatch): def relyt_module(monkeypatch):
for name, module in _build_fake_relyt_modules().items(): for name, module in _build_fake_relyt_modules().items():
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
session = _FakeSession() session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session) _patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3) vector.create_collection(3)
session.execute.assert_not_called() session.execute.assert_not_called()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
session = _FakeSession() session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session) _patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3) vector.create_collection(3)
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
# 8. delete commits session # 8. delete commits session
def test_delete_commits_session(relyt_module, monkeypatch): def test_delete_drops_table(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1" vector._collection_name = "collection_1"
vector.client = MagicMock() vector.client = MagicMock()
vector.embedding_dimension = 3 vector.embedding_dimension = 3
session = _FakeSession() session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session) _patch_both(monkeypatch, relyt_module, session)
vector.delete() vector.delete()
session.commit.assert_called_once() session.execute.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):

View File

@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
session = MagicMock() session = MagicMock()
class _SessionCtx: class _BeginCtx:
def __enter__(self): def __enter__(self):
return session return session
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1" vector._collection_name = "collection_1"
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
vector._create_collection(3) vector._create_collection(3)
session.begin.assert_called_once()
sql = str(session.execute.call_args.args[0]) sql = str(session.execute.call_args.args[0])
assert "VECTOR<FLOAT>(3)" in sql assert "VECTOR<FLOAT>(3)" in sql
assert "VEC_L2_DISTANCE" in sql assert "VEC_L2_DISTANCE" in sql
session.commit.assert_called_once()
tidb_module.redis_client.set.assert_called_once() tidb_module.redis_client.set.assert_called_once()
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
def test_delete_drops_table(tidb_module, monkeypatch): def test_delete_drops_table(tidb_module, monkeypatch):
session = MagicMock() session = MagicMock()
session.execute.return_value = None session.execute.return_value = None
session.commit = MagicMock()
class _SessionCtx: class _BeginCtx:
def __enter__(self): def __enter__(self):
return session return session
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1" vector._collection_name = "collection_1"
vector._engine = MagicMock() vector._engine = MagicMock()
vector.delete() vector.delete()
drop_sql = str(session.execute.call_args.args[0]) drop_sql = str(session.execute.call_args.args[0])
assert "DROP TABLE IF EXISTS collection_1" in drop_sql assert "DROP TABLE IF EXISTS collection_1" in drop_sql
session.commit.assert_called_once()
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):

View File

@ -39,7 +39,7 @@ class TestAppGenerateHandler:
"root_node_id": None, "root_node_id": None,
} }
arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs) arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs)
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate" assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
assert "app_model" in arguments, "Handler uses app_model but parameter is missing" assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
@ -70,14 +70,11 @@ class TestAppGenerateHandler:
handler.wrapper( handler.wrapper(
tracer, tracer,
dummy_func, dummy_func,
(), app_model=mock_app_model,
{ user=mock_account_user,
"app_model": mock_app_model, args={"workflow_id": test_workflow_id},
"user": mock_account_user, invoke_from=InvokeFrom.DEBUGGER,
"args": {"workflow_id": test_workflow_id}, streaming=False,
"invoke_from": InvokeFrom.DEBUGGER,
"streaming": False,
},
) )
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()

View File

@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler:
def runner_run(self): def runner_run(self):
return "result" return "result"
handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {}) handler.wrapper(tracer, runner_run, mock_workflow_runner)
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1 assert len(spans) == 1

View File

@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments:
args = (1, 2, 3) args = (1, 2, 3)
kwargs = {} kwargs = {}
result = handler._extract_arguments(func, args, kwargs) result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None assert result is not None
assert result["a"] == 1 assert result["a"] == 1
@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments:
args = () args = ()
kwargs = {"a": 1, "b": 2, "c": 3} kwargs = {"a": 1, "b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs) result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None assert result is not None
assert result["a"] == 1 assert result["a"] == 1
@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments:
args = (1,) args = (1,)
kwargs = {"b": 2, "c": 3} kwargs = {"b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs) result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None assert result is not None
assert result["a"] == 1 assert result["a"] == 1
@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments:
args = (1,) args = (1,)
kwargs = {} kwargs = {}
result = handler._extract_arguments(func, args, kwargs) result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None assert result is not None
assert result["a"] == 1 assert result["a"] == 1
@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments:
instance = MyClass() instance = MyClass()
args = (1, 2) args = (1, 2)
kwargs = {} kwargs = {}
result = handler._extract_arguments(instance.method, args, kwargs) result = handler._extract_arguments(instance.method, *args, **kwargs)
assert result is not None assert result is not None
assert result["a"] == 1 assert result["a"] == 1
@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments:
args = (1,) args = (1,)
kwargs = {} kwargs = {}
result = handler._extract_arguments(func, args, kwargs) result = handler._extract_arguments(func, *args, **kwargs)
assert result is None assert result is None
@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments:
assert func not in handler._signature_cache assert func not in handler._signature_cache
handler._extract_arguments(func, (1, 2), {}) handler._extract_arguments(func, 1, 2)
assert func in handler._signature_cache assert func in handler._signature_cache
cached_sig = handler._signature_cache[func] cached_sig = handler._signature_cache[func]
handler._extract_arguments(func, (3, 4), {}) handler._extract_arguments(func, 3, 4)
assert handler._signature_cache[func] is cached_sig assert handler._signature_cache[func] is cached_sig
@ -142,7 +142,7 @@ class TestSpanHandlerWrapper:
def test_func(): def test_func():
return "result" return "result"
result = handler.wrapper(tracer, test_func, (), {}) result = handler.wrapper(tracer, test_func)
assert result == "result" assert result == "result"
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()
@ -159,7 +159,7 @@ class TestSpanHandlerWrapper:
def test_func(): def test_func():
return "result" return "result"
handler.wrapper(tracer, test_func, (), {}) handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1 assert len(spans) == 1
@ -174,7 +174,7 @@ class TestSpanHandlerWrapper:
def test_func(): def test_func():
return "result" return "result"
handler.wrapper(tracer, test_func, (), {}) handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1 assert len(spans) == 1
@ -190,7 +190,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error") raise ValueError("test error")
with pytest.raises(ValueError, match="test error"): with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {}) handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1 assert len(spans) == 1
@ -208,7 +208,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error") raise ValueError("test error")
with pytest.raises(ValueError): with pytest.raises(ValueError):
handler.wrapper(tracer, test_func, (), {}) handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1 assert len(spans) == 1
@ -225,7 +225,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error") raise ValueError("test error")
with pytest.raises(ValueError, match="test error"): with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {}) handler.wrapper(tracer, test_func)
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter): def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
@ -236,7 +236,7 @@ class TestSpanHandlerWrapper:
def test_func(a, b, c=10): def test_func(a, b, c=10):
return a + b + c return a + b + c
result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3}) result = handler.wrapper(tracer, test_func, 1, 2, c=3)
assert result == 6 assert result == 6
@ -249,7 +249,7 @@ class TestSpanHandlerWrapper:
def my_function(x): def my_function(x):
return x * 2 return x * 2
result = handler.wrapper(tracer, my_function, (5,), {}) result = handler.wrapper(tracer, my_function, 5)
assert result == 10 assert result == 10
spans = memory_span_exporter.get_finished_spans() spans = memory_span_exporter.get_finished_spans()

View File

@ -0,0 +1,138 @@
import json
from libs.pyrefly_type_coverage import (
CoverageSummary,
format_comparison_markdown,
format_summary_markdown,
parse_summary,
)
def _make_report(summary: dict) -> str:
return json.dumps({"module_reports": [], "summary": summary})
_SAMPLE_SUMMARY: dict = {
"n_modules": 100,
"n_typable": 1000,
"n_typed": 400,
"n_any": 50,
"n_untyped": 550,
"coverage": 45.0,
"strict_coverage": 40.0,
"n_functions": 200,
"n_methods": 300,
"n_function_params": 150,
"n_method_params": 250,
"n_classes": 80,
"n_attrs": 40,
"n_properties": 20,
"n_type_ignores": 10,
}
def _make_summary(
*,
n_modules: int = 100,
n_typable: int = 1000,
n_typed: int = 400,
n_any: int = 50,
n_untyped: int = 550,
coverage: float = 45.0,
strict_coverage: float = 40.0,
) -> CoverageSummary:
return {
"n_modules": n_modules,
"n_typable": n_typable,
"n_typed": n_typed,
"n_any": n_any,
"n_untyped": n_untyped,
"coverage": coverage,
"strict_coverage": strict_coverage,
}
def test_parse_summary_extracts_fields() -> None:
report_json = _make_report(_SAMPLE_SUMMARY)
result = parse_summary(report_json)
assert result["n_modules"] == 100
assert result["n_typable"] == 1000
assert result["n_typed"] == 400
assert result["n_any"] == 50
assert result["n_untyped"] == 550
assert result["coverage"] == 45.0
assert result["strict_coverage"] == 40.0
def test_parse_summary_handles_empty_input() -> None:
assert parse_summary("")["n_modules"] == 0
assert parse_summary(" ")["n_modules"] == 0
def test_parse_summary_handles_invalid_json() -> None:
assert parse_summary("not json")["n_modules"] == 0
def test_parse_summary_handles_missing_summary_key() -> None:
assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0
def test_parse_summary_handles_incomplete_summary() -> None:
partial = json.dumps({"summary": {"n_modules": 5}})
assert parse_summary(partial)["n_modules"] == 0
def test_format_summary_markdown_contains_key_metrics() -> None:
summary = _make_summary()
result = format_summary_markdown(summary)
assert "**Type coverage**" in result
assert "45.00%" in result
assert "40.00%" in result
assert "| Modules | 100 |" in result
def test_format_comparison_markdown_shows_positive_delta() -> None:
base = _make_summary()
pr = _make_summary(
n_modules=101,
n_typable=1010,
n_typed=420,
n_untyped=540,
coverage=46.53,
strict_coverage=41.58,
)
result = format_comparison_markdown(base, pr)
assert "| Base | PR | Delta |" in result
assert "+1.53%" in result
assert "+1.58%" in result
assert "+20" in result
def test_format_comparison_markdown_shows_negative_delta() -> None:
base = _make_summary()
pr = _make_summary(
n_typed=390,
n_any=60,
coverage=44.0,
strict_coverage=39.0,
)
result = format_comparison_markdown(base, pr)
assert "-1.00%" in result
assert "-10" in result
def test_format_comparison_markdown_shows_zero_delta() -> None:
summary = _make_summary()
result = format_comparison_markdown(summary, summary)
assert "0.00%" in result
assert "| 0 |" in result

View File

@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings:
mock_db_session.scalar.return_value = 3 mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1") in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is True assert in_use is True
assert count == 3 assert count == 3
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock): def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
""" """
@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings:
mock_db_session.scalar.return_value = 0 mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1") in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is False assert in_use is False
assert count == 0 assert count == 0

View File

@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service"
def _patched_session(): def _patched_session():
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" """Patch session_factory.create_session() to return a mock session as context manager."""
session = MagicMock() session = MagicMock()
mock_sessionmaker = MagicMock() session.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) session.__exit__ = MagicMock(return_value=False)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) session.begin.return_value.__enter__ = MagicMock(return_value=session)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) session.begin.return_value.__exit__ = MagicMock(return_value=False)
db_patcher = patch(f"{MODULE}.db") mock_factory = MagicMock()
return patcher, db_patcher, session mock_factory.create_session.return_value = session
patcher = patch(f"{MODULE}.session_factory", mock_factory)
return patcher, session
class TestGetPermission: class TestGetPermission:
def test_returns_permission_when_found(self): def test_returns_permission_when_found(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
permission = MagicMock() permission = MagicMock()
session.scalar.return_value = permission session.scalar.return_value = permission
with p1, p2: with p1:
from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1") result = PluginPermissionService.get_permission("t1")
@ -30,10 +32,10 @@ class TestGetPermission:
assert result is permission assert result is permission
def test_returns_none_when_not_found(self): def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
session.scalar.return_value = None session.scalar.return_value = None
with p1, p2: with p1:
from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1") result = PluginPermissionService.get_permission("t1")
@ -43,10 +45,10 @@ class TestGetPermission:
class TestChangePermission: class TestChangePermission:
def test_creates_new_permission_when_not_exists(self): def test_creates_new_permission_when_not_exists(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
session.scalar.return_value = None session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock() perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_permission_service import PluginPermissionService
@ -54,20 +56,24 @@ class TestChangePermission:
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE "t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
) )
assert result is True
session.begin.assert_called_once()
session.add.assert_called_once() session.add.assert_called_once()
def test_updates_existing_permission(self): def test_updates_existing_permission(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
existing = MagicMock() existing = MagicMock()
session.scalar.return_value = existing session.scalar.return_value = existing
with p1, p2: with p1:
from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission( result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS "t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
) )
assert result is True
session.begin.assert_called_once()
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.add.assert_not_called() session.add.assert_not_called()

View File

@ -1427,16 +1427,7 @@ class TestRegisterService:
mock_tenant.name = "Test Workspace" mock_tenant.name = "Test Workspace"
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with ( with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
): ):
mock_lookup.return_value = None mock_lookup.return_value = None
@ -1475,7 +1466,7 @@ class TestRegisterService:
status=AccountStatus.PENDING, status=AccountStatus.PENDING,
is_setup=True, is_setup=True,
) )
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) mock_lookup.assert_called_once_with("newuser@example.com")
def test_invite_new_member_normalizes_new_account_email( def test_invite_new_member_normalizes_new_account_email(
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
@ -1486,13 +1477,7 @@ class TestRegisterService:
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
mixed_email = "Invitee@Example.com" mixed_email = "Invitee@Example.com"
mock_session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with ( with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
): ):
mock_lookup.return_value = None mock_lookup.return_value = None
@ -1525,7 +1510,7 @@ class TestRegisterService:
status=AccountStatus.PENDING, status=AccountStatus.PENDING,
is_setup=True, is_setup=True,
) )
mock_lookup.assert_called_once_with(mixed_email, session=mock_session) mock_lookup.assert_called_once_with(mixed_email)
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
@ -1545,16 +1530,7 @@ class TestRegisterService:
account_id="existing-user-456", email="existing@example.com", status="pending" account_id="existing-user-456", email="existing@example.com", status="pending"
) )
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with ( with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
): ):
mock_lookup.return_value = mock_existing_account mock_lookup.return_value = mock_existing_account
@ -1584,7 +1560,7 @@ class TestRegisterService:
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
mock_task_dependencies.delay.assert_called_once() mock_task_dependencies.delay.assert_called_once()
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) mock_lookup.assert_called_once_with("existing@example.com")
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
"""Test inviting a member who is already in the tenant.""" """Test inviting a member who is already in the tenant."""

View File

@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation:
assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1 assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1
assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False
def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self):
knowledge_config = KnowledgeConfig(
indexing_technique="economy",
data_source=DataSource(
info_list=InfoList(
data_source_type="upload_file",
file_info_list=FileInfo(file_ids=["file-1"]),
)
),
process_rule=ProcessRule(
mode="hierarchical",
rules=Rule(
pre_processing_rules=[
PreProcessingRule(id="remove_extra_spaces", enabled=True),
],
segmentation=Segmentation(separator="\n", max_tokens=1024),
subchunk_segmentation=Segmentation(separator="\n", max_tokens=512),
),
),
)
DocumentService.process_rule_args_validate(knowledge_config)
assert knowledge_config.process_rule is not None
assert knowledge_config.process_rule.rules is not None
assert knowledge_config.process_rule.rules.parent_mode == "paragraph"
class TestDocumentServiceSaveDocumentWithDatasetId: class TestDocumentServiceSaveDocumentWithDatasetId:
"""Unit tests for non-SQL validation branches in save_document_with_dataset_id.""" """Unit tests for non-SQL validation branches in save_document_with_dataset_id."""

View File

@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck:
"""Test API use check when API has one binding.""" """Test API use check when API has one binding."""
# Arrange # Arrange
api_id = "api-123" api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 1 mock_db.session.scalar.return_value = 1
# Act # Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert # Assert
assert in_use is True assert in_use is True
assert count == 1 assert count == 1
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
@patch("services.external_knowledge_service.db") @patch("services.external_knowledge_service.db")
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory): def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
"""Test API use check with multiple bindings.""" """Test API use check with multiple bindings."""
# Arrange # Arrange
api_id = "api-123" api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 10 mock_db.session.scalar.return_value = 10
# Act # Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert # Assert
assert in_use is True assert in_use is True
@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck:
"""Test API use check when API is not in use.""" """Test API use check when API is not in use."""
# Arrange # Arrange
api_id = "api-123" api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 0 mock_db.session.scalar.return_value = 0
# Act # Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert # Assert
assert in_use is False assert in_use is False

View File

@ -1,392 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.config_entity import TracingProviderEnum
from models.model import App, TraceAppConfig
from services.ops_service import OpsService
class TestOpsService:
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
# Arrange
mock_db.session.scalar.return_value = None
# Act
result = OpsService.get_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = None
# Act
result = OpsService.get_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = None
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
# Act & Assert
with pytest.raises(ValueError, match="Tracing config cannot be None."):
OpsService.get_tracing_app_config("app_id", "arize")
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
("provider", "default_url"),
[
("arize", "https://app.arize.com/"),
("phoenix", "https://app.phoenix.arize.com/projects/"),
("langsmith", "https://smith.langchain.com/"),
("opik", "https://www.comet.com/opik/"),
("weave", "https://wandb.ai/"),
("aliyun", "https://arms.console.aliyun.com/"),
("tencent", "https://console.cloud.tencent.com/apm"),
("mlflow", "http://localhost:5000/"),
("databricks", "https://www.databricks.com/"),
],
)
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
# Act
result = OpsService.get_tracing_app_config("app_id", provider)
# Assert
assert result["tracing_config"]["project_url"] == default_url
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
)
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
# Act
result = OpsService.get_tracing_app_config("app_id", provider)
# Assert
assert result["tracing_config"]["project_url"] == "success_url"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
# Act
result = OpsService.get_tracing_app_config("app_id", "langfuse")
# Assert
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
# Act
result = OpsService.get_tracing_app_config("app_id", "langfuse")
# Assert
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
# Act
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
# Assert
assert result == {"error": "Invalid tracing provider: invalid_provider"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.LANGFUSE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
# Assert
assert result == {"error": "Invalid Credentials"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
("provider", "config"),
[
(TracingProviderEnum.ARIZE, {}),
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
],
)
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
# Arrange
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
# Act
result = OpsService.create_tracing_app_config("app_id", provider, config)
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.LANGFUSE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
# Act
result = OpsService.create_tracing_app_config(
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
)
# Assert
assert result == {"result": "success"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = None
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
# Act
# 'project' is in other_keys for Arize
# provide an empty string for the project in the tracing_config
# create_tracing_app_config will replace it with the default from the model
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
# Assert
assert result == {"result": "success"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result == {"result": "success"}
mock_db.session.add.assert_called()
mock_db.session.commit.assert_called()
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
# Act & Assert
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_db.session.scalar.return_value = None
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = None
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
# Act & Assert
with pytest.raises(ValueError, match="Invalid Credentials"):
OpsService.update_tracing_app_config("app_id", provider, {})
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
current_config.to_dict.return_value = {"some": "data"}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result == {"some": "data"}
mock_db.session.commit.assert_called_once()
@patch("services.ops_service.db")
def test_delete_tracing_app_config_no_config(self, mock_db):
# Arrange
mock_db.session.scalar.return_value = None
# Act
result = OpsService.delete_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
def test_delete_tracing_app_config_success(self, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = trace_config
# Act
result = OpsService.delete_tracing_app_config("app_id", "arize")
# Assert
assert result is True
mock_db.session.delete.assert_called_with(trace_config)
mock_db.session.commit.assert_called_once()

View File

@ -2,7 +2,7 @@
import type { Area } from 'react-easy-crop' import type { Area } from 'react-easy-crop'
import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput' import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput'
import type { AvatarProps } from '@/app/components/base/avatar' import type { AvatarProps } from '@/app/components/base/ui/avatar'
import type { ImageFile } from '@/types/app' import type { ImageFile } from '@/types/app'
import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react' import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react'
import * as React from 'react' import * as React from 'react'
@ -10,10 +10,10 @@ import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput' import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
import getCroppedImg from '@/app/components/base/app-icon-picker/utils' import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
import { toast } from '@/app/components/base/ui/toast' import { toast } from '@/app/components/base/ui/toast'
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'

View File

@ -6,9 +6,9 @@ import {
import { Fragment } from 'react' import { Fragment } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils' import { resetUser } from '@/app/components/base/amplitude/utils'
import { Avatar } from '@/app/components/base/avatar'
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
import PremiumBadge from '@/app/components/base/premium-badge' import PremiumBadge from '@/app/components/base/premium-badge'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import { useRouter } from '@/next/navigation' import { useRouter } from '@/next/navigation'
import { useLogout, useUserProfile } from '@/service/use-common' import { useLogout, useUserProfile } from '@/service/use-common'

View File

@ -10,9 +10,9 @@ import {
import * as React from 'react' import * as React from 'react'
import { useEffect, useRef } from 'react' import { useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading' import Loading from '@/app/components/base/loading'
import { Avatar } from '@/app/components/base/ui/avatar'
import { toast } from '@/app/components/base/ui/toast' import { toast } from '@/app/components/base/ui/toast'
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'

View File

@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi
import { useDebounce } from 'ahooks' import { useDebounce } from 'ahooks'
import { useCallback, useEffect, useRef, useState } from 'react' import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useSelector } from '@/context/app-context' import { useSelector } from '@/context/app-context'
import { SubjectType } from '@/models/access-control' import { SubjectType } from '@/models/access-control'
import { useSearchForWhiteListCandidates } from '@/service/access-control' import { useSearchForWhiteListCandidates } from '@/service/access-control'
import { cn } from '@/utils/classnames' import { cn } from '@/utils/classnames'
import useAccessControlStore from '../../../../context/access-control-store' import useAccessControlStore from '../../../../context/access-control-store'
import { Avatar } from '../../base/avatar'
import Button from '../../base/button' import Button from '../../base/button'
import Checkbox from '../../base/checkbox' import Checkbox from '../../base/checkbox'
import Input from '../../base/input' import Input from '../../base/input'

View File

@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c
import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react' import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react'
import { useCallback, useEffect } from 'react' import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/ui/avatar'
import { AccessMode } from '@/models/access-control' import { AccessMode } from '@/models/access-control'
import { useAppWhiteListSubjects } from '@/service/access-control' import { useAppWhiteListSubjects } from '@/service/access-control'
import useAccessControlStore from '../../../../context/access-control-store' import useAccessControlStore from '../../../../context/access-control-store'
import { Avatar } from '../../base/avatar'
import Loading from '../../base/loading' import Loading from '../../base/loading'
import Tooltip from '../../base/tooltip' import Tooltip from '../../base/tooltip'
import AddMemberOrGroupDialog from './add-member-or-group-pop' import AddMemberOrGroupDialog from './add-member-or-group-pop'

View File

@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
}, },
})) }))
vi.mock('@/app/components/base/avatar', () => ({ vi.mock('@/app/components/base/ui/avatar', () => ({
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>, Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
})) }))

View File

@ -7,11 +7,11 @@ import {
useCallback, useCallback,
useMemo, useMemo,
} from 'react' } from 'react'
import { Avatar } from '@/app/components/base/avatar'
import Chat from '@/app/components/base/chat/chat' import Chat from '@/app/components/base/chat/chat'
import { useChat } from '@/app/components/base/chat/chat/hooks' import { useChat } from '@/app/components/base/chat/chat/hooks'
import { getLastAnswer } from '@/app/components/base/chat/utils' import { getLastAnswer } from '@/app/components/base/chat/utils'
import { useFeatures } from '@/app/components/base/features/hooks' import { useFeatures } from '@/app/components/base/features/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import { useDebugConfigurationContext } from '@/context/debug-configuration' import { useDebugConfigurationContext } from '@/context/debug-configuration'

View File

@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty
import type { FileEntity } from '@/app/components/base/file-uploader/types' import type { FileEntity } from '@/app/components/base/file-uploader/types'
import { memo, useCallback, useImperativeHandle, useMemo } from 'react' import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
import { useStore as useAppStore } from '@/app/components/app/store' import { useStore as useAppStore } from '@/app/components/app/store'
import { Avatar } from '@/app/components/base/avatar'
import Chat from '@/app/components/base/chat/chat' import Chat from '@/app/components/base/chat/chat'
import { useChat } from '@/app/components/base/chat/chat/hooks' import { useChat } from '@/app/components/base/chat/chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils' import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
import { useFeatures } from '@/app/components/base/features/hooks' import { useFeatures } from '@/app/components/base/features/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import { useDebugConfigurationContext } from '@/context/debug-configuration' import { useDebugConfigurationContext } from '@/context/debug-configuration'

View File

@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon'
import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form' import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form'
import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions' import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions'
import { Markdown } from '@/app/components/base/markdown' import { Markdown } from '@/app/components/base/markdown'
import { Avatar } from '@/app/components/base/ui/avatar'
import { InputVarType } from '@/app/components/workflow/types' import { InputVarType } from '@/app/components/workflow/types'
import { import {
AppSourceType, AppSourceType,
@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w
import { TransferMethod } from '@/types/app' import { TransferMethod } from '@/types/app'
import { cn } from '@/utils/classnames' import { cn } from '@/utils/classnames'
import { formatBooleanInputs } from '@/utils/model-config' import { formatBooleanInputs } from '@/utils/model-config'
import { Avatar } from '../../avatar'
import Chat from '../chat' import Chat from '../chat'
import { useChat } from '../chat/hooks' import { useChat } from '../chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '../utils' import { getLastAnswer, isValidGeneratedAnswer } from '../utils'

View File

@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form' import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form'
import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar' import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar'
import { Markdown } from '@/app/components/base/markdown' import { Markdown } from '@/app/components/base/markdown'
import { Avatar } from '@/app/components/base/ui/avatar'
import { InputVarType } from '@/app/components/workflow/types' import { InputVarType } from '@/app/components/workflow/types'
import { import {
AppSourceType, AppSourceType,
@ -23,7 +24,6 @@ import {
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
import { TransferMethod } from '@/types/app' import { TransferMethod } from '@/types/app'
import { cn } from '@/utils/classnames' import { cn } from '@/utils/classnames'
import { Avatar } from '../../avatar'
import Chat from '../chat' import Chat from '../chat'
import { useChat } from '../chat/hooks' import { useChat } from '../chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '../utils' import { getLastAnswer, isValidGeneratedAnswer } from '../utils'

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react' import { render, screen } from '@testing-library/react'
import { Avatar } from '../index' import { Avatar } from '..'
describe('Avatar', () => { describe('Avatar', () => {
describe('Rendering', () => { describe('Rendering', () => {

View File

@ -53,8 +53,8 @@ function AvatarRoot({
return ( return (
<BaseAvatar.Root <BaseAvatar.Root
className={cn( className={cn(
'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600', 'relative inline-flex shrink-0 items-center justify-center overflow-hidden rounded-full bg-primary-600 select-none',
isAvatarPresetSize(size) && avatarSizeClasses[size].root, avatarSizeClasses[size].root,
className, className,
)} )}
style={resolvedStyle} style={resolvedStyle}
@ -104,7 +104,7 @@ function AvatarImage({
}: AvatarImageProps) { }: AvatarImageProps) {
return ( return (
<BaseAvatar.Image <BaseAvatar.Image
className={cn('absolute inset-0 size-full object-cover', className)} className={cn('inset-0 absolute size-full object-cover', className)}
{...props} {...props}
/> />
) )

View File

@ -4,13 +4,13 @@ import { useDebounceFn } from 'ahooks'
import * as React from 'react' import * as React from 'react'
import { useCallback, useMemo, useState } from 'react' import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Input from '@/app/components/base/input' import Input from '@/app/components/base/input'
import { import {
PortalToFollowElem, PortalToFollowElem,
PortalToFollowElemContent, PortalToFollowElemContent,
PortalToFollowElemTrigger, PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem' } from '@/app/components/base/portal-to-follow-elem'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { DatasetPermission } from '@/models/datasets' import { DatasetPermission } from '@/models/datasets'
import { cn } from '@/utils/classnames' import { cn } from '@/utils/classnames'

View File

@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react'
import { useState } from 'react' import { useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils' import { resetUser } from '@/app/components/base/amplitude/utils'
import { Avatar } from '@/app/components/base/avatar'
import PremiumBadge from '@/app/components/base/premium-badge' import PremiumBadge from '@/app/components/base/premium-badge'
import ThemeSwitcher from '@/app/components/base/theme-switcher' import ThemeSwitcher from '@/app/components/base/theme-switcher'
import { Avatar } from '@/app/components/base/ui/avatar'
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { IS_CLOUD_EDITION } from '@/config' import { IS_CLOUD_EDITION } from '@/config'

View File

@ -2,7 +2,7 @@
import type { InvitationResult } from '@/models/common' import type { InvitationResult } from '@/models/common'
import { useState } from 'react' import { useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar' import { Avatar } from '@/app/components/base/ui/avatar'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import { NUM_INFINITE } from '@/app/components/billing/config' import { NUM_INFINITE } from '@/app/components/billing/config'
import { Plan } from '@/app/components/billing/type' import { Plan } from '@/app/components/billing/type'

View File

@ -3,9 +3,9 @@ import type { FC } from 'react'
import * as React from 'react' import * as React from 'react'
import { useMemo, useState } from 'react' import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Input from '@/app/components/base/input' import Input from '@/app/components/base/input'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useMembers } from '@/service/use-common' import { useMembers } from '@/service/use-common'
import { cn } from '@/utils/classnames' import { cn } from '@/utils/classnames'

View File

@ -0,0 +1,46 @@
import { render } from '@testing-library/react'
import { API_PREFIX } from '@/config'
import BlockIcon, { VarBlockIcon } from '../block-icon'
import { BlockEnum } from '../types'
describe('BlockIcon', () => {
it('renders the default workflow icon container for regular nodes', () => {
const { container } = render(<BlockIcon type={BlockEnum.Start} size="xs" className="extra-class" />)
const iconContainer = container.firstElementChild
expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class')
expect(iconContainer?.querySelector('svg')).toBeInTheDocument()
})
it('normalizes protected plugin icon urls for tool-like nodes', () => {
const { container } = render(
<BlockIcon
type={BlockEnum.Tool}
toolIcon="/foo/workspaces/current/plugin/icon/plugin-tool.png"
/>,
)
const iconContainer = container.firstElementChild as HTMLElement
const backgroundIcon = iconContainer.querySelector('div') as HTMLElement
expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500')
expect(backgroundIcon.style.backgroundImage).toContain(
`${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`,
)
})
})
describe('VarBlockIcon', () => {
it('renders the compact icon variant without the default container wrapper', () => {
const { container } = render(
<VarBlockIcon
type={BlockEnum.Answer}
className="custom-var-icon"
/>,
)
expect(container.querySelector('.custom-var-icon')).toBeInTheDocument()
expect(container.querySelector('svg')).toBeInTheDocument()
expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument()
})
})

View File

@ -0,0 +1,39 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { WorkflowContextProvider } from '../context'
import { useStore, useWorkflowStore } from '../store'
const StoreConsumer = () => {
const showSingleRunPanel = useStore(s => s.showSingleRunPanel)
const store = useWorkflowStore()
return (
<button onClick={() => store.getState().setShowSingleRunPanel(!showSingleRunPanel)}>
{showSingleRunPanel ? 'open' : 'closed'}
</button>
)
}
describe('WorkflowContextProvider', () => {
it('provides the workflow store to descendants and keeps the same store across rerenders', async () => {
const user = userEvent.setup()
const { rerender } = render(
<WorkflowContextProvider>
<StoreConsumer />
</WorkflowContextProvider>,
)
expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'closed' }))
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
rerender(
<WorkflowContextProvider>
<StoreConsumer />
</WorkflowContextProvider>,
)
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
})
})

View File

@ -0,0 +1,67 @@
import type { Edge, Node } from '../types'
import { render, screen } from '@testing-library/react'
import { useStoreApi } from 'reactflow'
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
import WorkflowWithDefaultContext from '../index'
import { BlockEnum } from '../types'
import { useWorkflowHistoryStore } from '../workflow-history-store'
const nodes: Node[] = [
{
id: 'node-start',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Start',
desc: '',
type: BlockEnum.Start,
},
},
]
const edges: Edge[] = [
{
id: 'edge-1',
source: 'node-start',
target: 'node-end',
sourceHandle: null,
targetHandle: null,
type: 'custom',
data: {
sourceType: BlockEnum.Start,
targetType: BlockEnum.End,
},
},
]
const ContextConsumer = () => {
const { store, shortcutsEnabled } = useWorkflowHistoryStore()
const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length)
const reactFlowStore = useStoreApi()
return (
<div>
{`history:${store.getState().nodes.length}`}
{` shortcuts:${String(shortcutsEnabled)}`}
{` datasets:${datasetCount}`}
{` reactflow:${String(!!reactFlowStore)}`}
</div>
)
}
describe('WorkflowWithDefaultContext', () => {
it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => {
render(
<WorkflowWithDefaultContext
nodes={nodes}
edges={edges}
>
<ContextConsumer />
</WorkflowWithDefaultContext>,
)
expect(
screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'),
).toBeInTheDocument()
})
})

View File

@ -0,0 +1,51 @@
import { render, screen } from '@testing-library/react'
import ShortcutsName from '../shortcuts-name'
describe('ShortcutsName', () => {
const originalNavigator = globalThis.navigator
afterEach(() => {
Object.defineProperty(globalThis, 'navigator', {
value: originalNavigator,
writable: true,
configurable: true,
})
})
it('renders mac-friendly key labels and style variants', () => {
Object.defineProperty(globalThis, 'navigator', {
value: { userAgent: 'Macintosh' },
writable: true,
configurable: true,
})
const { container } = render(
<ShortcutsName
keys={['ctrl', 'shift', 's']}
bgColor="white"
textColor="secondary"
/>,
)
expect(screen.getByText('⌘')).toBeInTheDocument()
expect(screen.getByText('⇧')).toBeInTheDocument()
expect(screen.getByText('s')).toBeInTheDocument()
expect(container.querySelector('.system-kbd')).toHaveClass(
'bg-components-kbd-bg-white',
'text-text-tertiary',
)
})
it('keeps raw key names on non-mac systems', () => {
Object.defineProperty(globalThis, 'navigator', {
value: { userAgent: 'Windows NT' },
writable: true,
configurable: true,
})
render(<ShortcutsName keys={['ctrl', 'alt']} />)
expect(screen.getByText('ctrl')).toBeInTheDocument()
expect(screen.getByText('alt')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,97 @@
import type { Edge, Node } from '../types'
import type { WorkflowHistoryState } from '../workflow-history-store'
import { render, renderHook, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { BlockEnum } from '../types'
import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store'
const nodes: Node[] = [
{
id: 'node-1',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Start',
desc: '',
type: BlockEnum.Start,
selected: true,
},
selected: true,
},
]
const edges: Edge[] = [
{
id: 'edge-1',
source: 'node-1',
target: 'node-2',
sourceHandle: null,
targetHandle: null,
type: 'custom',
selected: true,
data: {
sourceType: BlockEnum.Start,
targetType: BlockEnum.End,
},
},
]
const HistoryConsumer = () => {
const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore()
return (
<button onClick={() => setShortcutsEnabled(!shortcutsEnabled)}>
{`nodes:${store.getState().nodes.length} shortcuts:${String(shortcutsEnabled)}`}
</button>
)
}
describe('WorkflowHistoryProvider', () => {
it('provides workflow history state and shortcut toggles', async () => {
const user = userEvent.setup()
render(
<WorkflowHistoryProvider
nodes={nodes}
edges={edges}
>
<HistoryConsumer />
</WorkflowHistoryProvider>,
)
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' }))
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument()
})
it('sanitizes selected flags when history state is replaced through the exposed store api', () => {
const wrapper = ({ children }: { children: React.ReactNode }) => (
<WorkflowHistoryProvider
nodes={nodes}
edges={edges}
>
{children}
</WorkflowHistoryProvider>
)
const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper })
const nextState: WorkflowHistoryState = {
workflowHistoryEvent: undefined,
workflowHistoryEventMeta: undefined,
nodes,
edges,
}
result.current.store.setState(nextState)
expect(result.current.store.getState().nodes[0].data.selected).toBe(false)
expect(result.current.store.getState().edges[0].selected).toBe(false)
})
it('throws when consumed outside the provider', () => {
expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow(
'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider',
)
})
})

View File

@ -0,0 +1,140 @@
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import AllTools from '../all-tools'
import { createGlobalPublicStoreState, createToolProvider } from './factories'
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
useMarketplacePlugins: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
vi.mock('@/utils/var', async importOriginal => ({
...(await importOriginal<typeof import('@/utils/var')>()),
getMarketplaceUrl: () => 'https://marketplace.test/tools',
}))
const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins)
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
const createMarketplacePluginsMock = () => ({
plugins: [],
total: 0,
resetPlugins: vi.fn(),
queryPlugins: vi.fn(),
queryPluginsWithDebounced: vi.fn(),
cancelQueryPluginsWithDebounced: vi.fn(),
isLoading: false,
isFetchingNextPage: false,
hasNextPage: false,
fetchNextPage: vi.fn(),
page: 0,
})
describe('AllTools', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false)))
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock())
})
it('filters tools by the active tab', async () => {
const user = userEvent.setup()
render(
<AllTools
searchText=""
tags={[]}
onSelect={vi.fn()}
buildInTools={[createToolProvider({
id: 'provider-built-in',
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
})]}
customTools={[createToolProvider({
id: 'provider-custom',
type: 'custom',
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
})]}
workflowTools={[]}
mcpTools={[]}
/>,
)
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
await user.click(screen.getByText('workflow.tabs.customTool'))
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument()
})
it('filters the rendered tools by the search text', () => {
render(
<AllTools
searchText="report"
tags={[]}
onSelect={vi.fn()}
buildInTools={[
createToolProvider({
id: 'provider-report',
label: { en_US: 'Report Toolkit', zh_Hans: 'Report Toolkit' },
}),
createToolProvider({
id: 'provider-other',
label: { en_US: 'Other Toolkit', zh_Hans: 'Other Toolkit' },
}),
]}
customTools={[]}
workflowTools={[]}
mcpTools={[]}
/>,
)
expect(screen.getByText('Report Toolkit')).toBeInTheDocument()
expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument()
})
it('shows the empty state when no tool matches the current filter', async () => {
render(
<AllTools
searchText="missing"
tags={[]}
onSelect={vi.fn()}
buildInTools={[]}
customTools={[]}
workflowTools={[]}
mcpTools={[]}
/>,
)
await waitFor(() => {
expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,79 @@
import type { NodeDefault } from '../../types'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { BlockEnum } from '../../types'
import Blocks from '../blocks'
import { BlockClassificationEnum } from '../types'
const runtimeState = vi.hoisted(() => ({
nodes: [] as Array<{ data: { type?: BlockEnum } }>,
}))
vi.mock('reactflow', () => ({
useStoreApi: () => ({
getState: () => ({
getNodes: () => runtimeState.nodes,
}),
}),
}))
const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({
metaData: {
classification,
sort: 0,
type,
title,
author: 'Dify',
description: `${title} description`,
},
defaultValue: {},
checkValid: () => ({ isValid: true }),
})
describe('Blocks', () => {
beforeEach(() => {
runtimeState.nodes = []
})
it('renders grouped blocks, filters duplicate knowledge-base nodes, and selects a block', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }]
render(
<Blocks
searchText=""
onSelect={onSelect}
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.LoopEnd, BlockEnum.KnowledgeBase]}
blocks={[
createBlock(BlockEnum.LLM, 'LLM'),
createBlock(BlockEnum.LoopEnd, 'Exit Loop', BlockClassificationEnum.Logic),
createBlock(BlockEnum.KnowledgeBase, 'Knowledge Retrieval'),
]}
/>,
)
expect(screen.getByText('LLM')).toBeInTheDocument()
expect(screen.getByText('Exit Loop')).toBeInTheDocument()
expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument()
expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument()
await user.click(screen.getByText('LLM'))
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM)
})
it('shows the empty state when no block matches the search', () => {
render(
<Blocks
searchText="missing"
onSelect={vi.fn()}
availableBlocksTypes={[BlockEnum.LLM]}
blocks={[createBlock(BlockEnum.LLM, 'LLM')]}
/>,
)
expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,101 @@
import type { ToolWithProvider } from '../../types'
import type { Plugin } from '@/app/components/plugins/types'
import type { Tool } from '@/app/components/tools/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { CollectionType } from '@/app/components/tools/types'
import { defaultSystemFeatures } from '@/types/feature'
export const createTool = (
name: string,
label: string,
description = `${label} description`,
): Tool => ({
name,
author: 'author',
label: {
en_US: label,
zh_Hans: label,
},
description: {
en_US: description,
zh_Hans: description,
},
parameters: [],
labels: [],
output_schema: {},
})
export const createToolProvider = (
overrides: Partial<ToolWithProvider> = {},
): ToolWithProvider => ({
id: 'provider-1',
name: 'provider-one',
author: 'Provider Author',
description: {
en_US: 'Provider description',
zh_Hans: 'Provider description',
},
icon: 'icon',
icon_dark: 'icon-dark',
label: {
en_US: 'Provider One',
zh_Hans: 'Provider One',
},
type: CollectionType.builtIn,
team_credentials: {},
is_team_authorization: false,
allow_delete: false,
labels: [],
plugin_id: 'plugin-1',
tools: [createTool('tool-a', 'Tool A')],
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
plugin_unique_identifier: 'plugin-1@1.0.0',
...overrides,
})
export const createPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
type: 'plugin',
org: 'org',
author: 'author',
name: 'Plugin One',
plugin_id: 'plugin-1',
version: '1.0.0',
latest_version: '1.0.0',
latest_package_identifier: 'plugin-1@1.0.0',
icon: 'icon',
verified: true,
label: {
en_US: 'Plugin One',
zh_Hans: 'Plugin One',
},
brief: {
en_US: 'Plugin description',
zh_Hans: 'Plugin description',
},
description: {
en_US: 'Plugin description',
zh_Hans: 'Plugin description',
},
introduction: 'Plugin introduction',
repository: 'https://example.com/plugin',
category: PluginCategoryEnum.tool,
tags: [],
badges: [],
install_count: 0,
endpoint: {
settings: [],
},
verification: {
authorized_category: 'community',
},
from: 'github',
...overrides,
})
export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({
systemFeatures: {
...defaultSystemFeatures,
enable_marketplace: enableMarketplace,
},
setSystemFeatures: vi.fn(),
})

Some files were not shown because too many files have changed in this diff Show More