Merge branch 'main' into jzh

This commit is contained in:
JzoNg 2026-04-10 17:12:57 +08:00
commit d1ca468c1e
53 changed files with 1570 additions and 910 deletions

View File

@ -0,0 +1,118 @@
name: Comment with Pyrefly Type Coverage
on:
workflow_run:
workflows:
- Pyrefly Type Coverage
types:
- completed
permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
pull-requests: write
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Checkout default branch (trusted code)
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Download type coverage artifact
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }},
});
const match = artifacts.data.artifacts.find((artifact) =>
artifact.name === 'pyrefly_type_coverage'
);
if (!match) {
throw new Error('pyrefly_type_coverage artifact not found');
}
const download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: match.id,
archive_format: 'zip',
});
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
- name: Unzip artifact
run: unzip -o pyrefly_type_coverage.zip
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} > /tmp/type_coverage_comment.md
- name: Post comment
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
let prNumber = null;
try {
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
} catch (err) {
const prs = context.payload.workflow_run.pull_requests || [];
if (prs.length > 0 && prs[0].number) {
prNumber = prs[0].number;
}
}
if (!prNumber) {
throw new Error('PR number not found in artifact or workflow_run payload');
}
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const marker = '### Pyrefly Type Coverage';
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -0,0 +1,120 @@
name: Pyrefly Type Coverage
on:
pull_request:
paths:
- 'api/**/*.py'
permissions:
contents: read
jobs:
pyrefly-type-coverage:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly report on PR branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
echo '{}' > /tmp/pyrefly_report_pr.json
- name: Save helper script from base branch
run: |
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly report on base branch
run: |
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
echo '{}' > /tmp/pyrefly_report_base.json
- name: Generate coverage comparison
id: coverage
run: |
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
--base /tmp/pyrefly_report_base.json \
< /tmp/pyrefly_report_pr.json)"
{
echo "### Pyrefly Type Coverage"
echo ""
echo "$comment_body"
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload type coverage artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: pyrefly_type_coverage
path: |
pr_report.json
base_report.json
pr_number.txt
- name: Comment PR with type coverage
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const marker = '### Pyrefly Type Coverage';
let body;
try {
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
} catch {
body = `${marker}\n\n_Coverage report unavailable._`;
}
const prNumber = context.payload.pull_request.number;
// Update existing comment if one exists, otherwise create new
const { data: comments } = await github.rest.issues.listComments({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
});
const existing = comments.find(c => c.body.startsWith(marker));
if (existing) {
await github.rest.issues.updateComment({
comment_id: existing.id,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
} else {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
}

View File

@ -2,7 +2,6 @@ import base64
import secrets
import click
from sqlalchemy.orm import sessionmaker
from constants.languages import languages
from extensions.ext_database import db
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account = db.session.merge(account)
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try:
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
account = db.session.merge(account)
account.email = normalized_new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command("create-tenant", help="Create account and tenant.")

View File

@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models import Account
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
email = register_data.get("email", "")
normalized_email = email.lower()
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
token = AccountService.send_reset_password_email(
account=account,
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AccountNotFound()
if account:
account = db.session.merge(account)
self._update_existing_account(account, password_hashed, salt)
db.session.commit()
else:
raise AccountNotFound()
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
def _update_existing_account(self, account, password_hashed, salt):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()

View File

@ -4,7 +4,6 @@ import urllib.parse
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Unauthorized
from configs import dify_config
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account: Account | None = Account.get_by_openid(provider, user_info.id)
if not account:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
return account

View File

@ -8,7 +8,6 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import supported_language
@ -562,8 +561,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email
else:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(args.email)
if account is None:
raise AccountNotFound()
email_for_sending = account.email

View File

@ -3,7 +3,6 @@ import secrets
from flask import request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
account = AccountService.get_account_by_email_with_case_fallback(request_email)
if account is None:
raise AuthenticationFailedError()
else:
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if account:
self._update_existing_account(account, password_hashed, salt)
else:
raise AuthenticationFailedError()
if account:
account = db.session.merge(account)
self._update_existing_account(account, password_hashed, salt)
db.session.commit()
else:
raise AuthenticationFailedError()
return {"result": "success"}

View File

@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request: ReceiveRequestT
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
def __init__(
self,
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
):
self.request_id = request_id
self.request_meta = request_meta

View File

@ -31,7 +31,6 @@ ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
type AnyFunction = Callable[..., Any]
class RequestParams(BaseModel):

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
)
@classmethod
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
def _get_user(cls, user_id: str) -> EndUser | Account:
"""
get the user by user id
"""

View File

@ -6,7 +6,6 @@ import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Union
from uuid import uuid4
import httpx
@ -158,7 +157,7 @@ class ToolFileManager:
return tool_file
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
"""
get file binary
@ -176,7 +175,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
"""
get file binary

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
import sqlalchemy as sa
from graphon.runtime import VariablePool
@ -100,7 +100,7 @@ class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
_builtin_tools_labels: dict[str, I18nObject | None] = {}
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
@ -190,7 +190,7 @@ class ToolManager:
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: str | None = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
"""
get the tool runtime
@ -398,7 +398,7 @@ class ToolManager:
agent_tool: AgentToolEntity,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: "VariablePool | None" = None,
) -> Tool:
"""
get the agent tool runtime
@ -442,7 +442,7 @@ class ToolManager:
workflow_tool: WorkflowToolRuntimeSpec,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional["VariablePool"] = None,
variable_pool: "VariablePool | None" = None,
) -> Tool:
"""
get the workflow tool runtime
@ -634,7 +634,7 @@ class ToolManager:
cls._builtin_providers_loaded = False
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
def get_tool_label(cls, tool_name: str) -> I18nObject | None:
"""
get the tool label
@ -1052,7 +1052,7 @@ class ToolManager:
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
variable_pool: Optional["VariablePool"],
variable_pool: "VariablePool | None",
tool_configurations: Mapping[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:

View File

@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
handler = _get_handler_instance(handler_class or SpanHandler)
tracer = get_tracer(__name__)
return handler.wrapper(
tracer=tracer,
wrapped=func,
args=args,
kwargs=kwargs,
)
return handler.wrapper(tracer, func, *args, **kwargs)
return cast(Callable[P, R], wrapper)

View File

@ -1,8 +1,8 @@
import inspect
from collections.abc import Callable, Mapping
from collections.abc import Callable
from typing import Any
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
class SpanHandler:
@ -16,9 +16,9 @@ class SpanHandler:
exceptions. Handlers can override the wrapper method to customize behavior.
"""
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
"""
Build the span name from the wrapped function.
@ -29,11 +29,11 @@ class SpanHandler:
"""
return f"{wrapped.__module__}.{wrapped.__qualname__}"
def _extract_arguments[T](
def _extract_arguments[**P, R](
self,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> dict[str, Any] | None:
"""
Extract function arguments using inspect.signature.
@ -59,13 +59,13 @@ class SpanHandler:
except Exception:
return None
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""
Fully control the wrapper behavior.

View File

@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
def wrapper[T](
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., T],
args: tuple[object, ...],
kwargs: Mapping[str, object],
) -> T:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@ -1,8 +1,7 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any
from collections.abc import Callable
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
from extensions.otel.decorators.handler import SpanHandler
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
class WorkflowAppRunnerHandler(SpanHandler):
"""Span handler for ``WorkflowAppRunner.run``."""
def wrapper(
def wrapper[**P, R](
self,
tracer: Any,
wrapped: Callable[..., Any],
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
) -> Any:
tracer: Tracer,
wrapped: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
try:
arguments = self._extract_arguments(wrapped, args, kwargs)
arguments = self._extract_arguments(wrapped, *args, **kwargs)
if not arguments:
return wrapped(*args, **kwargs)

View File

@ -0,0 +1,145 @@
"""Helpers for generating type-coverage summaries from pyrefly report output."""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import TypedDict
class CoverageSummary(TypedDict):
n_modules: int
n_typable: int
n_typed: int
n_any: int
n_untyped: int
coverage: float
strict_coverage: float
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
_EMPTY_SUMMARY: CoverageSummary = {
"n_modules": 0,
"n_typable": 0,
"n_typed": 0,
"n_any": 0,
"n_untyped": 0,
"coverage": 0.0,
"strict_coverage": 0.0,
}
def parse_summary(report_json: str) -> CoverageSummary:
"""Extract the summary section from ``pyrefly report`` JSON output.
Returns an empty summary when *report_json* is empty or malformed so that
the CI workflow can degrade gracefully instead of crashing.
"""
if not report_json or not report_json.strip():
return _EMPTY_SUMMARY.copy()
try:
data = json.loads(report_json)
except json.JSONDecodeError:
return _EMPTY_SUMMARY.copy()
summary = data.get("summary")
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
return _EMPTY_SUMMARY.copy()
return {
"n_modules": summary["n_modules"],
"n_typable": summary["n_typable"],
"n_typed": summary["n_typed"],
"n_any": summary["n_any"],
"n_untyped": summary["n_untyped"],
"coverage": summary["coverage"],
"strict_coverage": summary["strict_coverage"],
}
def format_summary_markdown(summary: CoverageSummary) -> str:
"""Format a single coverage summary as a Markdown table."""
return (
"| Metric | Value |\n"
"| --- | ---: |\n"
f"| Modules | {summary['n_modules']} |\n"
f"| Typable symbols | {summary['n_typable']:,} |\n"
f"| Typed symbols | {summary['n_typed']:,} |\n"
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
f"| Any symbols | {summary['n_any']:,} |\n"
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
)
def format_comparison_markdown(
base: CoverageSummary,
pr: CoverageSummary,
) -> str:
"""Format a comparison between base and PR coverage as Markdown."""
coverage_delta = pr["coverage"] - base["coverage"]
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
typed_delta = pr["n_typed"] - base["n_typed"]
untyped_delta = pr["n_untyped"] - base["n_untyped"]
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
sign = "+" if value > 0 else ""
return f"{sign}{value:{fmt}}"
lines = [
"| Metric | Base | PR | Delta |",
"| --- | ---: | ---: | ---: |",
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
(
f"| Strict coverage | {base['strict_coverage']:.2f}% "
f"| {pr['strict_coverage']:.2f}% "
f"| {_fmt_delta(strict_delta)}% |"
),
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
(
f"| Modules | {base['n_modules']} "
f"| {pr['n_modules']} "
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
),
]
return "\n".join(lines)
def main() -> int:
"""Read pyrefly report JSON from stdin and print a Markdown summary.
Accepts an optional ``--base <file>`` argument. When provided, the output
includes a base-vs-PR comparison table.
"""
args = sys.argv[1:]
base_file: str | None = None
if "--base" in args:
idx = args.index("--base")
if idx + 1 >= len(args):
sys.stderr.write("error: --base requires a file path\n")
return 1
base_file = args[idx + 1]
pr_report = sys.stdin.read()
pr_summary = parse_summary(pr_report)
if base_file is not None:
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
base_summary = parse_summary(base_text)
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
else:
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -4,7 +4,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
"""
Get workflow type from app mode.
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property

View File

@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from core.db.session_factory import session_factory
class InvitationData(TypedDict):
@ -800,19 +801,19 @@ class AccountService:
return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
with session_factory.create_session() as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
@ -1516,8 +1517,7 @@ class RegisterService:
check_workspace_member_invite_permission(tenant.id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")

View File

@ -4,7 +4,7 @@ import logging
import threading
import uuid
from collections.abc import Callable, Generator, Mapping
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -88,7 +88,7 @@ class AppGenerateService:
def generate(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@ -356,11 +356,11 @@ class AppGenerateService:
def generate_more_like_this(
cls,
app_model: App,
user: Union[Account, EndUser],
user: Account | EndUser,
message_id: str,
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping, Generator]:
) -> Mapping | Generator:
"""
Generate more like this
:param app_model: app model

View File

@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
import json
from datetime import UTC, datetime
from typing import Any, Union
from typing import Any
from celery.result import AsyncResult
from sqlalchemy import select
@ -50,7 +50,7 @@ class AsyncWorkflowService:
@classmethod
def trigger_workflow_async(
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
) -> AsyncTriggerResponse:
"""
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
@ -177,7 +177,7 @@ class AsyncWorkflowService:
@classmethod
def reinvoke_trigger(
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
) -> AsyncTriggerResponse:
"""
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK

View File

@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Union, cast
from typing import Any, cast
from urllib.parse import urlparse
import httpx
@ -195,9 +195,7 @@ class ExternalDatasetService:
raise ValueError(f"{parameter.get('name')} is required")
@staticmethod
def process_external_api(
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
) -> httpx.Response:
def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
"""
do http request depending on api bundle
"""

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Iterator, Sequence
from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile
from typing import Literal, Union
from typing import Literal
from zipfile import ZIP_DEFLATED, ZipFile
from graphon.file import helpers as file_helpers
@ -52,7 +52,7 @@ class FileService:
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser],
user: Account | EndUser,
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict, Union
from typing import Any, TypedDict
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -626,7 +626,7 @@ class ModelLoadBalancingService:
def _get_credential_schema(
self, provider_configuration: ProviderConfiguration
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
) -> ModelCredentialSchema | ProviderCredentialSchema:
"""Get form schemas."""
if provider_configuration.provider.model_credential_schema:
return provider_configuration.provider.model_credential_schema

View File

@ -1,14 +1,13 @@
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models.account import TenantPluginPermission
class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with sessionmaker(bind=db.engine).begin() as session:
with session_factory.create_session() as session:
return session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)
@ -19,7 +18,7 @@ class PluginPermissionService:
install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission,
):
with sessionmaker(bind=db.engine).begin() as session:
with session_factory.create_session() as session, session.begin():
permission = session.scalar(
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Union
from typing import Any
from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
@ -17,7 +17,7 @@ class PipelineGenerateService:
def generate(
cls,
pipeline: Pipeline,
user: Union[Account, EndUser],
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,

View File

@ -5,7 +5,7 @@ import threading
import time
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Union, cast
from typing import Any, cast
from uuid import uuid4
from flask_login import current_user
@ -1387,7 +1387,7 @@ class RagPipelineService:
"uninstalled_recommended_plugins": uninstalled_plugin_list,
}
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser):
"""
Retry error document
"""

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Mapping
from typing import Any, Union
from typing import Any
from pydantic import TypeAdapter, ValidationError
from yarl import URL
@ -69,7 +69,7 @@ class ToolTransformService:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
"""
repack provider

View File

@ -7,15 +7,16 @@ with appropriate retry policies and error handling.
import logging
from datetime import UTC, datetime
from typing import Any
from typing import Any, NotRequired
from celery import shared_task
from graphon.runtime import GraphRuntimeState
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
from core.app.layers.timeslice_layer import TimeSliceLayer
@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
logger = logging.getLogger(__name__)
class WorkflowGeneratorArgsDict(TypedDict):
inputs: dict[str, Any]
files: list[Any]
_skip_prepare_user_inputs: bool
workflow_id: NotRequired[str]
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
def execute_workflow_professional(task_data_dict: dict[str, Any]):
"""Execute workflow for professional tier with highest priority"""
@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
)
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]:
def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict:
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
args: dict[str, Any] = {
return {
"inputs": dict(trigger_data.inputs),
"files": list(trigger_data.files),
SKIP_PREPARE_USER_INPUTS_KEY: True,
"_skip_prepare_user_inputs": True,
}
return args
def _execute_workflow_common(

View File

@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_fetches_account_with_original_email(
self,
mock_get_reset_data,
mock_revoke_token,
mock_db,
mock_get_account,
mock_update_account,
app,
@ -126,6 +128,7 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_db.session.merge.return_value = mock_account
wraps_features = SimpleNamespace(enable_email_password_login=True)
with (
@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -437,7 +437,10 @@ class TestAccountGeneration:
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
with patch("services.account_service.session_factory") as mock_factory:
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
assert result is expected_account
assert mock_session.execute.call_count == 2

View File

@ -335,10 +335,12 @@ class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.db")
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
def test_reset_password_success(
self,
mock_get_tenants,
mock_db,
mock_get_account,
mock_revoke_token,
mock_get_data,
@ -356,6 +358,7 @@ class TestForgotPasswordResetApi:
# Arrange
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
mock_get_account.return_value = mock_account
mock_db.session.merge.return_value = mock_account
mock_get_tenants.return_value = [MagicMock()]
# Act

View File

@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi:
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@patch("controllers.web.forgot_password.sessionmaker")
def test_should_normalize_email_before_sending(
self,
mock_session_cls,
mock_extract_ip,
mock_rate_limit,
mock_get_account,
@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi:
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_send_mail.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
with app.test_request_context(
"/web/forgot-password",
method="POST",
json={"email": "User@Example.com", "language": "zh-Hans"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_get_account.assert_called_once_with("User@Example.com")
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
mock_extract_ip.assert_called_once()
mock_rate_limit.assert_called_once_with("127.0.0.1")
@ -153,14 +148,14 @@ class TestForgotPasswordResetApi:
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.db")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_should_fetch_account_with_fallback(
self,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_db,
mock_get_account,
mock_update_account,
app,
@ -168,29 +163,27 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
mock_account = MagicMock()
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
mock_db.session.merge.return_value = mock_account
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_get_account.assert_called_once_with("User@Example.com")
mock_update_account.assert_called_once()
mock_revoke_token.assert_called_once_with("token-123")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.sessionmaker")
@patch("controllers.web.forgot_password.db")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@ -199,7 +192,7 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_get_reset_data,
mock_revoke_token,
mock_session_cls,
mock_db,
mock_token_bytes,
mock_hash_password,
app,
@ -207,20 +200,18 @@ class TestForgotPasswordResetApi:
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
account = MagicMock()
mock_get_account.return_value = account
mock_session = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
mock_db.session.merge.return_value = account
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
with app.test_request_context(
"/web/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("reset-token")

View File

@ -1,239 +1,193 @@
from __future__ import annotations
import json
from typing import Any, cast
from unittest.mock import ANY, MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.models.document import Document
from models.dataset import Dataset
from models.dataset import Dataset, DatasetQuery
from services.hit_testing_service import HitTestingService
class TestHitTestingService:
"""Test suite for HitTestingService"""
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
tenant_id = str(uuid4())
created_by = str(uuid4())
ds = Dataset(
tenant_id=kwargs.get("tenant_id", tenant_id),
name=kwargs.get("name", "test-dataset"),
created_by=kwargs.get("created_by", created_by),
provider=provider,
)
db_session.add(ds)
db_session.commit()
db_session.refresh(ds)
return ds
# ===== Utility Method Tests =====
class TestHitTestingService:
# ── Utility methods (pure logic, no DB) ────────────────────────────
def test_escape_query_for_search_should_escape_double_quotes(self):
"""Test that escape_query_for_search escapes double quotes correctly"""
# Arrange
query = 'test "query" with quotes'
expected = 'test \\"query\\" with quotes'
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == expected
assert result == 'test \\"query\\" with quotes'
def test_hit_testing_args_check_should_pass_with_valid_query(self):
"""Test that hit_testing_args_check passes with a valid query"""
# Arrange
args = {"query": "valid query"}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
HitTestingService.hit_testing_args_check({"query": "valid query"})
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
"""Test that hit_testing_args_check passes with valid attachment_ids"""
# Arrange
args = {"attachment_ids": ["id1", "id2"]}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
# Arrange
args = {}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query or attachment_ids is required" in str(exc_info.value)
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
HitTestingService.hit_testing_args_check({})
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
# Arrange
args = {"query": "a" * 251}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query cannot exceed 250 characters" in str(exc_info.value)
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check({"query": "a" * 251})
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
# Arrange
args = {"attachment_ids": "not a list"}
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Attachment_ids must be a list" in str(exc_info.value)
# ===== Response Formatting Tests =====
# ── Response formatting ────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
"""Test that compact_retrieve_response formats the response correctly"""
# Arrange
query = "test query"
mock_doc = MagicMock(spec=Document)
documents = [mock_doc]
mock_record = MagicMock()
mock_record.model_dump.return_value = {"content": "formatted content"}
mock_format.return_value = [mock_record]
# Act
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 1
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
mock_format.assert_called_once_with(documents)
mock_format.assert_called_once_with([mock_doc])
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "external"
query = "test query"
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
self, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="external")
documents = [
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
]
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
result = cast(
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert len(result["records"]) == 2
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
documents = [{"content": "c1"}]
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
self, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="vendor")
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
result = cast(
dict[str, Any],
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
# ===== External Retrieve Tests =====
# ── External retrieve (real DB) ────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
dataset.provider = "external"
query = 'test "query"'
def test_external_retrieve_should_succeed_for_external_provider(
self, mock_ext_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="external")
account_id = str(uuid4())
account = MagicMock()
account.id = "account_id"
account.id = account_id
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
# Act
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
HitTestingService.external_retrieve(
dataset=dataset,
query=query,
query='test "query"',
account=account,
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
# Verify call to RetrievalService.external_retrieve with escaped query
mock_ext_retrieve.assert_called_once_with(
dataset_id="dataset_id",
dataset_id=dataset.id,
query='test \\"query\\"',
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
)
# Verify DatasetQuery record was added and committed
mock_add.assert_called_once()
mock_commit.assert_called_once()
db_session_with_containers.expire_all()
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
"""Test that external_retrieve returns empty results immediately if provider is not external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers, provider="vendor")
account = MagicMock()
# Act
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
# ===== Retrieve Tests =====
# ── Retrieve (real DB) ─────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve uses default model when retrieval_model is not provided"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
def test_retrieve_should_use_default_model_when_none_provided(
self, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
dataset.retrieval_model = None
query = "test query"
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
mock_retrieve.return_value = []
# Act
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
mock_retrieve.assert_called_once()
# Verify top_k from default_retrieval_model (4)
assert mock_retrieve.call_args.kwargs["top_k"] == 4
mock_commit.assert_called_once()
db_session_with_containers.expire_all()
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_metadata_filtering(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
retrieval_model = {
"search_method": "semantic_search",
@ -242,29 +196,27 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
)
# Assert
mock_get_meta.assert_called_once()
mock_retrieve.assert_called_once()
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
retrieval_model = {
@ -274,37 +226,27 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response: condition returned but no IDs
mock_get_meta.return_value = ({}, "condition_string")
# Act
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset,
query=query,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
),
)
# Assert
assert result["records"] == []
mock_retrieve.assert_not_called()
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
attachment_ids = ["att1", "att2"]
retrieval_model = {
@ -315,21 +257,19 @@ class TestHitTestingService:
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset,
query=query,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
attachment_ids=attachment_ids,
)
# Assert
mock_retrieve.assert_called_once_with(
retrieval_method=ANY,
dataset_id="dataset_id",
query=query,
dataset_id=dataset.id,
query="test query",
attachment_ids=attachment_ids,
top_k=4,
score_threshold=0.0,
@ -338,26 +278,27 @@ class TestHitTestingService:
weights=None,
document_ids_filter=None,
)
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
called_query = mock_add.call_args[0][0]
query_content = json.loads(called_query.content)
# Verify DatasetQuery was persisted with correct content structure
db_session_with_containers.expire_all()
latest = db_session_with_containers.scalar(
select(DatasetQuery)
.where(DatasetQuery.dataset_id == dataset.id)
.order_by(DatasetQuery.created_at.desc())
.limit(1)
)
assert latest is not None
query_content = json.loads(latest.content)
assert len(query_content) == 3 # 1 text + 2 images
assert query_content[0]["content_type"] == "text_query"
assert query_content[1]["content_type"] == "image_query"
assert query_content[1]["content"] == "att1"
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve passes reranking and threshold parameters correctly"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
retrieval_model = {
"search_method": "hybrid_search",
@ -371,12 +312,14 @@ class TestHitTestingService:
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
)
# Assert
mock_retrieve.assert_called_once()
kwargs = mock_retrieve.call_args.kwargs
assert kwargs["score_threshold"] == 0.5

View File

@ -0,0 +1,363 @@
from __future__ import annotations
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import TracingProviderEnum
from models.model import TraceAppConfig
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.ops_service import OpsService
from tests.test_containers_integration_tests.helpers import generate_valid_password
class TestOpsService:
@pytest.fixture
def mock_external_service_dependencies(self):
with (
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager.for_tenant") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
mock_model_instance = mock_model_manager.return_value
mock_model_instance.get_default_model_instance.return_value = None
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
yield {
"feature_service": mock_feature_service,
"enterprise_service": mock_enterprise_service,
"model_manager": mock_model_manager,
"account_feature_service": mock_account_feature_service,
}
@pytest.fixture
def mock_ops_trace_manager(self):
with patch("services.ops_service.OpsTraceManager") as mock:
yield mock
def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies):
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
},
account,
)
return app, account
_SENTINEL = object()
def _insert_trace_config(
self,
db_session: Session,
app_id: str,
provider: str,
tracing_config: dict | None | object = _SENTINEL,
) -> TraceAppConfig:
trace_config = TraceAppConfig(
app_id=app_id,
tracing_provider=provider,
tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"},
)
db_session.add(trace_config)
db_session.commit()
return trace_config
# ── get_tracing_app_config ─────────────────────────────────────────
def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize")
assert result is None
def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
fake_app_id = str(uuid.uuid4())
self._insert_trace_config(db_session_with_containers, fake_app_id, "arize")
result = OpsService.get_tracing_app_config(fake_app_id, "arize")
assert result is None
def test_get_tracing_app_config_none_config(
self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager
):
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None)
with pytest.raises(ValueError, match="Tracing config cannot be None."):
OpsService.get_tracing_app_config(app.id, "arize")
@pytest.mark.parametrize(
("provider", "default_url"),
[
("arize", "https://app.arize.com/"),
("phoenix", "https://app.phoenix.arize.com/projects/"),
("langsmith", "https://smith.langchain.com/"),
("opik", "https://www.comet.com/opik/"),
("weave", "https://wandb.ai/"),
("aliyun", "https://arms.console.aliyun.com/"),
("tencent", "https://console.cloud.tencent.com/apm"),
("mlflow", "http://localhost:5000/"),
("databricks", "https://www.databricks.com/"),
],
)
def test_get_tracing_app_config_providers_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.obfuscated_decrypt_token.return_value = {}
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, provider)
result = OpsService.get_tracing_app_config(app.id, provider)
assert result is not None
assert result["tracing_config"]["project_url"] == default_url
@pytest.mark.parametrize(
"provider",
["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"],
)
def test_get_tracing_app_config_providers_success(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"}
mock_otm.get_trace_config_project_url.return_value = "success_url"
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, provider)
result = OpsService.get_tracing_app_config(app.id, provider)
assert result is not None
assert result["tracing_config"]["project_url"] == "success_url"
def test_get_tracing_app_config_langfuse_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_otm.get_trace_config_project_key.return_value = "key"
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
result = OpsService.get_tracing_app_config(app.id, "langfuse")
assert result is not None
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
def test_get_tracing_app_config_langfuse_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
result = OpsService.get_tracing_app_config(app.id, "langfuse")
assert result is not None
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
# ── create_tracing_app_config ──────────────────────────────────────
def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
assert result == {"error": "Invalid tracing provider: invalid_provider"}
def test_create_tracing_app_config_invalid_credentials(
self, db_session_with_containers: Session, mock_ops_trace_manager
):
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
result = OpsService.create_tracing_app_config(
str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}
)
assert result == {"error": "Invalid Credentials"}
@pytest.mark.parametrize(
("provider", "config"),
[
(TracingProviderEnum.ARIZE, {}),
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
],
)
def test_create_tracing_app_config_project_url_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config
):
# Existing config causes the service to return None before reaching the DB insert
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(provider))
result = OpsService.create_tracing_app_config(app.id, provider, config)
assert result is None
def test_create_tracing_app_config_langfuse_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_key.return_value = "key"
mock_otm.encrypt_tracing_config.return_value = {}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(
app.id,
TracingProviderEnum.LANGFUSE,
{"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"},
)
assert result == {"result": "success"}
def test_create_tracing_app_config_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result is None
def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
assert result is None
def test_create_tracing_app_config_with_empty_other_keys(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# "project" is in other_keys for Arize; providing "" triggers default substitution
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.side_effect = Exception("no url")
mock_otm.encrypt_tracing_config.return_value = {}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""})
assert result == {"result": "success"}
def test_create_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.check_trace_config_is_effective.return_value = True
mock_otm.get_trace_config_project_url.return_value = "http://project_url"
mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"}
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result == {"result": "success"}
# ── update_tracing_app_config ──────────────────────────────────────
def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
assert result is None
def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
fake_app_id = str(uuid.uuid4())
self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE))
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {})
assert result is None
def test_update_tracing_app_config_invalid_credentials(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.encrypt_tracing_config.return_value = {}
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.check_trace_config_is_effective.return_value = False
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
with pytest.raises(ValueError, match="Invalid Credentials"):
OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
def test_update_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
with patch("services.ops_service.OpsTraceManager") as mock_otm:
mock_otm.encrypt_tracing_config.return_value = {"updated": "config"}
mock_otm.decrypt_tracing_config.return_value = {}
mock_otm.check_trace_config_is_effective.return_value = True
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
assert result is not None
assert result["app_id"] == app.id
# ── delete_tracing_app_config ──────────────────────────────────────
def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session):
result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize")
assert result is None
def test_delete_tracing_app_config_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
self._insert_trace_config(db_session_with_containers, app.id, "arize")
result = OpsService.delete_tracing_app_config(app.id, "arize")
assert result is True
remaining = db_session_with_containers.scalar(
select(TraceAppConfig)
.where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize")
.limit(1)
)
assert remaining is None

View File

@ -233,11 +233,10 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE
# Verify database state
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.password is not None
assert result.password_salt is not None
refreshed = db_session_with_containers.get(Account, result.id)
assert refreshed is not None
assert refreshed.password is not None
assert refreshed.password_salt is not None
def test_authenticate_account_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -414,9 +413,8 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE
# Verify database state
db_session_with_containers.refresh(result)
assert result.id is not None
refreshed = db_session_with_containers.get(Account, result.id)
assert refreshed is not None
def test_get_user_through_email_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -233,15 +233,20 @@ class TestCheckEmailUnique:
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
session = MagicMock()
mock_session = MagicMock()
first = MagicMock()
first.scalar_one_or_none.return_value = None
second = MagicMock()
expected_account = MagicMock()
second.scalar_one_or_none.return_value = expected_account
session.execute.side_effect = [first, second]
mock_session.execute.side_effect = [first, second]
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
mock_factory = MagicMock()
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
with patch("services.account_service.session_factory", mock_factory):
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
assert result is expected_account
assert session.execute.call_count == 2
assert mock_session.execute.call_count == 2

View File

@ -39,7 +39,7 @@ class TestAppGenerateHandler:
"root_node_id": None,
}
arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs)
arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs)
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
@ -70,14 +70,11 @@ class TestAppGenerateHandler:
handler.wrapper(
tracer,
dummy_func,
(),
{
"app_model": mock_app_model,
"user": mock_account_user,
"args": {"workflow_id": test_workflow_id},
"invoke_from": InvokeFrom.DEBUGGER,
"streaming": False,
},
app_model=mock_app_model,
user=mock_account_user,
args={"workflow_id": test_workflow_id},
invoke_from=InvokeFrom.DEBUGGER,
streaming=False,
)
spans = memory_span_exporter.get_finished_spans()

View File

@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler:
def runner_run(self):
return "result"
handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {})
handler.wrapper(tracer, runner_run, mock_workflow_runner)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1

View File

@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments:
args = (1, 2, 3)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments:
args = ()
kwargs = {"a": 1, "b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {"b": 2, "c": 3}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments:
instance = MyClass()
args = (1, 2)
kwargs = {}
result = handler._extract_arguments(instance.method, args, kwargs)
result = handler._extract_arguments(instance.method, *args, **kwargs)
assert result is not None
assert result["a"] == 1
@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments:
args = (1,)
kwargs = {}
result = handler._extract_arguments(func, args, kwargs)
result = handler._extract_arguments(func, *args, **kwargs)
assert result is None
@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments:
assert func not in handler._signature_cache
handler._extract_arguments(func, (1, 2), {})
handler._extract_arguments(func, 1, 2)
assert func in handler._signature_cache
cached_sig = handler._signature_cache[func]
handler._extract_arguments(func, (3, 4), {})
handler._extract_arguments(func, 3, 4)
assert handler._signature_cache[func] is cached_sig
@ -142,7 +142,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
result = handler.wrapper(tracer, test_func, (), {})
result = handler.wrapper(tracer, test_func)
assert result == "result"
spans = memory_span_exporter.get_finished_spans()
@ -159,7 +159,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -174,7 +174,7 @@ class TestSpanHandlerWrapper:
def test_func():
return "result"
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -190,7 +190,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -208,7 +208,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@ -225,7 +225,7 @@ class TestSpanHandlerWrapper:
raise ValueError("test error")
with pytest.raises(ValueError, match="test error"):
handler.wrapper(tracer, test_func, (), {})
handler.wrapper(tracer, test_func)
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
@ -236,7 +236,7 @@ class TestSpanHandlerWrapper:
def test_func(a, b, c=10):
return a + b + c
result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3})
result = handler.wrapper(tracer, test_func, 1, 2, c=3)
assert result == 6
@ -249,7 +249,7 @@ class TestSpanHandlerWrapper:
def my_function(x):
return x * 2
result = handler.wrapper(tracer, my_function, (5,), {})
result = handler.wrapper(tracer, my_function, 5)
assert result == 10
spans = memory_span_exporter.get_finished_spans()

View File

@ -0,0 +1,138 @@
import json
from libs.pyrefly_type_coverage import (
CoverageSummary,
format_comparison_markdown,
format_summary_markdown,
parse_summary,
)
def _make_report(summary: dict) -> str:
return json.dumps({"module_reports": [], "summary": summary})
_SAMPLE_SUMMARY: dict = {
"n_modules": 100,
"n_typable": 1000,
"n_typed": 400,
"n_any": 50,
"n_untyped": 550,
"coverage": 45.0,
"strict_coverage": 40.0,
"n_functions": 200,
"n_methods": 300,
"n_function_params": 150,
"n_method_params": 250,
"n_classes": 80,
"n_attrs": 40,
"n_properties": 20,
"n_type_ignores": 10,
}
def _make_summary(
*,
n_modules: int = 100,
n_typable: int = 1000,
n_typed: int = 400,
n_any: int = 50,
n_untyped: int = 550,
coverage: float = 45.0,
strict_coverage: float = 40.0,
) -> CoverageSummary:
return {
"n_modules": n_modules,
"n_typable": n_typable,
"n_typed": n_typed,
"n_any": n_any,
"n_untyped": n_untyped,
"coverage": coverage,
"strict_coverage": strict_coverage,
}
def test_parse_summary_extracts_fields() -> None:
report_json = _make_report(_SAMPLE_SUMMARY)
result = parse_summary(report_json)
assert result["n_modules"] == 100
assert result["n_typable"] == 1000
assert result["n_typed"] == 400
assert result["n_any"] == 50
assert result["n_untyped"] == 550
assert result["coverage"] == 45.0
assert result["strict_coverage"] == 40.0
def test_parse_summary_handles_empty_input() -> None:
assert parse_summary("")["n_modules"] == 0
assert parse_summary(" ")["n_modules"] == 0
def test_parse_summary_handles_invalid_json() -> None:
assert parse_summary("not json")["n_modules"] == 0
def test_parse_summary_handles_missing_summary_key() -> None:
assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0
def test_parse_summary_handles_incomplete_summary() -> None:
partial = json.dumps({"summary": {"n_modules": 5}})
assert parse_summary(partial)["n_modules"] == 0
def test_format_summary_markdown_contains_key_metrics() -> None:
summary = _make_summary()
result = format_summary_markdown(summary)
assert "**Type coverage**" in result
assert "45.00%" in result
assert "40.00%" in result
assert "| Modules | 100 |" in result
def test_format_comparison_markdown_shows_positive_delta() -> None:
base = _make_summary()
pr = _make_summary(
n_modules=101,
n_typable=1010,
n_typed=420,
n_untyped=540,
coverage=46.53,
strict_coverage=41.58,
)
result = format_comparison_markdown(base, pr)
assert "| Base | PR | Delta |" in result
assert "+1.53%" in result
assert "+1.58%" in result
assert "+20" in result
def test_format_comparison_markdown_shows_negative_delta() -> None:
base = _make_summary()
pr = _make_summary(
n_typed=390,
n_any=60,
coverage=44.0,
strict_coverage=39.0,
)
result = format_comparison_markdown(base, pr)
assert "-1.00%" in result
assert "-10" in result
def test_format_comparison_markdown_shows_zero_delta() -> None:
summary = _make_summary()
result = format_comparison_markdown(summary, summary)
assert "0.00%" in result
assert "| 0 |" in result

View File

@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
"""Patch session_factory.create_session() to return a mock session as context manager."""
session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=False)
session.begin.return_value.__enter__ = MagicMock(return_value=session)
session.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_factory = MagicMock()
mock_factory.create_session.return_value = session
patcher = patch(f"{MODULE}.session_factory", mock_factory)
return patcher, session
class TestGetPermission:
def test_returns_permission_when_found(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
permission = MagicMock()
session.scalar.return_value = permission
with p1, p2:
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
@ -30,10 +32,10 @@ class TestGetPermission:
assert result is permission
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
session.scalar.return_value = None
with p1, p2:
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
@ -43,10 +45,10 @@ class TestGetPermission:
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
@ -54,20 +56,24 @@ class TestChangePermission:
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
)
assert result is True
session.begin.assert_called_once()
session.add.assert_called_once()
def test_updates_existing_permission(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
existing = MagicMock()
session.scalar.return_value = existing
with p1, p2:
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
)
assert result is True
session.begin.assert_called_once()
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.add.assert_not_called()

View File

@ -1427,16 +1427,7 @@ class TestRegisterService:
mock_tenant.name = "Test Workspace"
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_lookup.return_value = None
@ -1475,7 +1466,7 @@ class TestRegisterService:
status=AccountStatus.PENDING,
is_setup=True,
)
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
mock_lookup.assert_called_once_with("newuser@example.com")
def test_invite_new_member_normalizes_new_account_email(
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
@ -1486,13 +1477,7 @@ class TestRegisterService:
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
mixed_email = "Invitee@Example.com"
mock_session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_lookup.return_value = None
@ -1525,7 +1510,7 @@ class TestRegisterService:
status=AccountStatus.PENDING,
is_setup=True,
)
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
mock_lookup.assert_called_once_with(mixed_email)
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
@ -1545,16 +1530,7 @@ class TestRegisterService:
account_id="existing-user-456", email="existing@example.com", status="pending"
)
# Mock database queries - need to mock the sessionmaker query
mock_session = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
with (
patch("services.account_service.sessionmaker", mock_sessionmaker),
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
):
mock_lookup.return_value = mock_existing_account
@ -1584,7 +1560,7 @@ class TestRegisterService:
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
mock_task_dependencies.delay.assert_called_once()
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
mock_lookup.assert_called_once_with("existing@example.com")
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
"""Test inviting a member who is already in the tenant."""

View File

@ -1,392 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.config_entity import TracingProviderEnum
from models.model import App, TraceAppConfig
from services.ops_service import OpsService
class TestOpsService:
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
# Arrange
mock_db.session.scalar.return_value = None
# Act
result = OpsService.get_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = None
# Act
result = OpsService.get_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = None
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
# Act & Assert
with pytest.raises(ValueError, match="Tracing config cannot be None."):
OpsService.get_tracing_app_config("app_id", "arize")
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
("provider", "default_url"),
[
("arize", "https://app.arize.com/"),
("phoenix", "https://app.phoenix.arize.com/projects/"),
("langsmith", "https://smith.langchain.com/"),
("opik", "https://www.comet.com/opik/"),
("weave", "https://wandb.ai/"),
("aliyun", "https://arms.console.aliyun.com/"),
("tencent", "https://console.cloud.tencent.com/apm"),
("mlflow", "http://localhost:5000/"),
("databricks", "https://www.databricks.com/"),
],
)
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
# Act
result = OpsService.get_tracing_app_config("app_id", provider)
# Assert
assert result["tracing_config"]["project_url"] == default_url
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
)
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
# Act
result = OpsService.get_tracing_app_config("app_id", provider)
# Assert
assert result["tracing_config"]["project_url"] == "success_url"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
# Act
result = OpsService.get_tracing_app_config("app_id", "langfuse")
# Assert
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
trace_config.tracing_config = {"some": "config"}
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = trace_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
# Act
result = OpsService.get_tracing_app_config("app_id", "langfuse")
# Assert
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
# Act
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
# Assert
assert result == {"error": "Invalid tracing provider: invalid_provider"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.LANGFUSE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
# Assert
assert result == {"error": "Invalid Credentials"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
@pytest.mark.parametrize(
("provider", "config"),
[
(TracingProviderEnum.ARIZE, {}),
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
],
)
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
# Arrange
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
# Act
result = OpsService.create_tracing_app_config("app_id", provider, config)
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.LANGFUSE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
# Act
result = OpsService.create_tracing_app_config(
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
)
# Assert
assert result == {"result": "success"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = None
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
# Act
# 'project' is in other_keys for Arize
# provide an empty string for the project in the tracing_config
# create_tracing_app_config will replace it with the default from the model
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
# Assert
assert result == {"result": "success"}
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = app
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
# Act
result = OpsService.create_tracing_app_config("app_id", provider, {})
# Assert
assert result == {"result": "success"}
mock_db.session.add.assert_called()
mock_db.session.commit.assert_called()
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
# Act & Assert
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
mock_db.session.scalar.return_value = None
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = None
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result is None
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
# Act & Assert
with pytest.raises(ValueError, match="Invalid Credentials"):
OpsService.update_tracing_app_config("app_id", provider, {})
@patch("services.ops_service.db")
@patch("services.ops_service.OpsTraceManager")
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
# Arrange
provider = TracingProviderEnum.ARIZE
current_config = MagicMock(spec=TraceAppConfig)
current_config.to_dict.return_value = {"some": "data"}
app = MagicMock(spec=App)
app.tenant_id = "tenant_id"
mock_db.session.scalar.return_value = current_config
mock_db.session.get.return_value = app
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
# Act
result = OpsService.update_tracing_app_config("app_id", provider, {})
# Assert
assert result == {"some": "data"}
mock_db.session.commit.assert_called_once()
@patch("services.ops_service.db")
def test_delete_tracing_app_config_no_config(self, mock_db):
# Arrange
mock_db.session.scalar.return_value = None
# Act
result = OpsService.delete_tracing_app_config("app_id", "arize")
# Assert
assert result is None
@patch("services.ops_service.db")
def test_delete_tracing_app_config_success(self, mock_db):
# Arrange
trace_config = MagicMock(spec=TraceAppConfig)
mock_db.session.scalar.return_value = trace_config
# Act
result = OpsService.delete_tracing_app_config("app_id", "arize")
# Assert
assert result is True
mock_db.session.delete.assert_called_with(trace_config)
mock_db.session.commit.assert_called_once()

View File

@ -0,0 +1,126 @@
import type { ReactNode } from 'react'
import type { LoopVariableMap, NodeTracing } from '@/types/workflow'
import { fireEvent, render, screen } from '@testing-library/react'
import { BlockEnum } from '../../../types'
import LoopResultPanel from '../loop-result-panel'
const mockCodeEditor = vi.hoisted(() => vi.fn())
const mockTracingPanel = vi.hoisted(() => vi.fn())
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
__esModule: true,
default: (props: { title: ReactNode, value: unknown }) => {
mockCodeEditor(props)
return (
<section data-testid="code-editor">
<div>{props.title}</div>
<div>{JSON.stringify(props.value)}</div>
</section>
)
},
}))
vi.mock('@/app/components/workflow/run/tracing-panel', () => ({
__esModule: true,
default: (props: { list: NodeTracing[], className?: string }) => {
mockTracingPanel(props)
return <div data-testid="tracing-panel">{props.list.length}</div>
},
}))
const createNodeTracing = (id: string, executionMetadata?: NonNullable<NodeTracing['execution_metadata']>): NodeTracing => ({
id,
index: 0,
predecessor_node_id: '',
node_id: `node-${id}`,
node_type: BlockEnum.Code,
title: `Node ${id}`,
inputs: {},
inputs_truncated: false,
process_data: {},
process_data_truncated: false,
outputs: {},
outputs_truncated: false,
status: 'succeeded',
error: '',
elapsed_time: 0,
execution_metadata: executionMetadata,
metadata: {
iterator_length: 0,
iterator_index: 0,
loop_length: 0,
loop_index: 0,
},
created_at: 0,
created_by: {
id: 'user-1',
name: 'Tester',
email: 'tester@example.com',
},
finished_at: 0,
})
describe('LoopResultPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Loop variables should be resolved by the actual run key, not the rendered row position.
describe('Loop Variable Resolution', () => {
it('should read loop variables by the actual loop index when rows are compacted', () => {
const loopVariableMap: LoopVariableMap = {
2: { item: 'alpha' },
}
render(
<LoopResultPanel
list={[[
createNodeTracing('loop-2-step-1', {
total_tokens: 0,
total_price: 0,
currency: 'USD',
loop_index: 2,
}),
]]}
onBack={vi.fn()}
loopVariableMap={loopVariableMap}
/>,
)
fireEvent.click(screen.getByText('workflow.singleRun.loop 1'))
expect(screen.getByTestId('code-editor')).toHaveTextContent('{"item":"alpha"}')
expect(mockCodeEditor).toHaveBeenCalledWith(expect.objectContaining({
value: loopVariableMap[2],
}))
})
it('should read loop variables by parallel run id when available', () => {
const loopVariableMap: LoopVariableMap = {
'parallel-1': { item: 'beta' },
}
render(
<LoopResultPanel
list={[[
createNodeTracing('parallel-step-1', {
total_tokens: 0,
total_price: 0,
currency: 'USD',
parallel_mode_run_id: 'parallel-1',
}),
]]}
onBack={vi.fn()}
loopVariableMap={loopVariableMap}
/>,
)
fireEvent.click(screen.getByText('workflow.singleRun.loop 1'))
expect(screen.getByTestId('code-editor')).toHaveTextContent('{"item":"beta"}')
expect(mockCodeEditor).toHaveBeenCalledWith(expect.objectContaining({
value: loopVariableMap['parallel-1'],
}))
})
})
})

View File

@ -19,6 +19,18 @@ import { cn } from '@/utils/classnames'
const i18nPrefix = 'singleRun'
const getLoopRunKey = (loop: NodeTracing[], fallbackIndex: number) => {
const executionMetadata = loop[0]?.execution_metadata
if (executionMetadata?.parallel_mode_run_id !== undefined)
return executionMetadata.parallel_mode_run_id
if (executionMetadata?.loop_index !== undefined)
return String(executionMetadata.loop_index)
return String(fallbackIndex)
}
type Props = {
list: NodeTracing[][]
onBack: () => void
@ -42,10 +54,8 @@ const LoopResultPanel: FC<Props> = ({
}))
}, [])
const countLoopDuration = (loop: NodeTracing[], loopDurationMap: LoopDurationMap): string => {
const loopRunIndex = loop[0]?.execution_metadata?.loop_index as number
const loopRunId = loop[0]?.execution_metadata?.parallel_mode_run_id
const loopItem = loopDurationMap[loopRunId || loopRunIndex]
const countLoopDuration = (loop: NodeTracing[], index: number, loopDurationMap: LoopDurationMap): string => {
const loopItem = loopDurationMap[getLoopRunKey(loop, index)]
const duration = loopItem
return `${(duration && duration > 0.01) ? duration.toFixed(2) : 0.01}s`
}
@ -59,13 +69,13 @@ const LoopResultPanel: FC<Props> = ({
return <RiErrorWarningLine className="h-4 w-4 text-text-destructive" />
if (isRunning)
return <RiLoader2Line className="h-3.5 w-3.5 animate-spin text-primary-600" />
return <RiLoader2Line className="text-primary-600 h-3.5 w-3.5 animate-spin" />
return (
<>
{hasDurationMap && (
<div className="system-xs-regular text-text-tertiary">
{countLoopDuration(loop, loopDurationMap)}
{countLoopDuration(loop, index, loopDurationMap)}
</div>
)}
<RiArrowRightSLine
@ -98,7 +108,7 @@ const LoopResultPanel: FC<Props> = ({
<div
className={cn(
'flex w-full cursor-pointer items-center justify-between px-3',
expandedLoops[index] ? 'pb-2 pt-3' : 'py-3',
expandedLoops[index] ? 'pt-3 pb-2' : 'py-3',
'rounded-xl text-left',
)}
onClick={() => toggleLoop(index)}
@ -107,7 +117,7 @@ const LoopResultPanel: FC<Props> = ({
<div className="flex h-4 w-4 shrink-0 items-center justify-center rounded-[5px] border-divider-subtle bg-util-colors-cyan-cyan-500">
<Loop className="h-3 w-3 text-text-primary-on-surface" />
</div>
<span className="system-sm-semibold-uppercase grow text-text-primary">
<span className="grow system-sm-semibold-uppercase text-text-primary">
{t(`${i18nPrefix}.loop`, { ns: 'workflow' })}
{' '}
{index + 1}
@ -129,14 +139,14 @@ const LoopResultPanel: FC<Props> = ({
)}
>
{
loopVariableMap?.[index] && (
loopVariableMap?.[getLoopRunKey(loop, index)] && (
<div className="p-2 pb-0">
<CodeEditor
readOnly
title={<div>{t('nodes.loop.loopVariables', { ns: 'workflow' }).toLocaleUpperCase()}</div>}
language={CodeLanguage.json}
height={112}
value={loopVariableMap[index]}
value={loopVariableMap[getLoopRunKey(loop, index)]}
isJSONStringifyBeauty
/>
</div>

View File

@ -1,6 +1,6 @@
import type { NodeTracing } from '@/types/workflow'
import { noop } from 'es-toolkit/function'
import format from '..'
import format, { addChildrenToIterationNode } from '..'
import graphToLogStruct from '../../graph-to-log-struct'
describe('iteration', () => {
@ -9,15 +9,48 @@ describe('iteration', () => {
it('result should have no nodes in iteration node', () => {
expect(result.find(item => !!item.execution_metadata?.iteration_id)).toBeUndefined()
})
// test('iteration should put nodes in details', () => {
// expect(result).toEqual([
// startNode,
// {
// ...iterationNode,
// details: [
// [iterations[0], iterations[1]],
// ],
// },
// ])
// })
it('should place the first child of a new iteration at a new record when its index is missing', () => {
const parent = { node_id: 'iter1', node_type: 'iteration', execution_metadata: {} } as unknown as NodeTracing
const child0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing
const streaming = { node_id: 'code', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing
const result = addChildrenToIterationNode(parent, [child0, streaming])
expect(result.details![0]).toEqual([child0])
expect(result.details![1]).toEqual([streaming])
})
it('should keep missing iteration_index items in the current record when the node has not restarted', () => {
const parent = {
node_id: 'iter1',
node_type: 'iteration',
execution_metadata: {
iteration_duration_map: { 0: 1.2, 1: 0.4 },
},
} as unknown as NodeTracing
const child0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing
const child1 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 1 } } as unknown as NodeTracing
const streaming = { node_id: 'tool', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing
const result = addChildrenToIterationNode(parent, [child0, child1, streaming])
expect(result.details![0]).toEqual([child0])
expect(result.details![1]).toEqual([child1, streaming])
})
it('should not jump to the latest iteration when an earlier item is missing iteration_index', () => {
const parent = {
node_id: 'iter1',
node_type: 'iteration',
execution_metadata: {
iteration_duration_map: { 0: 1.2, 1: 0.4 },
},
} as unknown as NodeTracing
const code0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing
const tool = { node_id: 'tool', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing
const code1 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 1 } } as unknown as NodeTracing
const result = addChildrenToIterationNode(parent, [code0, tool, code1])
expect(result.details![0]).toEqual([code0, tool])
expect(result.details![1]).toEqual([code1])
})
})

View File

@ -4,15 +4,31 @@ import formatParallelNode from '../parallel'
export function addChildrenToIterationNode(iterationNode: NodeTracing, childrenNodes: NodeTracing[]): NodeTracing {
const details: NodeTracing[][] = []
childrenNodes.forEach((item, index) => {
let lastResolvedIndex = -1
childrenNodes.forEach((item) => {
if (!item.execution_metadata)
return
const { iteration_index = 0 } = item.execution_metadata
const runIndex: number = iteration_index !== undefined ? iteration_index : index
const { iteration_index } = item.execution_metadata
let runIndex: number
if (iteration_index !== undefined) {
runIndex = iteration_index
}
else if (lastResolvedIndex >= 0) {
const currentGroup = details[lastResolvedIndex] || []
const seenSameNodeInCurrentGroup = currentGroup.some(node => node.node_id === item.node_id)
runIndex = seenSameNodeInCurrentGroup ? lastResolvedIndex + 1 : lastResolvedIndex
}
else {
runIndex = 0
}
if (!details[runIndex])
details[runIndex] = []
details[runIndex].push(item)
lastResolvedIndex = runIndex
})
return {
...iterationNode,

View File

@ -1,6 +1,6 @@
import type { NodeTracing } from '@/types/workflow'
import { noop } from 'es-toolkit/function'
import format from '..'
import format, { addChildrenToLoopNode } from '..'
import graphToLogStruct from '../../graph-to-log-struct'
describe('loop', () => {
@ -21,4 +21,48 @@ describe('loop', () => {
},
])
})
it('should place the first child of a new loop run at a new record when its index is missing', () => {
const parent = { node_id: 'loop1', node_type: 'loop', execution_metadata: {} } as unknown as NodeTracing
const child0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing
const streaming = { node_id: 'code', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing
const result = addChildrenToLoopNode(parent, [child0, streaming])
expect(result.details![0]).toEqual([child0])
expect(result.details![1]).toEqual([streaming])
})
it('should keep missing loop_index items in the current record when the node has not restarted', () => {
const parent = {
node_id: 'loop1',
node_type: 'loop',
execution_metadata: {
loop_duration_map: { 0: 1.2, 1: 0.4 },
},
} as unknown as NodeTracing
const child0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing
const child1 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 1 } } as unknown as NodeTracing
const streaming = { node_id: 'tool', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing
const result = addChildrenToLoopNode(parent, [child0, child1, streaming])
expect(result.details![0]).toEqual([child0])
expect(result.details![1]).toEqual([child1, streaming])
})
it('should not jump to the latest loop when an earlier item is missing loop_index', () => {
const parent = {
node_id: 'loop1',
node_type: 'loop',
execution_metadata: {
loop_duration_map: { 0: 1.2, 1: 0.4 },
},
} as unknown as NodeTracing
const code0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing
const tool = { node_id: 'tool', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing
const code1 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 1 } } as unknown as NodeTracing
const result = addChildrenToLoopNode(parent, [code0, tool, code1])
expect(result.details![0]).toEqual([code0, tool])
expect(result.details![1]).toEqual([code1])
})
})

View File

@ -3,20 +3,49 @@ import { BlockEnum } from '@/app/components/workflow/types'
import formatParallelNode from '../parallel'
export function addChildrenToLoopNode(loopNode: NodeTracing, childrenNodes: NodeTracing[]): NodeTracing {
const details: NodeTracing[][] = []
const detailsByKey = new Map<string, NodeTracing[]>()
let lastResolvedIndex = -1
const order: string[] = []
const ensureGroup = (key: string) => {
const group = detailsByKey.get(key)
if (group)
return group
const newGroup: NodeTracing[] = []
detailsByKey.set(key, newGroup)
order.push(key)
return newGroup
}
childrenNodes.forEach((item) => {
if (!item.execution_metadata)
return
const { parallel_mode_run_id, loop_index = 0 } = item.execution_metadata
const runIndex: number = (parallel_mode_run_id || loop_index) as number
if (!details[runIndex])
details[runIndex] = []
const { parallel_mode_run_id, loop_index } = item.execution_metadata
let runIndex: number | string
details[runIndex].push(item)
if (parallel_mode_run_id !== undefined) {
runIndex = parallel_mode_run_id
}
else if (loop_index !== undefined) {
runIndex = loop_index
}
else if (lastResolvedIndex >= 0) {
const currentGroup = detailsByKey.get(String(lastResolvedIndex)) || []
const seenSameNodeInCurrentGroup = currentGroup.some(node => node.node_id === item.node_id)
runIndex = seenSameNodeInCurrentGroup ? lastResolvedIndex + 1 : lastResolvedIndex
}
else {
runIndex = 0
}
ensureGroup(String(runIndex)).push(item)
if (typeof runIndex === 'number')
lastResolvedIndex = runIndex
})
return {
...loopNode,
details,
details: order.map(key => detailsByKey.get(key) || []),
}
}

View File

@ -11088,11 +11088,6 @@
"count": 1
}
},
"app/components/workflow/run/loop-log/loop-result-panel.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 3
}
},
"app/components/workflow/run/loop-result-panel.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 4