Merge remote-tracking branch 'myori/main' into feat/collaboration2

This commit is contained in:
hjlarry 2026-04-10 22:47:40 +08:00
commit ee2b021395
380 changed files with 28622 additions and 3985 deletions

View File

@ -7,6 +7,7 @@
## Summary
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
<!-- If this PR was created by an automated agent, add `From <Tool Name>` as the final line of the description. Example: `From Codex`. -->
## Screenshots

View File

@ -0,0 +1,118 @@
name: Comment with Pyrefly Type Coverage
on:
workflow_run:
workflows:
- Pyrefly Type Coverage
types:
- completed
permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
pull-requests: write
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Checkout default branch (trusted code)
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Download type coverage artifact
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }},
});
const match = artifacts.data.artifacts.find((artifact) =>
artifact.name === 'pyrefly_type_coverage'
);
if (!match) {
throw new Error('pyrefly_type_coverage artifact not found');
}
const download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: match.id,
archive_format: 'zip',
});
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
- name: Unzip artifact
run: unzip -o pyrefly_type_coverage.zip
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} > /tmp/type_coverage_comment.md
- name: Post comment
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
let prNumber = null;
try {
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
} catch (err) {
const prs = context.payload.workflow_run.pull_requests || [];
if (prs.length > 0 && prs[0].number) {
prNumber = prs[0].number;
}
}
if (!prNumber) {
throw new Error('PR number not found in artifact or workflow_run payload');
}
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const marker = '### Pyrefly Type Coverage';
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -0,0 +1,120 @@
name: Pyrefly Type Coverage
on:
pull_request:
paths:
- 'api/**/*.py'
permissions:
contents: read
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly report on PR branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
echo '{}' > /tmp/pyrefly_report_pr.json
- name: Save helper script from base branch
run: |
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly report on base branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
echo '{}' > /tmp/pyrefly_report_base.json
- name: Generate coverage comparison
id: coverage
run: |
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
--base /tmp/pyrefly_report_base.json \
< /tmp/pyrefly_report_pr.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload type coverage artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: pyrefly_type_coverage
path: |
pr_report.json
base_report.json
pr_number.txt
- name: Comment PR with type coverage
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const marker = '### Pyrefly Type Coverage';
let body;
try {
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
} catch {
body = `${marker}\n\n_Coverage report unavailable._`;
}
const prNumber = context.payload.pull_request.number;
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process.
## Getting Help
If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.
## Automated Agent Contributions
> [!NOTE]
> If you are an automated agent just add 🤖🤖🤖 to the end of the PR title to opt-in.

View File

@ -2,7 +2,6 @@ import base64
import secrets
import click
from sqlalchemy.orm import sessionmaker
from constants.languages import languages
from extensions.ext_database import db
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = db.session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
account = db.session.merge(account)
account.email = normalized_new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command("create-tenant", help="Create account and tenant.")

View File

@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models import Account
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
email = register_data.get("email", "")
normalized_email = email.lower()
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
token = AccountService.send_reset_password_email(
account=account,
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AccountNotFound()
if account:
account = db.session.merge(account)
self._update_existing_account(account, password_hashed, salt)
db.session.commit()
else:
raise AccountNotFound()
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
def _update_existing_account(self, account, password_hashed, salt):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()

View File

@ -4,7 +4,6 @@ import urllib.parse
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
return account

View File

@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource):
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id
external_knowledge_api_id, current_tenant_id
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200

View File

@ -9,7 +9,6 @@ from flask_restx import Resource, fields, marshal_with
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import supported_language
@ -580,8 +579,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email
else:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
if account is None:
raise AccountNotFound()
email_for_sending = account.email

View File

@ -3,7 +3,6 @@ import secrets
from flask import request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
account = AccountService.get_account_by_email_with_case_fallback(request_email)
if account is None:
raise AuthenticationFailedError()
else:
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
self._update_existing_account(account, password_hashed, salt)
else:
raise AuthenticationFailedError()
if account:
account = db.session.merge(account)
self._update_existing_account(account, password_hashed, salt)
db.session.commit()
else:
raise AuthenticationFailedError()
return {"result": "success"}

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from graphon.file import File, FileUploadConfig
from graphon.model_runtime.entities.model_entities import AIModelEntity
@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):

View File

@ -1,10 +1,10 @@
from typing import Literal, Optional
from typing import Any, Literal, TypedDict
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
class DatasourceApiEntity(BaseModel):
@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel):
output_schema: dict | None = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
class DatasourceProviderApiEntityDict(TypedDict):
id: str
author: str
name: str
plugin_id: str | None
plugin_unique_identifier: str | None
description: I18nObjectDict
icon: str | dict
label: I18nObjectDict
type: str
team_credentials: dict | None
is_team_authorization: bool
allow_delete: bool
datasources: list[Any]
labels: list[str]
class DatasourceProviderApiEntity(BaseModel):
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
def to_dict(self) -> DatasourceProviderApiEntityDict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
parameter["type"] = "files"
# -------------
return {
result: DatasourceProviderApiEntityDict = {
"id": self.id,
"author": self.author,
"name": self.name,
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
"datasources": datasources,
"labels": self.labels,
}
return result

View File

@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer:
if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
if not isinstance(message.message.blob, bytes):
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager()
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
user_id=user_id,

View File

@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient):
logger.exception("Authentication retry failed")
raise MCPAuthError(f"Authentication retry failed: {e}") from e
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Execute a function with authentication retry logic.

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, TypeVar
from typing import Any
from pydantic import BaseModel
@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
class RequestContext[SessionT: BaseSession, LifespanContextT]:
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request: ReceiveRequestT
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
def __init__(
self,
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
):
self.request_id = request_id
self.request_meta = request_meta

View File

@ -31,7 +31,6 @@ ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
type AnyFunction = Callable[..., Any]
class RequestParams(BaseModel):

View File

@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback
from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -172,10 +172,10 @@ class ModelInstance:
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
tools=tools,
stop=stop,
tools=list(tools) if tools else None,
stop=list(stop) if stop else None,
stream=stream,
callbacks=callbacks,
),
@ -193,15 +193,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return cast(
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def invoke_text_embedding(
@ -216,15 +213,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
input_type=input_type,
)
def invoke_multimodal_embedding(
@ -241,15 +235,12 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
input_type=input_type,
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
@ -261,14 +252,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return cast(
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
),
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
)
def invoke_rerank(
@ -289,23 +277,20 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_multimodal_rerank(
self,
query: dict,
docs: list[dict],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
top_n: int | None = None,
) -> RerankResult:
@ -320,17 +305,14 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return cast(
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
def invoke_moderation(self, text: str) -> bool:
@ -342,14 +324,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return cast(
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
)
def invoke_speech2text(self, file: IO[bytes]) -> str:
@ -361,14 +340,11 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return cast(
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
)
def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]:
@ -381,18 +357,15 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return cast(
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
),
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
@ -430,9 +403,8 @@ class ModelInstance:
continue
try:
if "credentials" in kwargs:
del kwargs["credentials"]
return function(*args, **kwargs, credentials=lb_config.credentials)
kwargs["credentials"] = lb_config.credentials
return function(*args, **kwargs)
except InvokeRateLimitError as e:
# expire in 60 seconds
self.load_balancing_manager.cooldown(lb_config, expire=60)

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
)
@classmethod
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
def _get_user(cls, user_id: str) -> EndUser | Account:
"""
get the user by user id
"""

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@ -79,7 +79,7 @@ class RelytVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
index_name = f"{self._collection_name}_embedding_index"
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement)
create_statement = sql_text(f"""
@ -104,7 +104,6 @@ class RelytVector(BaseVector):
$$);
""")
session.execute(index_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -208,9 +207,8 @@ class RelytVector(BaseVector):
self.delete_by_uuids(ids)
def delete(self):
with Session(self.client) as session:
with sessionmaker(bind=self.client).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit()
def text_exists(self, id: str) -> bool:
with Session(self.client) as session:

View File

@ -6,7 +6,7 @@ import sqlalchemy
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json
@ -97,8 +97,7 @@ class TiDBVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
with sessionmaker(bind=self._engine).begin() as session:
create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS {self._collection_name} (
id CHAR(36) PRIMARY KEY,
@ -115,7 +114,6 @@ class TiDBVector(BaseVector):
);
""")
session.execute(create_statement)
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -238,9 +236,8 @@ class TiDBVector(BaseVector):
return []
def delete(self):
with Session(self._engine) as session:
with sessionmaker(bind=self._engine).begin() as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()
def _get_distance_func(self) -> str:
match self._distance_func:

View File

@ -3,8 +3,7 @@
import logging
import re
import uuid
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, TypedDict, cast
logger = logging.getLogger(__name__)
@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService
_file_access_controller = DatabaseFileAccessController()
class ParagraphFormatPreviewDict(TypedDict):
chunk_structure: str
preview: list[dict[str, Any]]
total_segments: int
class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword = Keyword(dataset)
keyword.add_texts(documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict:
if isinstance(chunks, list):
preview = []
for content in chunks:
preview.append({"content": content})
return {
result: ParagraphFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARAGRAPH_INDEX,
"preview": preview,
"total_segments": len(chunks),
}
return result
else:
raise ValueError("Chunks is not a list")

View File

@ -3,8 +3,7 @@
import json
import logging
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
from sqlalchemy import delete, select
@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class ParentChildFormatPreviewDict(TypedDict):
chunk_structure: str
parent_mode: str
preview: list[dict[str, Any]]
total_segments: int
class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if all_multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(all_multimodal_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})
return {
result: ParentChildFormatPreviewDict = {
"chunk_structure": IndexStructureType.PARENT_CHILD_INDEX,
"parent_mode": parent_childs.parent_mode,
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),
}
return result
def generate_summary_preview(
self,

View File

@ -4,8 +4,7 @@ import logging
import re
import threading
import uuid
from collections.abc import Mapping
from typing import Any
from typing import Any, TypedDict
import pandas as pd
from flask import Flask, current_app
@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class QAFormatPreviewDict(TypedDict):
chunk_structure: str
qa_preview: list[dict[str, Any]]
total_segments: int
class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor):
else:
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
def format_preview(self, chunks: Any) -> QAFormatPreviewDict:
qa_chunks = QAStructureChunk.model_validate(chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})
return {
result: QAFormatPreviewDict = {
"chunk_structure": IndexStructureType.QA_INDEX,
"qa_preview": preview,
"total_segments": len(qa_chunks.qa_chunks),
}
return result
def generate_summary_preview(
self,

View File

@ -1,7 +1,7 @@
import base64
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.rerank_entities import RerankResult
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from core.model_manager import ModelInstance, ModelManager
from core.rag.index_processor.constant.doc_type import DocType
@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner):
:param query_type: query type
:return: rerank result
"""
docs = []
docs: list[MultimodalRerankInput] = []
doc_ids = set()
unique_documents = []
for document in documents:
@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
document_file_dict = {
"content": document_file_base64,
"content_type": document.metadata["doc_type"],
}
docs.append(document_file_dict)
docs.append(
MultimodalRerankInput(
content=document_file_base64,
content_type=document.metadata["doc_type"],
)
)
else:
document_text_dict = {
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
docs.append(document_text_dict)
docs.append(
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(
{
"content": document.page_content,
"content_type": document.metadata.get("doc_type") or DocType.TEXT,
}
MultimodalRerankInput(
content=document.page_content,
content_type=document.metadata.get("doc_type") or DocType.TEXT,
)
)
unique_documents.append(document)
@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner):
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()
file_query_dict = {
"content": file_query,
"content_type": DocType.IMAGE,
}
file_query_input = MultimodalRerankInput(
content=file_query,
content_type=DocType.IMAGE,
)
rerank_result = self.rerank_model_instance.invoke_multimodal_rerank(
query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n
query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n
)
return rerank_result, unique_documents
else:

View File

@ -6,7 +6,6 @@ import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Union
from uuid import uuid4
import httpx
@ -158,7 +157,7 @@ class ToolFileManager:
return tool_file
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
"""
get file binary
@ -176,7 +175,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
"""
get file binary

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
import sqlalchemy as sa
from graphon.runtime import VariablePool
@ -100,7 +100,7 @@ class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
_builtin_tools_labels: dict[str, I18nObject | None] = {}
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
@ -190,7 +190,7 @@ class ToolManager:
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: str | None = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
"""
get the tool runtime
@ -398,7 +398,7 @@ class ToolManager:
agent_tool: AgentToolEntity,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: "VariablePool | None" = None,
) -> Tool:
"""
get the agent tool runtime
@ -442,7 +442,7 @@ class ToolManager:
workflow_tool: WorkflowToolRuntimeSpec,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: "VariablePool | None" = None,
) -> Tool:
"""
get the workflow tool runtime
@ -634,7 +634,7 @@ class ToolManager:
cls._builtin_providers_loaded = False
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
def get_tool_label(cls, tool_name: str) -> I18nObject | None:
"""
get the tool label
@ -993,7 +993,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str:
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str:
try:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
@ -1001,7 +1001,7 @@ class ToolManager:
mcp_provider = mcp_service.get_provider_entity(
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
)
return mcp_provider.provider_icon
return cast(EmojiIconDict | str, mcp_provider.provider_icon)
except ValueError:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
except Exception:
@ -1013,7 +1013,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> str | EmojiIconDict | dict[str, str]:
) -> str | EmojiIconDict:
"""
get the tool icon
@ -1052,7 +1052,7 @@ class ToolManager:
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
variable_pool: Optional["VariablePool"],
variable_pool: "VariablePool | None",
tool_configurations: Mapping[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:

View File

@ -118,7 +118,8 @@ class ToolFileMessageTransformer:
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
assert isinstance(message.message.blob, bytes)
if not isinstance(message.message.blob, bytes):
raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}")
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,

View File

@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.retry import Retry
from redis.sentinel import Sentinel
from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
@ -126,6 +127,35 @@ redis_client: RedisClientWrapper = RedisClientWrapper()
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
class RedisSSLParamsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
class RedisHealthParamsDict(TypedDict):
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
class RedisBaseParamsDict(TypedDict):
username: str | None
password: str | None
db: int
encoding: str
encoding_errors: str
decode_responses: bool
protocol: int
cache_config: CacheConfig | None
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
health_check_interval: int | None
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
"""Get SSL configuration for Redis connection."""
if not dify_config.REDIS_USE_SSL:
@ -171,14 +201,14 @@ def _get_retry_policy() -> Retry:
)
def _get_connection_health_params() -> dict[str, Any]:
def _get_connection_health_params() -> RedisHealthParamsDict:
"""Get connection health and retry parameters for standalone and Sentinel Redis clients."""
return {
"retry": _get_retry_policy(),
"socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT,
"socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
"health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL,
}
return RedisHealthParamsDict(
retry=_get_retry_policy(),
socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT,
socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT,
health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL,
)
def _get_cluster_connection_health_params() -> dict[str, Any]:
@ -189,26 +219,26 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params = _get_connection_health_params()
params: dict[str, Any] = dict(_get_connection_health_params())
return {k: v for k, v in params.items() if k != "health_check_interval"}
def _get_base_redis_params() -> dict[str, Any]:
def _get_base_redis_params() -> RedisBaseParamsDict:
"""Get base Redis connection parameters including retry and health policy."""
return {
"username": dify_config.REDIS_USERNAME,
"password": dify_config.REDIS_PASSWORD or None,
"db": dify_config.REDIS_DB,
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
"cache_config": _get_cache_configuration(),
return RedisBaseParamsDict(
username=dify_config.REDIS_USERNAME,
password=dify_config.REDIS_PASSWORD or None,
db=dify_config.REDIS_DB,
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
cache_config=_get_cache_configuration(),
**_get_connection_health_params(),
}
)
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create Redis client using Sentinel configuration."""
if not dify_config.REDIS_SENTINELS:
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
@ -232,7 +262,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis,
sentinel_kwargs=sentinel_kwargs,
)
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
params: dict[str, Any] = {**redis_params}
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params)
return master
@ -259,18 +290,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
return cluster
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]:
"""Create standalone Redis client."""
connection_class, ssl_kwargs = _get_ssl_configuration()
params = {**redis_params}
params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
)
params: dict[str, Any] = {
**redis_params,
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
}
if dify_config.REDIS_MAX_CONNECTIONS:
params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS
@ -293,8 +322,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis |
kwargs["max_connections"] = max_conns
return RedisCluster.from_url(pubsub_url, **kwargs)
health_params = _get_connection_health_params()
kwargs = {**health_params}
standalone_health_params: dict[str, Any] = dict(_get_connection_health_params())
kwargs = {**standalone_health_params}
if max_conns:
kwargs["max_connections"] = max_conns
return redis.Redis.from_url(pubsub_url, **kwargs)

View File

@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
handler = _get_handler_instance(handler_class or SpanHandler)
tracer = get_tracer(__name__)
return handler.wrapper(
tracer=tracer,
wrapped=func,
args=args,
kwargs=kwargs,
)
return handler.wrapper(tracer, func, *args, **kwargs)
return cast(Callable[P, R], wrapper)

View File

@ -1,8 +1,8 @@
import inspect
from collections.abc import Callable, Mapping
from collections.abc import Callable
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
class SpanHandler:
@ -16,9 +16,9 @@ class SpanHandler:
exceptions. Handlers can override the wrapper method to customize behavior.
"""
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
"""
Build the span name from the wrapped function.
@ -29,11 +29,11 @@ class SpanHandler:
"""
return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments[T](
def _extract_arguments[**P, R](
self,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> dict[str, Any] | None:
"""
Extract function arguments using inspect.signature.
@ -59,13 +59,13 @@ class SpanHandler:
except Exception:
return None
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""
Fully control the wrapper behavior.

View File

@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
class WorkflowAppRunnerHandler(SpanHandler):
"""Span handler for ``WorkflowAppRunner.run``."""
def wrapper(
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., Any],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@ -14,9 +14,15 @@ from __future__ import annotations
import logging
import threading
from typing import Any
from typing import TYPE_CHECKING, Any
import redis
from redis.cluster import RedisCluster
from redis.exceptions import LockNotOwnedError, RedisError
from redis.lock import Lock
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
logger = logging.getLogger(__name__)
@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock:
primary error/exit code.
"""
_redis_client: Any
_redis_client: redis.Redis | RedisCluster | RedisClientWrapper
_name: str
_ttl_seconds: float
_renew_interval_seconds: float
_log_context: str | None
_logger: logging.Logger
_lock: Any
_lock: Lock | None
_stop_event: threading.Event | None
_thread: threading.Thread | None
_acquired: bool
def __init__(
self,
redis_client: Any,
redis_client: redis.Redis | RedisCluster | RedisClientWrapper,
name: str,
ttl_seconds: float = 60,
renew_interval_seconds: float | None = None,
@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock:
)
self._thread.start()
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None:
while not stop_event.wait(self._renew_interval_seconds):
try:
lock.reacquire()

View File

@ -10,7 +10,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast
from uuid import UUID
from zoneinfo import available_timezones
@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str:
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
def extract_tenant_id(user: "Account | EndUser") -> str | None:
"""
Extract tenant_id from Account or EndUser object.
@ -164,7 +164,10 @@ def email(email):
EmailStr = Annotated[str, AfterValidator(email)]
def uuid_value(value: Any) -> str:
def uuid_value(value: str | UUID) -> str:
if isinstance(value, UUID):
return str(value)
if value == "":
return str(value)
@ -405,7 +408,7 @@ class TokenManager:
def generate_token(
cls,
token_type: str,
account: Optional["Account"] = None,
account: "Account | None" = None,
email: str | None = None,
additional_data: dict | None = None,
) -> str:
@ -465,9 +468,7 @@ class TokenManager:
return current_token
@classmethod
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
):
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float):
key = cls._get_account_token_key(account_id, token_type)
expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(key, expiry_seconds, token)

View File

@ -0,0 +1,145 @@
"""Helpers for generating type-coverage summaries from pyrefly report output."""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import TypedDict
class CoverageSummary(TypedDict):
n_modules: int
n_typable: int
n_typed: int
n_any: int
n_untyped: int
coverage: float
strict_coverage: float
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
_EMPTY_SUMMARY: CoverageSummary = {
"n_modules": 0,
"n_typable": 0,
"n_typed": 0,
"n_any": 0,
"n_untyped": 0,
"coverage": 0.0,
"strict_coverage": 0.0,
}
def parse_summary(report_json: str) -> CoverageSummary:
"""Extract the summary section from ``pyrefly report`` JSON output.
Returns an empty summary when *report_json* is empty or malformed so that
the CI workflow can degrade gracefully instead of crashing.
"""
if not report_json or not report_json.strip():
return _EMPTY_SUMMARY.copy()
try:
data = json.loads(report_json)
except json.JSONDecodeError:
return _EMPTY_SUMMARY.copy()
summary = data.get("summary")
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
return _EMPTY_SUMMARY.copy()
return {
"n_modules": summary["n_modules"],
"n_typable": summary["n_typable"],
"n_typed": summary["n_typed"],
"n_any": summary["n_any"],
"n_untyped": summary["n_untyped"],
"coverage": summary["coverage"],
"strict_coverage": summary["strict_coverage"],
}
def format_summary_markdown(summary: CoverageSummary) -> str:
"""Format a single coverage summary as a Markdown table."""
return (
"| Metric | Value |\n"
"| --- | ---: |\n"
f"| Modules | {summary['n_modules']} |\n"
f"| Typable symbols | {summary['n_typable']:,} |\n"
f"| Typed symbols | {summary['n_typed']:,} |\n"
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
f"| Any symbols | {summary['n_any']:,} |\n"
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
)
def format_comparison_markdown(
base: CoverageSummary,
pr: CoverageSummary,
) -> str:
"""Format a comparison between base and PR coverage as Markdown."""
coverage_delta = pr["coverage"] - base["coverage"]
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
typed_delta = pr["n_typed"] - base["n_typed"]
untyped_delta = pr["n_untyped"] - base["n_untyped"]
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
sign = "+" if value > 0 else ""
return f"{sign}{value:{fmt}}"
lines = [
"| Metric | Base | PR | Delta |",
"| --- | ---: | ---: | ---: |",
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
(
f"| Strict coverage | {base['strict_coverage']:.2f}% "
f"| {pr['strict_coverage']:.2f}% "
f"| {_fmt_delta(strict_delta)}% |"
),
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
(
f"| Modules | {base['n_modules']} "
f"| {pr['n_modules']} "
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
),
]
return "\n".join(lines)
def main() -> int:
"""Read pyrefly report JSON from stdin and print a Markdown summary.
Accepts an optional ``--base <file>`` argument. When provided, the output
includes a base-vs-PR comparison table.
"""
args = sys.argv[1:]
base_file: str | None = None
if "--base" in args:
idx = args.index("--base")
if idx + 1 >= len(args):
sys.stderr.write("error: --base requires a file path\n")
return 1
base_file = args[idx + 1]
pr_report = sys.stdin.read()
pr_summary = parse_summary(pr_report)
if base_file is not None:
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
base_summary = parse_summary(base_text)
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
else:
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase):
class DefaultFieldsMixin:
"""Mixin for models that inherit from Base (non-dataclass)."""
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
@ -53,6 +55,42 @@ class DefaultFieldsMixin:
return f"<{self.__class__.__name__}(id={self.id})>"
class DefaultFieldsDCMixin(MappedAsDataclass):
"""Mixin for models that inherit from TypeBase (MappedAsDataclass)."""
__abstract__ = True
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuidv7()),
default_factory=lambda: str(uuidv7()),
init=False,
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
insert_default=naive_utc_now,
default_factory=naive_utc_now,
init=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(id={self.id})>"
def gen_uuidv4_string() -> str:
"""gen_uuidv4_string generate a UUIDv4 string.

View File

@ -913,11 +913,7 @@ class TrialApp(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
@ -941,11 +937,7 @@ class AccountTrialAppRecord(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@property

View File

@ -4,7 +4,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
"""
Get workflow type from app mode.
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property

View File

@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from core.db.session_factory import session_factory
class InvitationData(TypedDict):
@ -800,19 +801,19 @@ class AccountService:
return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
with session_factory.create_session() as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
@ -1516,8 +1517,7 @@ class RegisterService:
check_workspace_member_invite_permission(tenant.id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")

View File

@ -4,7 +4,7 @@ import logging
import threading
import uuid
from collections.abc import Callable, Generator, Mapping
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -88,7 +88,7 @@ class AppGenerateService:
def generate(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@ -356,11 +356,11 @@ class AppGenerateService:
def generate_more_like_this(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator]:
) -> Mapping | Generator:
"""
Generate more like this
:param app_model: app model

View File

@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
import json
from datetime import UTC, datetime
from typing import Any, Union
from typing import Any
from celery.result import AsyncResult
from sqlalchemy import select
@ -50,7 +50,7 @@ class AsyncWorkflowService:
@classmethod
def trigger_workflow_async(
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
) -> AsyncTriggerResponse:
"""
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
@ -177,7 +177,7 @@ class AsyncWorkflowService:
@classmethod
def reinvoke_trigger(
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
) -> AsyncTriggerResponse:
"""
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK

View File

@ -2822,6 +2822,10 @@ class DocumentService:
knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values())
if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
if not knowledge_config.process_rule.rules.parent_mode:
knowledge_config.process_rule.rules.parent_mode = "paragraph"
if not knowledge_config.process_rule.rules.segmentation:
raise ValueError("Process rule segmentation is required")

View File

@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Union, cast
from typing import Any, cast
from urllib.parse import urlparse
import httpx
@ -148,18 +148,23 @@ class ExternalDatasetService:
db.session.commit()
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]:
"""
Return usage for an external knowledge API within a single tenant.
The caller already scopes access by tenant, so this query must do the
same; otherwise the endpoint becomes a cross-tenant UUID oracle.
"""
count = (
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id,
ExternalKnowledgeBindings.tenant_id == tenant_id,
)
)
or 0
)
if count > 0:
return True, count
return False, 0
return count > 0, count
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
@ -190,9 +195,7 @@ class ExternalDatasetService:
raise ValueError(f"{parameter.get('name')} is required")
@staticmethod
def process_external_api(
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
) -> httpx.Response:
def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
"""
do http request depending on api bundle
"""

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Iterator, Sequence
from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile
from typing import Literal, Union
from typing import Literal
from zipfile import ZIP_DEFLATED, ZipFile
from graphon.file import helpers as file_helpers
@ -52,7 +52,7 @@ class FileService:
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser],
user: Account | EndUser,
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict, Union
from typing import Any, TypedDict
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -626,7 +626,7 @@ class ModelLoadBalancingService:
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
) -> ModelCredentialSchema | ProviderCredentialSchema:
"""Get form schemas."""
if provider_configuration.provider.model_credential_schema:
return provider_configuration.provider.model_credential_schema

View File

@ -1,14 +1,13 @@
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models.account import TenantPluginPermission
class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with sessionmaker(bind=db.engine).begin() as session:
with session_factory.create_session() as session:
return session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)
@ -19,7 +18,7 @@ class PluginPermissionService:
install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission,
):
with sessionmaker(bind=db.engine).begin() as session:
with session_factory.create_session() as session, session.begin():
permission = session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Union
from typing import Any
from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
@ -17,7 +17,7 @@ class PipelineGenerateService:
def generate(
cls,
pipeline: Pipeline,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,

View File

@ -5,7 +5,7 @@ import threading
import time
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Union, cast
from typing import Any, cast
from uuid import uuid4
from flask_login import current_user
@ -1387,7 +1387,7 @@ class RagPipelineService:
"uninstalled_recommended_plugins": uninstalled_plugin_list,
}
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser):
"""
Retry error document
"""

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Mapping
from typing import Any, Union
from typing import Any
from pydantic import TypeAdapter, ValidationError
from yarl import URL
@ -69,7 +69,7 @@ class ToolTransformService:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
"""
repack provider

View File

@ -7,15 +7,16 @@ with appropriate retry policies and error handling.
import logging
from datetime import UTC, datetime
from typing import Any
from typing import Any, NotRequired
from celery import shared_task
from graphon.runtime import GraphRuntimeState
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer
@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
logger = logging.getLogger(__name__)
class WorkflowGeneratorArgsDict(TypedDict):
inputs: dict[str, Any]
files: list[Any]
_skip_prepare_user_inputs: bool
workflow_id: NotRequired[str]
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]):
"""Execute workflow for professional tier with highest priority"""
@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
)
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]:
def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict:
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
args: dict[str, Any] = {
return {
"inputs": dict(trigger_data.inputs),
"files": list(trigger_data.files),
SKIP_PREPARE_USER_INPUTS_KEY: True,
"_skip_prepare_user_inputs": True,
}
return args
def _execute_workflow_common(

View File

@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_fetches_account_with_original_email(
self,
mock_get_reset_data,
mock_revoke_token,
mock_db,
mock_get_account,
mock_update_account,
app,
@ -126,6 +128,7 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_db.session.merge.return_value = mock_account
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -437,7 +437,10 @@ class TestAccountGeneration:
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -335,10 +335,12 @@ class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
def test_reset_password_success(
self,
mock_get_tenants,
mock_db,
mock_get_account,
mock_revoke_token,
mock_get_data,
@ -356,6 +358,7 @@ class TestForgotPasswordResetApi:
# Arrange
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_get_account.return_value = mock_account
mock_db.session.merge.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
# Act

View File

@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi:
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@patch("controllers.web.forgot_password.sessionmaker")
def test_should_normalize_email_before_sending(
self,
mock_session_cls,
mock_extract_ip,
mock_rate_limit,
mock_get_account,
@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi:
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_mail.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
with app.test_request_context(
"/web/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_get_account.assert_called_once_with("User@Example.com")
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
mock_extract_ip.assert_called_once()
mock_rate_limit.assert_called_once_with("127.0.0.1")
@ -153,14 +148,14 @@ class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.db")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_should_fetch_account_with_fallback(
self,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_db,
mock_get_account,
mock_update_account,
app,
@ -168,29 +163,27 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
mock_db.session.merge.return_value = mock_account
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_get_account.assert_called_once_with("User@Example.com")
mock_update_account.assert_called_once()
mock_revoke_token.assert_called_once_with("token-123")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.db")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@ -199,7 +192,7 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_db,
mock_token_bytes,
mock_hash_password,
app,
@ -207,20 +200,18 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
account = MagicMock()
mock_get_account.return_value = account
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
mock_db.session.merge.return_value = account
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("reset-token")

View File

@ -1,239 +1,193 @@
from __future__ import annotations
import json
from typing import Any, cast
from unittest.mock import ANY, MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.models.document import Document
from models.dataset import Dataset
from models.dataset import Dataset, DatasetQuery
from services.hit_testing_service import HitTestingService
class TestHitTestingService:
"""Test suite for HitTestingService"""
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
tenant_id = str(uuid4())
created_by = str(uuid4())
ds = Dataset(
tenant_id=kwargs.get("tenant_id", tenant_id),
name=kwargs.get("name", "test-dataset"),
created_by=kwargs.get("created_by", created_by),
provider=provider,
)
db_session.add(ds)
db_session.commit()
db_session.refresh(ds)
return ds
# ===== Utility Method Tests =====
class TestHitTestingService:
# ── Utility methods (pure logic, no DB) ────────────────────────────
def test_escape_query_for_search_should_escape_double_quotes(self):
"""Test that escape_query_for_search escapes double quotes correctly"""
# Arrange
query = 'test "query" with quotes'
expected = 'test \\"query\\" with quotes'
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == expected
assert result == 'test \\"query\\" with quotes'
def test_hit_testing_args_check_should_pass_with_valid_query(self):
"""Test that hit_testing_args_check passes with a valid query"""
# Arrange
args = {"query": "valid query"}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
HitTestingService.hit_testing_args_check({"query": "valid query"})
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
"""Test that hit_testing_args_check passes with valid attachment_ids"""
# Arrange
args = {"attachment_ids": ["id1", "id2"]}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
# Arrange
args = {}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query or attachment_ids is required" in str(exc_info.value)
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
HitTestingService.hit_testing_args_check({})
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
# Arrange
args = {"query": "a" * 251}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query cannot exceed 250 characters" in str(exc_info.value)
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check({"query": "a" * 251})
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
# Arrange
args = {"attachment_ids": "not a list"}
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Attachment_ids must be a list" in str(exc_info.value)
# ===== Response Formatting Tests =====
# ── Response formatting ────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
"""Test that compact_retrieve_response formats the response correctly"""
# Arrange
query = "test query"
mock_doc = MagicMock(spec=Document)
documents = [mock_doc]
mock_record = MagicMock()
mock_record.model_dump.return_value = {"content": "formatted content"}
mock_format.return_value = [mock_record]
# Act
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 1
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
mock_format.assert_called_once_with(documents)
mock_format.assert_called_once_with([mock_doc])
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "external"
query = "test query"
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
self, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="external")
documents = [
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
]
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
result = cast(
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert len(result["records"]) == 2
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
documents = [{"content": "c1"}]
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
self, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="vendor")
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
result = cast(
dict[str, Any],
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
# ===== External Retrieve Tests =====
# ── External retrieve (real DB) ────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
dataset.provider = "external"
query = 'test "query"'
def test_external_retrieve_should_succeed_for_external_provider(
self, mock_ext_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="external")
account_id = str(uuid4())
account = MagicMock()
account.id = "account_id"
account.id = account_id
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
# Act
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
HitTestingService.external_retrieve(
dataset=dataset,
query=query,
query='test "query"',
account=account,
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
# Verify call to RetrievalService.external_retrieve with escaped query
mock_ext_retrieve.assert_called_once_with(
dataset_id="dataset_id",
dataset_id=dataset.id,
query='test \\"query\\"',
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
)
# Verify DatasetQuery record was added and committed
mock_add.assert_called_once()
mock_commit.assert_called_once()
db_session_with_containers.expire_all()
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
"""Test that external_retrieve returns empty results immediately if provider is not external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers, provider="vendor")
account = MagicMock()
# Act
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
# ===== Retrieve Tests =====
# ── Retrieve (real DB) ─────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve uses default model when retrieval_model is not provided"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
def test_retrieve_should_use_default_model_when_none_provided(
self, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
dataset.retrieval_model = None
query = "test query"
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
mock_retrieve.return_value = []
# Act
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
mock_retrieve.assert_called_once()
# Verify top_k from default_retrieval_model (4)
assert mock_retrieve.call_args.kwargs["top_k"] == 4
mock_commit.assert_called_once()
db_session_with_containers.expire_all()
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_metadata_filtering(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
retrieval_model = {
"search_method": "semantic_search",
@ -242,29 +196,27 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
)
# Assert
mock_get_meta.assert_called_once()
mock_retrieve.assert_called_once()
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
retrieval_model = {
@ -274,37 +226,27 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response: condition returned but no IDs
mock_get_meta.return_value = ({}, "condition_string")
# Act
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset,
query=query,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
),
)
# Assert
assert result["records"] == []
mock_retrieve.assert_not_called()
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
attachment_ids = ["att1", "att2"]
retrieval_model = {
@ -315,21 +257,19 @@ class TestHitTestingService:
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset,
query=query,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
attachment_ids=attachment_ids,
)
# Assert
mock_retrieve.assert_called_once_with(
retrieval_method=ANY,
dataset_id="dataset_id",
query=query,
dataset_id=dataset.id,
query="test query",
attachment_ids=attachment_ids,
top_k=4,
score_threshold=0.0,
@ -338,26 +278,27 @@ class TestHitTestingService:
weights=None,
document_ids_filter=None,
)
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
called_query = mock_add.call_args[0][0]
query_content = json.loads(called_query.content)
# Verify DatasetQuery was persisted with correct content structure
db_session_with_containers.expire_all()
latest = db_session_with_containers.scalar(
select(DatasetQuery)
.where(DatasetQuery.dataset_id == dataset.id)
.order_by(DatasetQuery.created_at.desc())
.limit(1)
)
assert latest is not None
query_content = json.loads(latest.content)
assert len(query_content) == 3 # 1 text + 2 images
assert query_content[0]["content_type"] == "text_query"
assert query_content[1]["content_type"] == "image_query"
assert query_content[1]["content"] == "att1"
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve passes reranking and threshold parameters correctly"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
retrieval_model = {
"search_method": "hybrid_search",
@ -371,12 +312,14 @@ class TestHitTestingService:
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
)
# Assert
mock_retrieve.assert_called_once()
kwargs = mock_retrieve.call_args.kwargs
assert kwargs["score_threshold"] == 0.5

View File

@ -0,0 +1,363 @@
from __future__ import annotations
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import TracingProviderEnum
from models.model import TraceAppConfig
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.ops_service import OpsService
from tests.test_containers_integration_tests.helpers import generate_valid_password
class TestOpsService:
@pytest.fixture
def mock_external_service_dependencies(self):
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager.for_tenant") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
}
@pytest.fixture
def mock_ops_trace_manager(self):
with patch("services.ops_service.OpsTraceManager") as mock:
yield mock
def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies):
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
},
account,
)
return app, account
_SENTINEL = object()
def _insert_trace_config(
self,
db_session: Session,
app_id: str,
provider: str,
tracing_config: dict | None | object = _SENTINEL,
) -> TraceAppConfig:
trace_config = TraceAppConfig(
app_id=app_id,
tracing_provider=provider,
tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"},
)
db_session.add(trace_config)
db_session.commit()
return trace_config
# ── get_tracing_app_config ─────────────────────────────────────────
def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize")
assert result is None
def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
fake_app_id = str(uuid.uuid4())
self._insert_trace_config(db_session_with_containers, fake_app_id, "arize")
result = OpsService.get_tracing_app_config(fake_app_id, "arize")
assert result is None
def test_get_tracing_app_config_none_config(
self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager
):
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None)
with pytest.raises(ValueError, match="Tracing config cannot be None."):
OpsService.get_tracing_app_config(app.id, "arize")
@pytest.mark.parametrize(
("provider", "default_url"),
[
("arize", "https://app.arize.com/"),
("phoenix", "https://app.phoenix.arize.com/projects/"),
("langsmith", "https://smith.langchain.com/"),
("opik", "https://www.comet.com/opik/"),
("weave", "https://wandb.ai/"),
("aliyun", "https://arms.console.aliyun.com/"),
("tencent", "https://console.cloud.tencent.com/apm"),
("mlflow", "http://localhost:5000/"),
("databricks", "https://www.databricks.com/"),
],
)
def test_get_tracing_app_config_providers_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.obfuscated_decrypt_token.return_value = {}
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, provider)
result = OpsService.get_tracing_app_config(app.id, provider)
assert result is not None
assert result["tracing_config"]["project_url"] == default_url
@pytest.mark.parametrize(
"provider",
["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"],
)
def test_get_tracing_app_config_providers_success(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"}
mock_otm.get_trace_config_project_url.return_value = "success_url"
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, provider)
result = OpsService.get_tracing_app_config(app.id, provider)
assert result is not None
assert result["tracing_config"]["project_url"] == "success_url"
def test_get_tracing_app_config_langfuse_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_otm.get_trace_config_project_key.return_value = "key"
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
result = OpsService.get_tracing_app_config(app.id, "langfuse")
assert result is not None
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
def test_get_tracing_app_config_langfuse_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
result = OpsService.get_tracing_app_config(app.id, "langfuse")
assert result is not None
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
# ── create_tracing_app_config ──────────────────────────────────────
def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
assert result == {"error": "Invalid tracing provider: invalid_provider"}
def test_create_tracing_app_config_invalid_credentials(
self, db_session_with_containers: Session, mock_ops_trace_manager
):
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
result = OpsService.create_tracing_app_config(
str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}
)
assert result == {"error": "Invalid Credentials"}
@pytest.mark.parametrize(
("provider", "config"),
[
(TracingProviderEnum.ARIZE, {}),
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
],
)
def test_create_tracing_app_config_project_url_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config
):
# Existing config causes the service to return None before reaching the DB insert
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(provider))
result = OpsService.create_tracing_app_config(app.id, provider, config)
assert result is None
def test_create_tracing_app_config_langfuse_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_key.return_value = "key"
mock_otm.encrypt_tracing_config.return_value = {}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(
app.id,
TracingProviderEnum.LANGFUSE,
{"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"},
)
assert result == {"result": "success"}
def test_create_tracing_app_config_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result is None
def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
assert result is None
def test_create_tracing_app_config_with_empty_other_keys(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# "project" is in other_keys for Arize; providing "" triggers default substitution
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.side_effect = Exception("no url")
mock_otm.encrypt_tracing_config.return_value = {}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""})
assert result == {"result": "success"}
def test_create_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.return_value = "http://project_url"
mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result == {"result": "success"}
# ── update_tracing_app_config ──────────────────────────────────────
def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
assert result is None
def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
fake_app_id = str(uuid.uuid4())
self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE))
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {})
assert result is None
def test_update_tracing_app_config_invalid_credentials(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.encrypt_tracing_config.return_value = {}
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.check_trace_config_is_effective.return_value = False
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
with pytest.raises(ValueError, match="Invalid Credentials"):
OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
def test_update_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.encrypt_tracing_config.return_value = {"updated": "config"}
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.check_trace_config_is_effective.return_value = True
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result is not None
assert result["app_id"] == app.id
# ── delete_tracing_app_config ──────────────────────────────────────
def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session):
result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize")
assert result is None
def test_delete_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "arize")
result = OpsService.delete_tracing_app_config(app.id, "arize")
assert result is True
remaining = db_session_with_containers.scalar(
select(TraceAppConfig)
.where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize")
.limit(1)
)
assert remaining is None

View File

@ -233,11 +233,10 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE
# Verify database state
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.password is not None
assert result.password_salt is not None
refreshed = db_session_with_containers.get(Account, result.id)
assert refreshed is not None
assert refreshed.password is not None
assert refreshed.password_salt is not None
def test_authenticate_account_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -414,9 +413,8 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE
# Verify database state
db_session_with_containers.refresh(result)
assert result.id is not None
refreshed = db_session_with_containers.get(Account, result.id)
assert refreshed is not None
def test_get_user_through_email_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -1,3 +1,4 @@
from importlib import import_module
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -11,6 +12,7 @@ from controllers.console.datasets.external import (
BedrockRetrievalApi,
ExternalApiTemplateApi,
ExternalApiTemplateListApi,
ExternalApiUseCheckApi,
ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi,
)
@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
external_controller = import_module("controllers.console.datasets.external")
def unwrap(func):
while hasattr(func, "__wrapped__"):
@ -44,10 +48,11 @@ def current_user():
@pytest.fixture(autouse=True)
def mock_auth(mocker, current_user):
mocker.patch(
"controllers.console.datasets.external.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
def mock_auth(monkeypatch, current_user):
monkeypatch.setattr(
external_controller,
"current_account_with_tenant",
lambda: (current_user, "tenant-1"),
)
@ -136,6 +141,26 @@ class TestExternalApiTemplateApi:
method(api, "api-id")
class TestExternalApiUseCheckApi:
def test_get_scopes_usage_check_to_current_tenant(self, app):
api = ExternalApiUseCheckApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
ExternalDatasetService,
"external_knowledge_api_use_check",
return_value=(True, 2),
) as mock_use_check,
):
response, status = method(api, "api-id")
assert status == 200
assert response == {"is_using": True, "count": 2}
mock_use_check.assert_called_once_with("api-id", "tenant-1")
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
api = ExternalDatasetCreateApi()

View File

@ -233,15 +233,20 @@ class TestCheckEmailUnique:
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
session = MagicMock()
mock_session = MagicMock()
first = MagicMock()
first.scalar_one_or_none.return_value = None
second = MagicMock()
expected_account = MagicMock()
second.scalar_one_or_none.return_value = expected_account
session.execute.side_effect = [first, second]
mock_session.execute.side_effect = [first, second]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
mock_factory = MagicMock()
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
with patch("services.account_service.session_factory", mock_factory):
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
assert result is expected_account
assert session.execute.call_count == 2
assert mock_session.execute.call_count == 2

View File

@ -4,9 +4,7 @@ from unittest.mock import Mock
from core.mcp.entities import (
SUPPORTED_PROTOCOL_VERSIONS,
LifespanContextT,
RequestContext,
SessionT,
)
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams
@ -198,42 +196,3 @@ class TestRequestContext:
assert "RequestContext" in repr_str
assert "test-123" in repr_str
assert "MockSession" in repr_str
class TestTypeVariables:
"""Test type variables defined in the module."""
def test_session_type_var(self):
"""Test SessionT type variable."""
# Create a custom session class
class CustomSession(BaseSession):
pass
# Use in generic context
def process_session(session: SessionT) -> SessionT:
return session
mock_session = Mock(spec=CustomSession)
result = process_session(mock_session)
assert result == mock_session
def test_lifespan_context_type_var(self):
"""Test LifespanContextT type variable."""
# Use in generic context
def process_lifespan(context: LifespanContextT) -> LifespanContextT:
return context
# Test with different types
str_context = "string-context"
assert process_lifespan(str_context) == str_context
dict_context = {"key": "value"}
assert process_lifespan(dict_context) == dict_context
class CustomContext:
pass
custom_context = CustomContext()
assert process_lifespan(custom_context) == custom_context

View File

@ -39,6 +39,25 @@ class _FakeSession:
return None
class _FakeBeginContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return None
def _patch_both(monkeypatch, module, session):
"""Patch both Session and sessionmaker on the module."""
monkeypatch.setattr(module, "Session", lambda _client: session)
monkeypatch.setattr(
module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session)))
)
@pytest.fixture
def relyt_module(monkeypatch):
for name, module in _build_fake_relyt_modules().items():
@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3)
session.execute.assert_not_called()
monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None))
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.create_collection(3)
executed_sql = [str(call.args[0]) for call in session.execute.call_args_list]
assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql)
@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
# 8. delete commits session
def test_delete_commits_session(relyt_module, monkeypatch):
def test_delete_drops_table(relyt_module, monkeypatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
vector.embedding_dimension = 3
session = _FakeSession()
monkeypatch.setattr(relyt_module, "Session", lambda _client: session)
_patch_both(monkeypatch, relyt_module, session)
vector.delete()
session.commit.assert_called_once()
session.execute.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):

View File

@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
session = MagicMock()
class _SessionCtx:
class _BeginCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
vector._create_collection(3)
session.begin.assert_called_once()
sql = str(session.execute.call_args.args[0])
assert "VECTOR<FLOAT>(3)" in sql
assert "VEC_L2_DISTANCE" in sql
session.commit.assert_called_once()
tidb_module.redis_client.set.assert_called_once()
@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
def test_delete_drops_table(tidb_module, monkeypatch):
session = MagicMock()
session.execute.return_value = None
session.commit = MagicMock()
class _SessionCtx:
class _BeginCtx:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx())
mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx()))
monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm)
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
vector.delete()
drop_sql = str(session.execute.call_args.args[0])
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
session.commit.assert_called_once()
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):

View File

@ -39,7 +39,7 @@ class TestAppGenerateHandler:
"root_node_id": None,
}
arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs)
arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs)
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
@ -70,14 +70,11 @@ class TestAppGenerateHandler:
handler.wrapper(
tracer,
dummy_func,
(),
{
"app_model": mock_app_model,
"user": mock_account_user,
"args": {"workflow_id": test_workflow_id},
"invoke_from": InvokeFrom.DEBUGGER,
"streaming": False,
},
app_model=mock_app_model,
user=mock_account_user,
args={"workflow_id": test_workflow_id},
invoke_from=InvokeFrom.DEBUGGER,
streaming=False,
)
spans = memory_span_exporter.get_finished_spans()

View File

@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler:
def runner_run(self):
return "result"
handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {})
handler.wrapper(tracer, runner_run, mock_workflow_runner)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1

View File

@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments:
args = (1, 2, 3)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments:
args = ()
kwargs = {"a": 1, "b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {"b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments:
instance = MyClass()
args = (1, 2)
kwargs = {}
result = handler._extract_arguments(instance.method, args, kwargs)
result = handler._extract_arguments(instance.method, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is None
@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments:
assert func not in handler._signature_cache
handler._extract_arguments(func, (1, 2), {})
handler._extract_arguments(func, 1, 2)
assert func in handler._signature_cache
cached_sig = handler._signature_cache[func]
handler._extract_arguments(func, (3, 4), {})
handler._extract_arguments(func, 3, 4)
assert handler._signature_cache[func] is cached_sig
@ -142,7 +142,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
result = handler.wrapper(tracer, test_func, (), {})
result = handler.wrapper(tracer, test_func)
assert result == "result"
spans = memory_span_exporter.get_finished_spans()
@ -159,7 +159,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -174,7 +174,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -190,7 +190,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -208,7 +208,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -225,7 +225,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
@ -236,7 +236,7 @@ class TestSpanHandlerWrapper:
def test_func(a, b, c=10):
return a + b + c
result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3})
result = handler.wrapper(tracer, test_func, 1, 2, c=3)
assert result == 6
@ -249,7 +249,7 @@ class TestSpanHandlerWrapper:
def my_function(x):
return x * 2
result = handler.wrapper(tracer, my_function, (5,), {})
result = handler.wrapper(tracer, my_function, 5)
assert result == 10
spans = memory_span_exporter.get_finished_spans()

View File

@ -0,0 +1,138 @@
import json
from libs.pyrefly_type_coverage import (
CoverageSummary,
format_comparison_markdown,
format_summary_markdown,
parse_summary,
)
def _make_report(summary: dict) -> str:
return json.dumps({"module_reports": [], "summary": summary})
_SAMPLE_SUMMARY: dict = {
"n_modules": 100,
"n_typable": 1000,
"n_typed": 400,
"n_any": 50,
"n_untyped": 550,
"coverage": 45.0,
"strict_coverage": 40.0,
"n_functions": 200,
"n_methods": 300,
"n_function_params": 150,
"n_method_params": 250,
"n_classes": 80,
"n_attrs": 40,
"n_properties": 20,
"n_type_ignores": 10,
}
def _make_summary(
*,
n_modules: int = 100,
n_typable: int = 1000,
n_typed: int = 400,
n_any: int = 50,
n_untyped: int = 550,
coverage: float = 45.0,
strict_coverage: float = 40.0,
) -> CoverageSummary:
return {
"n_modules": n_modules,
"n_typable": n_typable,
"n_typed": n_typed,
"n_any": n_any,
"n_untyped": n_untyped,
"coverage": coverage,
"strict_coverage": strict_coverage,
}
def test_parse_summary_extracts_fields() -> None:
report_json = _make_report(_SAMPLE_SUMMARY)
result = parse_summary(report_json)
assert result["n_modules"] == 100
assert result["n_typable"] == 1000
assert result["n_typed"] == 400
assert result["n_any"] == 50
assert result["n_untyped"] == 550
assert result["coverage"] == 45.0
assert result["strict_coverage"] == 40.0
def test_parse_summary_handles_empty_input() -> None:
assert parse_summary("")["n_modules"] == 0
assert parse_summary(" ")["n_modules"] == 0
def test_parse_summary_handles_invalid_json() -> None:
assert parse_summary("not json")["n_modules"] == 0
def test_parse_summary_handles_missing_summary_key() -> None:
assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0
def test_parse_summary_handles_incomplete_summary() -> None:
partial = json.dumps({"summary": {"n_modules": 5}})
assert parse_summary(partial)["n_modules"] == 0
def test_format_summary_markdown_contains_key_metrics() -> None:
summary = _make_summary()
result = format_summary_markdown(summary)
assert "**Type coverage**" in result
assert "45.00%" in result
assert "40.00%" in result
assert "| Modules | 100 |" in result
def test_format_comparison_markdown_shows_positive_delta() -> None:
base = _make_summary()
pr = _make_summary(
n_modules=101,
n_typable=1010,
n_typed=420,
n_untyped=540,
coverage=46.53,
strict_coverage=41.58,
)
result = format_comparison_markdown(base, pr)
assert "| Base | PR | Delta |" in result
assert "+1.53%" in result
assert "+1.58%" in result
assert "+20" in result
def test_format_comparison_markdown_shows_negative_delta() -> None:
base = _make_summary()
pr = _make_summary(
n_typed=390,
n_any=60,
coverage=44.0,
strict_coverage=39.0,
)
result = format_comparison_markdown(base, pr)
assert "-1.00%" in result
assert "-10" in result
def test_format_comparison_markdown_shows_zero_delta() -> None:
summary = _make_summary()
result = format_comparison_markdown(summary, summary)
assert "0.00%" in result
assert "| 0 |" in result

View File

@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings:
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is True
assert count == 3
assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0])
def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock):
"""
@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings:
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1")
assert in_use is False
assert count == 0

View File

@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
"""Patch session_factory.create_session() to return a mock session as context manager."""
session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=False)
session.begin.return_value.__enter__ = MagicMock(return_value=session)
session.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_factory = MagicMock()
mock_factory.create_session.return_value = session
patcher = patch(f"{MODULE}.session_factory", mock_factory)
return patcher, session
class TestGetPermission:
def test_returns_permission_when_found(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
permission = MagicMock()
session.scalar.return_value = permission
with p1, p2:
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
@ -30,10 +32,10 @@ class TestGetPermission:
assert result is permission
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
session.scalar.return_value = None
with p1, p2:
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
@ -43,10 +45,10 @@ class TestGetPermission:
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
@ -54,20 +56,24 @@ class TestChangePermission:
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
)
assert result is True
session.begin.assert_called_once()
session.add.assert_called_once()
def test_updates_existing_permission(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
existing = MagicMock()
session.scalar.return_value = existing
with p1, p2:
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
)
assert result is True
session.begin.assert_called_once()
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.add.assert_not_called()

View File

@ -1427,16 +1427,7 @@ class TestRegisterService:
mock_tenant.name = "Test Workspace"
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_lookup.return_value = None
@ -1475,7 +1466,7 @@ class TestRegisterService:
status=AccountStatus.PENDING,
is_setup=True,
)
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
mock_lookup.assert_called_once_with("newuser@example.com")
def test_invite_new_member_normalizes_new_account_email(
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
@ -1486,13 +1477,7 @@ class TestRegisterService:
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
mixed_email = "Invitee@Example.com"
mock_session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_lookup.return_value = None
@ -1525,7 +1510,7 @@ class TestRegisterService:
status=AccountStatus.PENDING,
is_setup=True,
)
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
mock_lookup.assert_called_once_with(mixed_email)
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
@ -1545,16 +1530,7 @@ class TestRegisterService:
account_id="existing-user-456", email="existing@example.com", status="pending"
)
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_lookup.return_value = mock_existing_account
@ -1584,7 +1560,7 @@ class TestRegisterService:
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
mock_task_dependencies.delay.assert_called_once()
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
mock_lookup.assert_called_once_with("existing@example.com")
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
"""Test inviting a member who is already in the tenant."""

View File

@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation:
assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1
assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False
def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self):
knowledge_config = KnowledgeConfig(
indexing_technique="economy",
data_source=DataSource(
info_list=InfoList(
data_source_type="upload_file",
file_info_list=FileInfo(file_ids=["file-1"]),
)
),
process_rule=ProcessRule(
mode="hierarchical",
rules=Rule(
pre_processing_rules=[
PreProcessingRule(id="remove_extra_spaces", enabled=True),
],
segmentation=Segmentation(separator="\n", max_tokens=1024),
subchunk_segmentation=Segmentation(separator="\n", max_tokens=512),
),
),
)
DocumentService.process_rule_args_validate(knowledge_config)
assert knowledge_config.process_rule is not None
assert knowledge_config.process_rule.rules is not None
assert knowledge_config.process_rule.rules.parent_mode == "paragraph"
class TestDocumentServiceSaveDocumentWithDatasetId:
"""Unit tests for non-SQL validation branches in save_document_with_dataset_id."""

View File

@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck:
"""Test API use check when API has one binding."""
# Arrange
api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 1
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert
assert in_use is True
assert count == 1
assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0])
@patch("services.external_knowledge_service.db")
def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory):
"""Test API use check with multiple bindings."""
# Arrange
api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 10
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert
assert in_use is True
@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck:
"""Test API use check when API is not in use."""
# Arrange
api_id = "api-123"
tenant_id = "tenant-123"
mock_db.session.scalar.return_value = 0
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id)
# Assert
assert in_use is False

View File

@ -1,392 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.config_entity import TracingProviderEnum
from models.model import App, TraceAppConfig
from services.ops_service import OpsService
class TestOpsService:
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
# Arrange
mock_db.session.scalar.return_value = None
# Act
result = OpsService.get_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = None
# Act
result = OpsService.get_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = None
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
# Act & Assert
with pytest.raises(ValueError, match="Tracing config cannot be None."):
OpsService.get_tracing_app_config("app_id", "arize")
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
("provider", "default_url"),
[
("arize", "https://app.arize.com/"),
("phoenix", "https://app.phoenix.arize.com/projects/"),
("langsmith", "https://smith.langchain.com/"),
("opik", "https://www.comet.com/opik/"),
("weave", "https://wandb.ai/"),
("aliyun", "https://arms.console.aliyun.com/"),
("tencent", "https://console.cloud.tencent.com/apm"),
("mlflow", "http://localhost:5000/"),
("databricks", "https://www.databricks.com/"),
],
)
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
# Act
result = OpsService.get_tracing_app_config("app_id", provider)
# Assert
assert result["tracing_config"]["project_url"] == default_url
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
)
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
# Act
result = OpsService.get_tracing_app_config("app_id", provider)
# Assert
assert result["tracing_config"]["project_url"] == "success_url"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
# Act
result = OpsService.get_tracing_app_config("app_id", "langfuse")
# Assert
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
# Act
result = OpsService.get_tracing_app_config("app_id", "langfuse")
# Assert
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
# Act
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
# Assert
assert result == {"error": "Invalid tracing provider: invalid_provider"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.LANGFUSE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
# Assert
assert result == {"error": "Invalid Credentials"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
("provider", "config"),
[
(TracingProviderEnum.ARIZE, {}),
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
],
)
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
# Arrange
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
# Act
result = OpsService.create_tracing_app_config("app_id", provider, config)
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.LANGFUSE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
# Act
result = OpsService.create_tracing_app_config(
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
)
# Assert
assert result == {"result": "success"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = None
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
# Act
# 'project' is in other_keys for Arize
# provide an empty string for the project in the tracing_config
# create_tracing_app_config will replace it with the default from the model
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
# Assert
assert result == {"result": "success"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result == {"result": "success"}
mock_db.session.add.assert_called()
mock_db.session.commit.assert_called()
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
# Act & Assert
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_db.session.scalar.return_value = None
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = None
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
# Act & Assert
with pytest.raises(ValueError, match="Invalid Credentials"):
OpsService.update_tracing_app_config("app_id", provider, {})
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
current_config.to_dict.return_value = {"some": "data"}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result == {"some": "data"}
mock_db.session.commit.assert_called_once()
@patch("services.ops_service.db")
def test_delete_tracing_app_config_no_config(self, mock_db):
# Arrange
mock_db.session.scalar.return_value = None
# Act
result = OpsService.delete_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
def test_delete_tracing_app_config_success(self, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = trace_config
# Act
result = OpsService.delete_tracing_app_config("app_id", "arize")
# Assert
assert result is True
mock_db.session.delete.assert_called_with(trace_config)
mock_db.session.commit.assert_called_once()

View File

@ -2,7 +2,7 @@
import type { Area } from 'react-easy-crop'
import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput'
import type { AvatarProps } from '@/app/components/base/avatar'
import type { AvatarProps } from '@/app/components/base/ui/avatar'
import type { ImageFile } from '@/types/app'
import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react'
import * as React from 'react'
@ -10,10 +10,10 @@ import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider'
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
import { toast } from '@/app/components/base/ui/toast'
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'

View File

@ -6,9 +6,9 @@ import {
import { Fragment } from 'react'
import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils'
import { Avatar } from '@/app/components/base/avatar'
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
import PremiumBadge from '@/app/components/base/premium-badge'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useProviderContext } from '@/context/provider-context'
import { useRouter } from '@/next/navigation'
import { useLogout, useUserProfile } from '@/service/use-common'

View File

@ -10,9 +10,9 @@ import {
import * as React from 'react'
import { useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
import { Avatar } from '@/app/components/base/ui/avatar'
import { toast } from '@/app/components/base/ui/toast'
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'

View File

@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi
import { useDebounce } from 'ahooks'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useSelector } from '@/context/app-context'
import { SubjectType } from '@/models/access-control'
import { useSearchForWhiteListCandidates } from '@/service/access-control'
import { cn } from '@/utils/classnames'
import useAccessControlStore from '../../../../context/access-control-store'
import { Avatar } from '../../base/avatar'
import Button from '../../base/button'
import Checkbox from '../../base/checkbox'
import Input from '../../base/input'

View File

@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c
import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/ui/avatar'
import { AccessMode } from '@/models/access-control'
import { useAppWhiteListSubjects } from '@/service/access-control'
import useAccessControlStore from '../../../../context/access-control-store'
import { Avatar } from '../../base/avatar'
import Loading from '../../base/loading'
import Tooltip from '../../base/tooltip'
import AddMemberOrGroupDialog from './add-member-or-group-pop'

View File

@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
},
}))
vi.mock('@/app/components/base/avatar', () => ({
vi.mock('@/app/components/base/ui/avatar', () => ({
Avatar: ({ name }: { name: string }) => <div data-testid="avatar">{name}</div>,
}))

View File

@ -7,11 +7,11 @@ import {
useCallback,
useMemo,
} from 'react'
import { Avatar } from '@/app/components/base/avatar'
import Chat from '@/app/components/base/chat/chat'
import { useChat } from '@/app/components/base/chat/chat/hooks'
import { getLastAnswer } from '@/app/components/base/chat/utils'
import { useFeatures } from '@/app/components/base/features/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useAppContext } from '@/context/app-context'
import { useDebugConfigurationContext } from '@/context/debug-configuration'

View File

@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import { memo, useCallback, useImperativeHandle, useMemo } from 'react'
import { useStore as useAppStore } from '@/app/components/app/store'
import { Avatar } from '@/app/components/base/avatar'
import Chat from '@/app/components/base/chat/chat'
import { useChat } from '@/app/components/base/chat/chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils'
import { useFeatures } from '@/app/components/base/features/hooks'
import { Avatar } from '@/app/components/base/ui/avatar'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useAppContext } from '@/context/app-context'
import { useDebugConfigurationContext } from '@/context/debug-configuration'

View File

@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon'
import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form'
import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions'
import { Markdown } from '@/app/components/base/markdown'
import { Avatar } from '@/app/components/base/ui/avatar'
import { InputVarType } from '@/app/components/workflow/types'
import {
AppSourceType,
@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w
import { TransferMethod } from '@/types/app'
import { cn } from '@/utils/classnames'
import { formatBooleanInputs } from '@/utils/model-config'
import { Avatar } from '../../avatar'
import Chat from '../chat'
import { useChat } from '../chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'

View File

@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form'
import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar'
import { Markdown } from '@/app/components/base/markdown'
import { Avatar } from '@/app/components/base/ui/avatar'
import { InputVarType } from '@/app/components/workflow/types'
import {
AppSourceType,
@ -23,7 +24,6 @@ import {
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
import { TransferMethod } from '@/types/app'
import { cn } from '@/utils/classnames'
import { Avatar } from '../../avatar'
import Chat from '../chat'
import { useChat } from '../chat/hooks'
import { getLastAnswer, isValidGeneratedAnswer } from '../utils'

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import { Avatar } from '../index'
import { Avatar } from '..'
describe('Avatar', () => {
describe('Rendering', () => {

View File

@ -53,8 +53,8 @@ function AvatarRoot({
return (
<BaseAvatar.Root
className={cn(
'relative inline-flex shrink-0 select-none items-center justify-center overflow-hidden rounded-full bg-primary-600',
isAvatarPresetSize(size) && avatarSizeClasses[size].root,
'relative inline-flex shrink-0 items-center justify-center overflow-hidden rounded-full bg-primary-600 select-none',
avatarSizeClasses[size].root,
className,
)}
style={resolvedStyle}
@ -104,7 +104,7 @@ function AvatarImage({
}: AvatarImageProps) {
return (
<BaseAvatar.Image
className={cn('absolute inset-0 size-full object-cover', className)}
className={cn('inset-0 absolute size-full object-cover', className)}
{...props}
/>
)

View File

@ -4,13 +4,13 @@ import { useDebounceFn } from 'ahooks'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Input from '@/app/components/base/input'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { DatasetPermission } from '@/models/datasets'
import { cn } from '@/utils/classnames'

View File

@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { resetUser } from '@/app/components/base/amplitude/utils'
import { Avatar } from '@/app/components/base/avatar'
import PremiumBadge from '@/app/components/base/premium-badge'
import ThemeSwitcher from '@/app/components/base/theme-switcher'
import { Avatar } from '@/app/components/base/ui/avatar'
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { IS_CLOUD_EDITION } from '@/config'

View File

@ -2,7 +2,7 @@
import type { InvitationResult } from '@/models/common'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import { Avatar } from '@/app/components/base/ui/avatar'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import { NUM_INFINITE } from '@/app/components/billing/config'
import { Plan } from '@/app/components/billing/type'

View File

@ -3,9 +3,9 @@ import type { FC } from 'react'
import * as React from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Input from '@/app/components/base/input'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
import { Avatar } from '@/app/components/base/ui/avatar'
import { useMembers } from '@/service/use-common'
import { cn } from '@/utils/classnames'

View File

@ -0,0 +1,46 @@
import { render } from '@testing-library/react'
import { API_PREFIX } from '@/config'
import BlockIcon, { VarBlockIcon } from '../block-icon'
import { BlockEnum } from '../types'
describe('BlockIcon', () => {
it('renders the default workflow icon container for regular nodes', () => {
const { container } = render(<BlockIcon type={BlockEnum.Start} size="xs" className="extra-class" />)
const iconContainer = container.firstElementChild
expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class')
expect(iconContainer?.querySelector('svg')).toBeInTheDocument()
})
it('normalizes protected plugin icon urls for tool-like nodes', () => {
const { container } = render(
<BlockIcon
type={BlockEnum.Tool}
toolIcon="/foo/workspaces/current/plugin/icon/plugin-tool.png"
/>,
)
const iconContainer = container.firstElementChild as HTMLElement
const backgroundIcon = iconContainer.querySelector('div') as HTMLElement
expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500')
expect(backgroundIcon.style.backgroundImage).toContain(
`${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`,
)
})
})
describe('VarBlockIcon', () => {
it('renders the compact icon variant without the default container wrapper', () => {
const { container } = render(
<VarBlockIcon
type={BlockEnum.Answer}
className="custom-var-icon"
/>,
)
expect(container.querySelector('.custom-var-icon')).toBeInTheDocument()
expect(container.querySelector('svg')).toBeInTheDocument()
expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument()
})
})

View File

@ -0,0 +1,39 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { WorkflowContextProvider } from '../context'
import { useStore, useWorkflowStore } from '../store'
const StoreConsumer = () => {
const showSingleRunPanel = useStore(s => s.showSingleRunPanel)
const store = useWorkflowStore()
return (
<button onClick={() => store.getState().setShowSingleRunPanel(!showSingleRunPanel)}>
{showSingleRunPanel ? 'open' : 'closed'}
</button>
)
}
describe('WorkflowContextProvider', () => {
it('provides the workflow store to descendants and keeps the same store across rerenders', async () => {
const user = userEvent.setup()
const { rerender } = render(
<WorkflowContextProvider>
<StoreConsumer />
</WorkflowContextProvider>,
)
expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'closed' }))
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
rerender(
<WorkflowContextProvider>
<StoreConsumer />
</WorkflowContextProvider>,
)
expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument()
})
})

View File

@ -0,0 +1,67 @@
import type { Edge, Node } from '../types'
import { render, screen } from '@testing-library/react'
import { useStoreApi } from 'reactflow'
import { useDatasetsDetailStore } from '../datasets-detail-store/store'
import WorkflowWithDefaultContext from '../index'
import { BlockEnum } from '../types'
import { useWorkflowHistoryStore } from '../workflow-history-store'
const nodes: Node[] = [
{
id: 'node-start',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Start',
desc: '',
type: BlockEnum.Start,
},
},
]
const edges: Edge[] = [
{
id: 'edge-1',
source: 'node-start',
target: 'node-end',
sourceHandle: null,
targetHandle: null,
type: 'custom',
data: {
sourceType: BlockEnum.Start,
targetType: BlockEnum.End,
},
},
]
const ContextConsumer = () => {
const { store, shortcutsEnabled } = useWorkflowHistoryStore()
const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length)
const reactFlowStore = useStoreApi()
return (
<div>
{`history:${store.getState().nodes.length}`}
{` shortcuts:${String(shortcutsEnabled)}`}
{` datasets:${datasetCount}`}
{` reactflow:${String(!!reactFlowStore)}`}
</div>
)
}
describe('WorkflowWithDefaultContext', () => {
it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => {
render(
<WorkflowWithDefaultContext
nodes={nodes}
edges={edges}
>
<ContextConsumer />
</WorkflowWithDefaultContext>,
)
expect(
screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'),
).toBeInTheDocument()
})
})

View File

@ -0,0 +1,51 @@
import { render, screen } from '@testing-library/react'
import ShortcutsName from '../shortcuts-name'
describe('ShortcutsName', () => {
const originalNavigator = globalThis.navigator
afterEach(() => {
Object.defineProperty(globalThis, 'navigator', {
value: originalNavigator,
writable: true,
configurable: true,
})
})
it('renders mac-friendly key labels and style variants', () => {
Object.defineProperty(globalThis, 'navigator', {
value: { userAgent: 'Macintosh' },
writable: true,
configurable: true,
})
const { container } = render(
<ShortcutsName
keys={['ctrl', 'shift', 's']}
bgColor="white"
textColor="secondary"
/>,
)
expect(screen.getByText('⌘')).toBeInTheDocument()
expect(screen.getByText('⇧')).toBeInTheDocument()
expect(screen.getByText('s')).toBeInTheDocument()
expect(container.querySelector('.system-kbd')).toHaveClass(
'bg-components-kbd-bg-white',
'text-text-tertiary',
)
})
it('keeps raw key names on non-mac systems', () => {
Object.defineProperty(globalThis, 'navigator', {
value: { userAgent: 'Windows NT' },
writable: true,
configurable: true,
})
render(<ShortcutsName keys={['ctrl', 'alt']} />)
expect(screen.getByText('ctrl')).toBeInTheDocument()
expect(screen.getByText('alt')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,97 @@
import type { Edge, Node } from '../types'
import type { WorkflowHistoryState } from '../workflow-history-store'
import { render, renderHook, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { BlockEnum } from '../types'
import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store'
const nodes: Node[] = [
{
id: 'node-1',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Start',
desc: '',
type: BlockEnum.Start,
selected: true,
},
selected: true,
},
]
const edges: Edge[] = [
{
id: 'edge-1',
source: 'node-1',
target: 'node-2',
sourceHandle: null,
targetHandle: null,
type: 'custom',
selected: true,
data: {
sourceType: BlockEnum.Start,
targetType: BlockEnum.End,
},
},
]
const HistoryConsumer = () => {
const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore()
return (
<button onClick={() => setShortcutsEnabled(!shortcutsEnabled)}>
{`nodes:${store.getState().nodes.length} shortcuts:${String(shortcutsEnabled)}`}
</button>
)
}
describe('WorkflowHistoryProvider', () => {
it('provides workflow history state and shortcut toggles', async () => {
const user = userEvent.setup()
render(
<WorkflowHistoryProvider
nodes={nodes}
edges={edges}
>
<HistoryConsumer />
</WorkflowHistoryProvider>,
)
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' }))
expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument()
})
it('sanitizes selected flags when history state is replaced through the exposed store api', () => {
const wrapper = ({ children }: { children: React.ReactNode }) => (
<WorkflowHistoryProvider
nodes={nodes}
edges={edges}
>
{children}
</WorkflowHistoryProvider>
)
const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper })
const nextState: WorkflowHistoryState = {
workflowHistoryEvent: undefined,
workflowHistoryEventMeta: undefined,
nodes,
edges,
}
result.current.store.setState(nextState)
expect(result.current.store.getState().nodes[0].data.selected).toBe(false)
expect(result.current.store.getState().edges[0].selected).toBe(false)
})
it('throws when consumed outside the provider', () => {
expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow(
'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider',
)
})
})

View File

@ -0,0 +1,140 @@
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useGetLanguage } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import AllTools from '../all-tools'
import { createGlobalPublicStoreState, createToolProvider } from './factories'
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useGetLanguage: vi.fn(),
}))
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(),
}))
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
useMarketplacePlugins: vi.fn(),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({
useMCPToolAvailability: () => ({
allowed: true,
}),
}))
vi.mock('@/utils/var', async importOriginal => ({
...(await importOriginal<typeof import('@/utils/var')>()),
getMarketplaceUrl: () => 'https://marketplace.test/tools',
}))
const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins)
const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore)
const mockUseGetLanguage = vi.mocked(useGetLanguage)
const mockUseTheme = vi.mocked(useTheme)
const createMarketplacePluginsMock = () => ({
plugins: [],
total: 0,
resetPlugins: vi.fn(),
queryPlugins: vi.fn(),
queryPluginsWithDebounced: vi.fn(),
cancelQueryPluginsWithDebounced: vi.fn(),
isLoading: false,
isFetchingNextPage: false,
hasNextPage: false,
fetchNextPage: vi.fn(),
page: 0,
})
describe('AllTools', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false)))
mockUseGetLanguage.mockReturnValue('en_US')
mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType<typeof useTheme>)
mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock())
})
it('filters tools by the active tab', async () => {
const user = userEvent.setup()
render(
<AllTools
searchText=""
tags={[]}
onSelect={vi.fn()}
buildInTools={[createToolProvider({
id: 'provider-built-in',
label: { en_US: 'Built In Provider', zh_Hans: 'Built In Provider' },
})]}
customTools={[createToolProvider({
id: 'provider-custom',
type: 'custom',
label: { en_US: 'Custom Provider', zh_Hans: 'Custom Provider' },
})]}
workflowTools={[]}
mcpTools={[]}
/>,
)
expect(screen.getByText('Built In Provider')).toBeInTheDocument()
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
await user.click(screen.getByText('workflow.tabs.customTool'))
expect(screen.getByText('Custom Provider')).toBeInTheDocument()
expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument()
})
it('filters the rendered tools by the search text', () => {
render(
<AllTools
searchText="report"
tags={[]}
onSelect={vi.fn()}
buildInTools={[
createToolProvider({
id: 'provider-report',
label: { en_US: 'Report Toolkit', zh_Hans: 'Report Toolkit' },
}),
createToolProvider({
id: 'provider-other',
label: { en_US: 'Other Toolkit', zh_Hans: 'Other Toolkit' },
}),
]}
customTools={[]}
workflowTools={[]}
mcpTools={[]}
/>,
)
expect(screen.getByText('Report Toolkit')).toBeInTheDocument()
expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument()
})
it('shows the empty state when no tool matches the current filter', async () => {
render(
<AllTools
searchText="missing"
tags={[]}
onSelect={vi.fn()}
buildInTools={[]}
customTools={[]}
workflowTools={[]}
mcpTools={[]}
/>,
)
await waitFor(() => {
expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,79 @@
import type { NodeDefault } from '../../types'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { BlockEnum } from '../../types'
import Blocks from '../blocks'
import { BlockClassificationEnum } from '../types'
const runtimeState = vi.hoisted(() => ({
nodes: [] as Array<{ data: { type?: BlockEnum } }>,
}))
vi.mock('reactflow', () => ({
useStoreApi: () => ({
getState: () => ({
getNodes: () => runtimeState.nodes,
}),
}),
}))
const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({
metaData: {
classification,
sort: 0,
type,
title,
author: 'Dify',
description: `${title} description`,
},
defaultValue: {},
checkValid: () => ({ isValid: true }),
})
describe('Blocks', () => {
beforeEach(() => {
runtimeState.nodes = []
})
it('renders grouped blocks, filters duplicate knowledge-base nodes, and selects a block', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }]
render(
<Blocks
searchText=""
onSelect={onSelect}
availableBlocksTypes={[BlockEnum.LLM, BlockEnum.LoopEnd, BlockEnum.KnowledgeBase]}
blocks={[
createBlock(BlockEnum.LLM, 'LLM'),
createBlock(BlockEnum.LoopEnd, 'Exit Loop', BlockClassificationEnum.Logic),
createBlock(BlockEnum.KnowledgeBase, 'Knowledge Retrieval'),
]}
/>,
)
expect(screen.getByText('LLM')).toBeInTheDocument()
expect(screen.getByText('Exit Loop')).toBeInTheDocument()
expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument()
expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument()
await user.click(screen.getByText('LLM'))
expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM)
})
it('shows the empty state when no block matches the search', () => {
render(
<Blocks
searchText="missing"
onSelect={vi.fn()}
availableBlocksTypes={[BlockEnum.LLM]}
blocks={[createBlock(BlockEnum.LLM, 'LLM')]}
/>,
)
expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument()
})
})

View File

@ -0,0 +1,101 @@
import type { ToolWithProvider } from '../../types'
import type { Plugin } from '@/app/components/plugins/types'
import type { Tool } from '@/app/components/tools/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { CollectionType } from '@/app/components/tools/types'
import { defaultSystemFeatures } from '@/types/feature'
export const createTool = (
name: string,
label: string,
description = `${label} description`,
): Tool => ({
name,
author: 'author',
label: {
en_US: label,
zh_Hans: label,
},
description: {
en_US: description,
zh_Hans: description,
},
parameters: [],
labels: [],
output_schema: {},
})
export const createToolProvider = (
overrides: Partial<ToolWithProvider> = {},
): ToolWithProvider => ({
id: 'provider-1',
name: 'provider-one',
author: 'Provider Author',
description: {
en_US: 'Provider description',
zh_Hans: 'Provider description',
},
icon: 'icon',
icon_dark: 'icon-dark',
label: {
en_US: 'Provider One',
zh_Hans: 'Provider One',
},
type: CollectionType.builtIn,
team_credentials: {},
is_team_authorization: false,
allow_delete: false,
labels: [],
plugin_id: 'plugin-1',
tools: [createTool('tool-a', 'Tool A')],
meta: { version: '1.0.0' } as ToolWithProvider['meta'],
plugin_unique_identifier: 'plugin-1@1.0.0',
...overrides,
})
export const createPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
type: 'plugin',
org: 'org',
author: 'author',
name: 'Plugin One',
plugin_id: 'plugin-1',
version: '1.0.0',
latest_version: '1.0.0',
latest_package_identifier: 'plugin-1@1.0.0',
icon: 'icon',
verified: true,
label: {
en_US: 'Plugin One',
zh_Hans: 'Plugin One',
},
brief: {
en_US: 'Plugin description',
zh_Hans: 'Plugin description',
},
description: {
en_US: 'Plugin description',
zh_Hans: 'Plugin description',
},
introduction: 'Plugin introduction',
repository: 'https://example.com/plugin',
category: PluginCategoryEnum.tool,
tags: [],
badges: [],
install_count: 0,
endpoint: {
settings: [],
},
verification: {
authorized_category: 'community',
},
from: 'github',
...overrides,
})
export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({
systemFeatures: {
...defaultSystemFeatures,
enable_marketplace: enableMarketplace,
},
setSystemFeatures: vi.fn(),
})

Some files were not shown because too many files have changed in this diff Show More