mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 23:18:39 +08:00
Merge branch 'main' into jzh
This commit is contained in:
commit
d1ca468c1e
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
118
.github/workflows/pyrefly-type-coverage-comment.yml
vendored
Normal file
@ -0,0 +1,118 @@
|
||||
name: Comment with Pyrefly Type Coverage
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows:
|
||||
- Pyrefly Type Coverage
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Checkout default branch (trusted code)
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Download type coverage artifact
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }},
|
||||
});
|
||||
const match = artifacts.data.artifacts.find((artifact) =>
|
||||
artifact.name === 'pyrefly_type_coverage'
|
||||
);
|
||||
if (!match) {
|
||||
throw new Error('pyrefly_type_coverage artifact not found');
|
||||
}
|
||||
const download = await github.rest.actions.downloadArtifact({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
artifact_id: match.id,
|
||||
archive_format: 'zip',
|
||||
});
|
||||
fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data));
|
||||
|
||||
- name: Unzip artifact
|
||||
run: unzip -o pyrefly_type_coverage.zip
|
||||
|
||||
- name: Render coverage markdown from structured data
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
echo ""
|
||||
echo "$comment_body"
|
||||
} > /tmp/type_coverage_comment.md
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||
let prNumber = null;
|
||||
try {
|
||||
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
|
||||
} catch (err) {
|
||||
const prs = context.payload.workflow_run.pull_requests || [];
|
||||
if (prs.length > 0 && prs[0].number) {
|
||||
prNumber = prs[0].number;
|
||||
}
|
||||
}
|
||||
if (!prNumber) {
|
||||
throw new Error('PR number not found in artifact or workflow_run payload');
|
||||
}
|
||||
|
||||
// Update existing comment if one exists, otherwise create new
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
const marker = '### Pyrefly Type Coverage';
|
||||
const existing = comments.find(c => c.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
120
.github/workflows/pyrefly-type-coverage.yml
vendored
Normal file
@ -0,0 +1,120 @@
|
||||
name: Pyrefly Type Coverage
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'api/**/*.py'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run pyrefly report on PR branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \
|
||||
mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \
|
||||
echo '{}' > /tmp/pyrefly_report_pr.json
|
||||
|
||||
- name: Save helper script from base branch
|
||||
run: |
|
||||
git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \
|
||||
|| cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py
|
||||
|
||||
- name: Checkout base branch
|
||||
run: git checkout ${{ github.base_ref }}
|
||||
|
||||
- name: Run pyrefly report on base branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \
|
||||
mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \
|
||||
echo '{}' > /tmp/pyrefly_report_base.json
|
||||
|
||||
- name: Generate coverage comparison
|
||||
id: coverage
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \
|
||||
--base /tmp/pyrefly_report_base.json \
|
||||
< /tmp/pyrefly_report_pr.json)"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
echo ""
|
||||
echo "$comment_body"
|
||||
} | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md
|
||||
|
||||
# Save structured data for the fork-PR comment workflow
|
||||
cp /tmp/pyrefly_report_pr.json pr_report.json
|
||||
cp /tmp/pyrefly_report_base.json base_report.json
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload type coverage artifact
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: pyrefly_type_coverage
|
||||
path: |
|
||||
pr_report.json
|
||||
base_report.json
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with type coverage
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const marker = '### Pyrefly Type Coverage';
|
||||
let body;
|
||||
try {
|
||||
body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' });
|
||||
} catch {
|
||||
body = `${marker}\n\n_Coverage report unavailable._`;
|
||||
}
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
|
||||
// Update existing comment if one exists, otherwise create new
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
const existing = comments.find(c => c.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
@ -2,7 +2,6 @@ import base64
|
||||
import secrets
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm):
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account = db.session.merge(account)
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm):
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account = db.session.merge(account)
|
||||
account.email = normalized_new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
@ -14,7 +13,6 @@ from controllers.console.auth.error import (
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
def _update_existing_account(self, account, password_hashed, salt):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
||||
@ -4,7 +4,6 @@ import urllib.parse
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -562,8 +561,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
|
||||
@ -3,7 +3,6 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object],
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
|
||||
@ -31,7 +31,6 @@ ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
||||
type AnyFunction = Callable[..., Any]
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_user(cls, user_id: str) -> Union[EndUser, Account]:
|
||||
def _get_user(cls, user_id: str) -> EndUser | Account:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
@ -158,7 +157,7 @@ class ToolFileManager:
|
||||
|
||||
return tool_file
|
||||
|
||||
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary(self, id: str) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -176,7 +175,7 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.runtime import VariablePool
|
||||
@ -100,7 +100,7 @@ class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
_builtin_tools_labels: dict[str, I18nObject | None] = {}
|
||||
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
@ -190,7 +190,7 @@ class ToolManager:
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
|
||||
credential_id: str | None = None,
|
||||
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
|
||||
) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@ -398,7 +398,7 @@ class ToolManager:
|
||||
agent_tool: AgentToolEntity,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
variable_pool: "VariablePool | None" = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@ -442,7 +442,7 @@ class ToolManager:
|
||||
workflow_tool: WorkflowToolRuntimeSpec,
|
||||
user_id: str | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
variable_pool: "VariablePool | None" = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
@ -634,7 +634,7 @@ class ToolManager:
|
||||
cls._builtin_providers_loaded = False
|
||||
|
||||
@classmethod
|
||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
||||
def get_tool_label(cls, tool_name: str) -> I18nObject | None:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
@ -1052,7 +1052,7 @@ class ToolManager:
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: Optional["VariablePool"],
|
||||
variable_pool: "VariablePool | None",
|
||||
tool_configurations: Mapping[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab
|
||||
handler = _get_handler_instance(handler_class or SpanHandler)
|
||||
tracer = get_tracer(__name__)
|
||||
|
||||
return handler.wrapper(
|
||||
tracer=tracer,
|
||||
wrapped=func,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
return handler.wrapper(tracer, func, *args, **kwargs)
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import inspect
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
|
||||
|
||||
class SpanHandler:
|
||||
@ -16,9 +16,9 @@ class SpanHandler:
|
||||
exceptions. Handlers can override the wrapper method to customize behavior.
|
||||
"""
|
||||
|
||||
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
|
||||
_signature_cache: dict[Callable[..., object], inspect.Signature] = {}
|
||||
|
||||
def _build_span_name(self, wrapped: Callable[..., Any]) -> str:
|
||||
def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str:
|
||||
"""
|
||||
Build the span name from the wrapped function.
|
||||
|
||||
@ -29,11 +29,11 @@ class SpanHandler:
|
||||
"""
|
||||
return f"{wrapped.__module__}.{wrapped.__qualname__}"
|
||||
|
||||
def _extract_arguments[T](
|
||||
def _extract_arguments[**P, R](
|
||||
self,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract function arguments using inspect.signature.
|
||||
@ -59,13 +59,13 @@ class SpanHandler:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def wrapper[T](
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> T:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
"""
|
||||
Fully control the wrapper behavior.
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from extensions.otel.decorators.handler import SpanHandler
|
||||
@ -15,15 +14,15 @@ logger = logging.getLogger(__name__)
|
||||
class AppGenerateHandler(SpanHandler):
|
||||
"""Span handler for ``AppGenerateService.generate``."""
|
||||
|
||||
def wrapper[T](
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., T],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> T:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||
if not arguments:
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from extensions.otel.decorators.handler import SpanHandler
|
||||
@ -14,15 +13,15 @@ logger = logging.getLogger(__name__)
|
||||
class WorkflowAppRunnerHandler(SpanHandler):
|
||||
"""Span handler for ``WorkflowAppRunner.run``."""
|
||||
|
||||
def wrapper(
|
||||
def wrapper[**P, R](
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Any:
|
||||
tracer: Tracer,
|
||||
wrapped: Callable[P, R],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> R:
|
||||
try:
|
||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||
arguments = self._extract_arguments(wrapped, *args, **kwargs)
|
||||
if not arguments:
|
||||
return wrapped(*args, **kwargs)
|
||||
|
||||
|
||||
145
api/libs/pyrefly_type_coverage.py
Normal file
145
api/libs/pyrefly_type_coverage.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Helpers for generating type-coverage summaries from pyrefly report output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class CoverageSummary(TypedDict):
|
||||
n_modules: int
|
||||
n_typable: int
|
||||
n_typed: int
|
||||
n_any: int
|
||||
n_untyped: int
|
||||
coverage: float
|
||||
strict_coverage: float
|
||||
|
||||
|
||||
_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__)
|
||||
|
||||
_EMPTY_SUMMARY: CoverageSummary = {
|
||||
"n_modules": 0,
|
||||
"n_typable": 0,
|
||||
"n_typed": 0,
|
||||
"n_any": 0,
|
||||
"n_untyped": 0,
|
||||
"coverage": 0.0,
|
||||
"strict_coverage": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def parse_summary(report_json: str) -> CoverageSummary:
|
||||
"""Extract the summary section from ``pyrefly report`` JSON output.
|
||||
|
||||
Returns an empty summary when *report_json* is empty or malformed so that
|
||||
the CI workflow can degrade gracefully instead of crashing.
|
||||
"""
|
||||
if not report_json or not report_json.strip():
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
try:
|
||||
data = json.loads(report_json)
|
||||
except json.JSONDecodeError:
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
summary = data.get("summary")
|
||||
if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary):
|
||||
return _EMPTY_SUMMARY.copy()
|
||||
|
||||
return {
|
||||
"n_modules": summary["n_modules"],
|
||||
"n_typable": summary["n_typable"],
|
||||
"n_typed": summary["n_typed"],
|
||||
"n_any": summary["n_any"],
|
||||
"n_untyped": summary["n_untyped"],
|
||||
"coverage": summary["coverage"],
|
||||
"strict_coverage": summary["strict_coverage"],
|
||||
}
|
||||
|
||||
|
||||
def format_summary_markdown(summary: CoverageSummary) -> str:
|
||||
"""Format a single coverage summary as a Markdown table."""
|
||||
|
||||
return (
|
||||
"| Metric | Value |\n"
|
||||
"| --- | ---: |\n"
|
||||
f"| Modules | {summary['n_modules']} |\n"
|
||||
f"| Typable symbols | {summary['n_typable']:,} |\n"
|
||||
f"| Typed symbols | {summary['n_typed']:,} |\n"
|
||||
f"| Untyped symbols | {summary['n_untyped']:,} |\n"
|
||||
f"| Any symbols | {summary['n_any']:,} |\n"
|
||||
f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n"
|
||||
f"| Strict coverage | {summary['strict_coverage']:.2f}% |"
|
||||
)
|
||||
|
||||
|
||||
def format_comparison_markdown(
|
||||
base: CoverageSummary,
|
||||
pr: CoverageSummary,
|
||||
) -> str:
|
||||
"""Format a comparison between base and PR coverage as Markdown."""
|
||||
|
||||
coverage_delta = pr["coverage"] - base["coverage"]
|
||||
strict_delta = pr["strict_coverage"] - base["strict_coverage"]
|
||||
typed_delta = pr["n_typed"] - base["n_typed"]
|
||||
untyped_delta = pr["n_untyped"] - base["n_untyped"]
|
||||
|
||||
def _fmt_delta(value: float, fmt: str = ".2f") -> str:
|
||||
sign = "+" if value > 0 else ""
|
||||
return f"{sign}{value:{fmt}}"
|
||||
|
||||
lines = [
|
||||
"| Metric | Base | PR | Delta |",
|
||||
"| --- | ---: | ---: | ---: |",
|
||||
(f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"),
|
||||
(
|
||||
f"| Strict coverage | {base['strict_coverage']:.2f}% "
|
||||
f"| {pr['strict_coverage']:.2f}% "
|
||||
f"| {_fmt_delta(strict_delta)}% |"
|
||||
),
|
||||
(f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"),
|
||||
(f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"),
|
||||
(
|
||||
f"| Modules | {base['n_modules']} "
|
||||
f"| {pr['n_modules']} "
|
||||
f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |"
|
||||
),
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Read pyrefly report JSON from stdin and print a Markdown summary.
|
||||
|
||||
Accepts an optional ``--base <file>`` argument. When provided, the output
|
||||
includes a base-vs-PR comparison table.
|
||||
"""
|
||||
|
||||
args = sys.argv[1:]
|
||||
|
||||
base_file: str | None = None
|
||||
if "--base" in args:
|
||||
idx = args.index("--base")
|
||||
if idx + 1 >= len(args):
|
||||
sys.stderr.write("error: --base requires a file path\n")
|
||||
return 1
|
||||
base_file = args[idx + 1]
|
||||
|
||||
pr_report = sys.stdin.read()
|
||||
pr_summary = parse_summary(pr_report)
|
||||
|
||||
if base_file is not None:
|
||||
base_text = Path(base_file).read_text() if Path(base_file).exists() else ""
|
||||
base_summary = parse_summary(base_text)
|
||||
sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n")
|
||||
else:
|
||||
sys.stdout.write(format_summary_markdown(pr_summary) + "\n")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -121,7 +121,7 @@ class WorkflowType(StrEnum):
|
||||
raise ValueError(f"invalid workflow type value {value}")
|
||||
|
||||
@classmethod
|
||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
||||
def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType":
|
||||
"""
|
||||
Get workflow type from app mode.
|
||||
|
||||
@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None":
|
||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||
|
||||
@property
|
||||
|
||||
@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
|
||||
|
||||
class InvitationData(TypedDict):
|
||||
@ -800,19 +801,19 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
|
||||
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
query_session = session or db.session
|
||||
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
with session_factory.create_session() as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
@ -1516,8 +1517,7 @@ class RegisterService:
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -88,7 +88,7 @@ class AppGenerateService:
|
||||
def generate(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
@ -356,11 +356,11 @@ class AppGenerateService:
|
||||
def generate_more_like_this(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
message_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping, Generator]:
|
||||
) -> Mapping | Generator:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
|
||||
@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import select
|
||||
@ -50,7 +50,7 @@ class AsyncWorkflowService:
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_async(
|
||||
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
|
||||
cls, session: Session, user: Account | EndUser, trigger_data: TriggerData
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
||||
@ -177,7 +177,7 @@ class AsyncWorkflowService:
|
||||
|
||||
@classmethod
|
||||
def reinvoke_trigger(
|
||||
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
|
||||
cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str
|
||||
) -> AsyncTriggerResponse:
|
||||
"""
|
||||
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
@ -195,9 +195,7 @@ class ExternalDatasetService:
|
||||
raise ValueError(f"{parameter.get('name')} is required")
|
||||
|
||||
@staticmethod
|
||||
def process_external_api(
|
||||
settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]]
|
||||
) -> httpx.Response:
|
||||
def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
|
||||
@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager, suppress
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
from zipfile import ZIP_DEFLATED, ZipFile
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -52,7 +52,7 @@ class FileService:
|
||||
filename: str,
|
||||
content: bytes,
|
||||
mimetype: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
source: Literal["datasets"] | None = None,
|
||||
source_url: str = "",
|
||||
) -> UploadFile:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict, Union
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -626,7 +626,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def _get_credential_schema(
|
||||
self, provider_configuration: ProviderConfiguration
|
||||
) -> Union[ModelCredentialSchema, ProviderCredentialSchema]:
|
||||
) -> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
"""Get form schemas."""
|
||||
if provider_configuration.provider.model_credential_schema:
|
||||
return provider_configuration.provider.model_credential_schema
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with session_factory.create_session() as session:
|
||||
return session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
@ -19,7 +18,7 @@ class PluginPermissionService:
|
||||
install_permission: TenantPluginPermission.InstallPermission,
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
permission = session.scalar(
|
||||
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
@ -17,7 +17,7 @@ class PipelineGenerateService:
|
||||
def generate(
|
||||
cls,
|
||||
pipeline: Pipeline,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
|
||||
@ -5,7 +5,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
@ -1387,7 +1387,7 @@ class RagPipelineService:
|
||||
"uninstalled_recommended_plugins": uninstalled_plugin_list,
|
||||
}
|
||||
|
||||
def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]):
|
||||
def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser):
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from yarl import URL
|
||||
@ -69,7 +69,7 @@ class ToolTransformService:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
||||
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
|
||||
@ -7,15 +7,16 @@ with appropriate retry policies and error handling.
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired
|
||||
|
||||
from celery import shared_task
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.app.layers.timeslice_layer import TimeSliceLayer
|
||||
@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowGeneratorArgsDict(TypedDict):
|
||||
inputs: dict[str, Any]
|
||||
files: list[Any]
|
||||
_skip_prepare_user_inputs: bool
|
||||
workflow_id: NotRequired[str]
|
||||
|
||||
|
||||
@shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE)
|
||||
def execute_workflow_professional(task_data_dict: dict[str, Any]):
|
||||
"""Execute workflow for professional tier with highest priority"""
|
||||
@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]):
|
||||
)
|
||||
|
||||
|
||||
def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]:
|
||||
def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict:
|
||||
"""Build args passed into WorkflowAppGenerator.generate for Celery executions."""
|
||||
|
||||
args: dict[str, Any] = {
|
||||
return {
|
||||
"inputs": dict(trigger_data.inputs),
|
||||
"files": list(trigger_data.files),
|
||||
SKIP_PREPARE_USER_INPUTS_KEY: True,
|
||||
"_skip_prepare_user_inputs": True,
|
||||
}
|
||||
return args
|
||||
|
||||
|
||||
def _execute_workflow_common(
|
||||
|
||||
@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi:
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_fetches_account_with_original_email(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
@ -126,6 +128,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_db.session.merge.return_value = mock_account
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -437,7 +437,10 @@ class TestAccountGeneration:
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -335,10 +335,12 @@ class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
||||
def test_reset_password_success(
|
||||
self,
|
||||
mock_get_tenants,
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
@ -356,6 +358,7 @@ class TestForgotPasswordResetApi:
|
||||
# Arrange
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_db.session.merge.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
||||
# Act
|
||||
|
||||
@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("controllers.web.forgot_password.sessionmaker")
|
||||
def test_should_normalize_email_before_sending(
|
||||
self,
|
||||
mock_session_cls,
|
||||
mock_extract_ip,
|
||||
mock_rate_limit,
|
||||
mock_get_account,
|
||||
@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||
@ -153,14 +148,14 @@ class TestForgotPasswordResetApi:
|
||||
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.sessionmaker")
|
||||
@patch("controllers.web.forgot_password.db")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_should_fetch_account_with_fallback(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
@ -168,29 +163,27 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.merge.return_value = mock_account
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_update_account.assert_called_once()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.sessionmaker")
|
||||
@patch("controllers.web.forgot_password.db")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@ -199,7 +192,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_account,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_db,
|
||||
mock_token_bytes,
|
||||
mock_hash_password,
|
||||
app,
|
||||
@ -207,20 +200,18 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
||||
account = MagicMock()
|
||||
mock_get_account.return_value = account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_db.session.merge.return_value = account
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("reset-token")
|
||||
|
||||
@ -1,239 +1,193 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class TestHitTestingService:
|
||||
"""Test suite for HitTestingService"""
|
||||
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
ds = Dataset(
|
||||
tenant_id=kwargs.get("tenant_id", tenant_id),
|
||||
name=kwargs.get("name", "test-dataset"),
|
||||
created_by=kwargs.get("created_by", created_by),
|
||||
provider=provider,
|
||||
)
|
||||
db_session.add(ds)
|
||||
db_session.commit()
|
||||
db_session.refresh(ds)
|
||||
return ds
|
||||
|
||||
# ===== Utility Method Tests =====
|
||||
|
||||
class TestHitTestingService:
|
||||
# ── Utility methods (pure logic, no DB) ────────────────────────────
|
||||
|
||||
def test_escape_query_for_search_should_escape_double_quotes(self):
|
||||
"""Test that escape_query_for_search escapes double quotes correctly"""
|
||||
# Arrange
|
||||
query = 'test "query" with quotes'
|
||||
expected = 'test \\"query\\" with quotes'
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
assert result == 'test \\"query\\" with quotes'
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_query(self):
|
||||
"""Test that hit_testing_args_check passes with a valid query"""
|
||||
# Arrange
|
||||
args = {"query": "valid query"}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
HitTestingService.hit_testing_args_check({"query": "valid query"})
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
|
||||
"""Test that hit_testing_args_check passes with valid attachment_ids"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": ["id1", "id2"]}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
|
||||
# Arrange
|
||||
args = {}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query or attachment_ids is required" in str(exc_info.value)
|
||||
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
|
||||
HitTestingService.hit_testing_args_check({})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
|
||||
# Arrange
|
||||
args = {"query": "a" * 251}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query cannot exceed 250 characters" in str(exc_info.value)
|
||||
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check({"query": "a" * 251})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": "not a list"}
|
||||
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Attachment_ids must be a list" in str(exc_info.value)
|
||||
|
||||
# ===== Response Formatting Tests =====
|
||||
# ── Response formatting ────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
|
||||
"""Test that compact_retrieve_response formats the response correctly"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
documents = [mock_doc]
|
||||
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"content": "formatted content"}
|
||||
mock_format.return_value = [mock_record]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert len(result["records"]) == 1
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
|
||||
mock_format.assert_called_once_with(documents)
|
||||
mock_format.assert_called_once_with([mock_doc])
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "external"
|
||||
query = "test query"
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
documents = [
|
||||
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
|
||||
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
|
||||
]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
result = cast(
|
||||
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert len(result["records"]) == 2
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
|
||||
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
documents = [{"content": "c1"}]
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== External Retrieve Tests =====
|
||||
# ── External retrieve (real DB) ────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
|
||||
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
dataset.provider = "external"
|
||||
query = 'test "query"'
|
||||
def test_external_retrieve_should_succeed_for_external_provider(
|
||||
self, mock_ext_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
account_id = str(uuid4())
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
account.id = account_id
|
||||
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
|
||||
|
||||
# Act
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query='test "query"',
|
||||
account=account,
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
|
||||
|
||||
# Verify call to RetrievalService.external_retrieve with escaped query
|
||||
mock_ext_retrieve.assert_called_once_with(
|
||||
dataset_id="dataset_id",
|
||||
dataset_id=dataset.id,
|
||||
query='test \\"query\\"',
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
)
|
||||
|
||||
# Verify DatasetQuery record was added and committed
|
||||
mock_add.assert_called_once()
|
||||
mock_commit.assert_called_once()
|
||||
db_session_with_containers.expire_all()
|
||||
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
assert after_count == before_count + 1
|
||||
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that external_retrieve returns empty results immediately if provider is not external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
account = MagicMock()
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== Retrieve Tests =====
|
||||
# ── Retrieve (real DB) ─────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve uses default model when retrieval_model is not provided"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
def test_retrieve_should_use_default_model_when_none_provided(
|
||||
self, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
dataset.retrieval_model = None
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
account.id = str(uuid4())
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
|
||||
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
mock_retrieve.assert_called_once()
|
||||
# Verify top_k from default_retrieval_model (4)
|
||||
assert mock_retrieve.call_args.kwargs["top_k"] == 4
|
||||
mock_commit.assert_called_once()
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
assert after_count == before_count + 1
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_metadata_filtering(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
@ -242,29 +196,27 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response
|
||||
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
|
||||
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_get_meta.assert_called_once()
|
||||
mock_retrieve.assert_called_once()
|
||||
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
|
||||
retrieval_model = {
|
||||
@ -274,37 +226,27 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response: condition returned but no IDs
|
||||
mock_get_meta.return_value = ({}, "condition_string")
|
||||
|
||||
# Act
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
mock_retrieve.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
attachment_ids = ["att1", "att2"]
|
||||
|
||||
retrieval_model = {
|
||||
@ -315,21 +257,19 @@ class TestHitTestingService:
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once_with(
|
||||
retrieval_method=ANY,
|
||||
dataset_id="dataset_id",
|
||||
query=query,
|
||||
dataset_id=dataset.id,
|
||||
query="test query",
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=4,
|
||||
score_threshold=0.0,
|
||||
@ -338,26 +278,27 @@ class TestHitTestingService:
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
)
|
||||
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
|
||||
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
|
||||
called_query = mock_add.call_args[0][0]
|
||||
query_content = json.loads(called_query.content)
|
||||
|
||||
# Verify DatasetQuery was persisted with correct content structure
|
||||
db_session_with_containers.expire_all()
|
||||
latest = db_session_with_containers.scalar(
|
||||
select(DatasetQuery)
|
||||
.where(DatasetQuery.dataset_id == dataset.id)
|
||||
.order_by(DatasetQuery.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
assert latest is not None
|
||||
query_content = json.loads(latest.content)
|
||||
assert len(query_content) == 3 # 1 text + 2 images
|
||||
assert query_content[0]["content_type"] == "text_query"
|
||||
assert query_content[1]["content_type"] == "image_query"
|
||||
assert query_content[1]["content"] == "att1"
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve passes reranking and threshold parameters correctly"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "hybrid_search",
|
||||
@ -371,12 +312,14 @@ class TestHitTestingService:
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once()
|
||||
kwargs = mock_retrieve.call_args.kwargs
|
||||
assert kwargs["score_threshold"] == 0.5
|
||||
@ -0,0 +1,363 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.entities.config_entity import TracingProviderEnum
|
||||
from models.model import TraceAppConfig
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.ops_service import OpsService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
class TestOpsService:
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager.for_tenant") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ops_trace_manager(self):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock:
|
||||
yield mock
|
||||
|
||||
def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
fake = Faker()
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
},
|
||||
account,
|
||||
)
|
||||
return app, account
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def _insert_trace_config(
|
||||
self,
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
provider: str,
|
||||
tracing_config: dict | None | object = _SENTINEL,
|
||||
) -> TraceAppConfig:
|
||||
trace_config = TraceAppConfig(
|
||||
app_id=app_id,
|
||||
tracing_provider=provider,
|
||||
tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"},
|
||||
)
|
||||
db_session.add(trace_config)
|
||||
db_session.commit()
|
||||
return trace_config
|
||||
|
||||
# ── get_tracing_app_config ─────────────────────────────────────────
|
||||
|
||||
def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize")
|
||||
assert result is None
|
||||
|
||||
def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
fake_app_id = str(uuid.uuid4())
|
||||
self._insert_trace_config(db_session_with_containers, fake_app_id, "arize")
|
||||
result = OpsService.get_tracing_app_config(fake_app_id, "arize")
|
||||
assert result is None
|
||||
|
||||
def test_get_tracing_app_config_none_config(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager
|
||||
):
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||
OpsService.get_tracing_app_config(app.id, "arize")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "default_url"),
|
||||
[
|
||||
("arize", "https://app.arize.com/"),
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
("databricks", "https://www.databricks.com/"),
|
||||
],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {}
|
||||
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, provider)
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, provider)
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == default_url
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, provider
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"}
|
||||
mock_otm.get_trace_config_project_url.return_value = "success_url"
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, provider)
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, provider)
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == "success_url"
|
||||
|
||||
def test_get_tracing_app_config_langfuse_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.get_trace_config_project_key.return_value = "key"
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, "langfuse")
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
||||
|
||||
def test_get_tracing_app_config_langfuse_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "langfuse")
|
||||
|
||||
result = OpsService.get_tracing_app_config(app.id, "langfuse")
|
||||
|
||||
assert result is not None
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
||||
|
||||
# ── create_tracing_app_config ──────────────────────────────────────
|
||||
|
||||
def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
|
||||
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
|
||||
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
||||
|
||||
def test_create_tracing_app_config_invalid_credentials(
|
||||
self, db_session_with_containers: Session, mock_ops_trace_manager
|
||||
):
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
result = OpsService.create_tracing_app_config(
|
||||
str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}
|
||||
)
|
||||
assert result == {"error": "Invalid Credentials"}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "config"),
|
||||
[
|
||||
(TracingProviderEnum.ARIZE, {}),
|
||||
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
||||
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
||||
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
||||
],
|
||||
)
|
||||
def test_create_tracing_app_config_project_url_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config
|
||||
):
|
||||
# Existing config causes the service to return None before reaching the DB insert
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_otm.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(provider))
|
||||
|
||||
result = OpsService.create_tracing_app_config(app.id, provider, config)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_create_tracing_app_config_langfuse_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_key.return_value = "key"
|
||||
mock_otm.encrypt_tracing_config.return_value = {}
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app.id,
|
||||
TracingProviderEnum.LANGFUSE,
|
||||
{"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"},
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_tracing_app_config_already_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||
|
||||
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
|
||||
assert result is None
|
||||
|
||||
def test_create_tracing_app_config_with_empty_other_keys(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# "project" is in other_keys for Arize; providing "" triggers default substitution
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_url.side_effect = Exception("no url")
|
||||
mock_otm.encrypt_tracing_config.return_value = {}
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_tracing_app_config_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
mock_otm.get_trace_config_project_url.return_value = "http://project_url"
|
||||
mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# ── update_tracing_app_config ──────────────────────────────────────
|
||||
|
||||
def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session):
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
||||
OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {})
|
||||
|
||||
def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {})
|
||||
assert result is None
|
||||
|
||||
def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager):
|
||||
fake_app_id = str(uuid.uuid4())
|
||||
self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE))
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {})
|
||||
assert result is None
|
||||
|
||||
def test_update_tracing_app_config_invalid_credentials(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.encrypt_tracing_config.return_value = {}
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.check_trace_config_is_effective.return_value = False
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid Credentials"):
|
||||
OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
def test_update_tracing_app_config_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
with patch("services.ops_service.OpsTraceManager") as mock_otm:
|
||||
mock_otm.encrypt_tracing_config.return_value = {"updated": "config"}
|
||||
mock_otm.decrypt_tracing_config.return_value = {}
|
||||
mock_otm.check_trace_config_is_effective.return_value = True
|
||||
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE))
|
||||
|
||||
result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {})
|
||||
|
||||
assert result is not None
|
||||
assert result["app_id"] == app.id
|
||||
|
||||
# ── delete_tracing_app_config ──────────────────────────────────────
|
||||
|
||||
def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session):
|
||||
result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize")
|
||||
assert result is None
|
||||
|
||||
def test_delete_tracing_app_config_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies)
|
||||
self._insert_trace_config(db_session_with_containers, app.id, "arize")
|
||||
|
||||
result = OpsService.delete_tracing_app_config(app.id, "arize")
|
||||
|
||||
assert result is True
|
||||
remaining = db_session_with_containers.scalar(
|
||||
select(TraceAppConfig)
|
||||
.where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize")
|
||||
.limit(1)
|
||||
)
|
||||
assert remaining is None
|
||||
@ -233,11 +233,10 @@ class TestWebAppAuthService:
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.password is not None
|
||||
assert result.password_salt is not None
|
||||
refreshed = db_session_with_containers.get(Account, result.id)
|
||||
assert refreshed is not None
|
||||
assert refreshed.password is not None
|
||||
assert refreshed.password_salt is not None
|
||||
|
||||
def test_authenticate_account_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -414,9 +413,8 @@ class TestWebAppAuthService:
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
refreshed = db_session_with_containers.get(Account, result.id)
|
||||
assert refreshed is not None
|
||||
|
||||
def test_get_user_through_email_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -233,15 +233,20 @@ class TestCheckEmailUnique:
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
first = MagicMock()
|
||||
first.scalar_one_or_none.return_value = None
|
||||
second = MagicMock()
|
||||
expected_account = MagicMock()
|
||||
second.scalar_one_or_none.return_value = expected_account
|
||||
session.execute.side_effect = [first, second]
|
||||
mock_session.execute.side_effect = [first, second]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("services.account_service.session_factory", mock_factory):
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert session.execute.call_count == 2
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -39,7 +39,7 @@ class TestAppGenerateHandler:
|
||||
"root_node_id": None,
|
||||
}
|
||||
|
||||
arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs)
|
||||
arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs)
|
||||
|
||||
assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate"
|
||||
assert "app_model" in arguments, "Handler uses app_model but parameter is missing"
|
||||
@ -70,14 +70,11 @@ class TestAppGenerateHandler:
|
||||
handler.wrapper(
|
||||
tracer,
|
||||
dummy_func,
|
||||
(),
|
||||
{
|
||||
"app_model": mock_app_model,
|
||||
"user": mock_account_user,
|
||||
"args": {"workflow_id": test_workflow_id},
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
"streaming": False,
|
||||
},
|
||||
app_model=mock_app_model,
|
||||
user=mock_account_user,
|
||||
args={"workflow_id": test_workflow_id},
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
|
||||
@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler:
|
||||
def runner_run(self):
|
||||
return "result"
|
||||
|
||||
handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {})
|
||||
handler.wrapper(tracer, runner_run, mock_workflow_runner)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
|
||||
@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1, 2, 3)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = ()
|
||||
kwargs = {"a": 1, "b": 2, "c": 3}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1,)
|
||||
kwargs = {"b": 2, "c": 3}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1,)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments:
|
||||
instance = MyClass()
|
||||
args = (1, 2)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(instance.method, args, kwargs)
|
||||
result = handler._extract_arguments(instance.method, *args, **kwargs)
|
||||
|
||||
assert result is not None
|
||||
assert result["a"] == 1
|
||||
@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
args = (1,)
|
||||
kwargs = {}
|
||||
result = handler._extract_arguments(func, args, kwargs)
|
||||
result = handler._extract_arguments(func, *args, **kwargs)
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments:
|
||||
|
||||
assert func not in handler._signature_cache
|
||||
|
||||
handler._extract_arguments(func, (1, 2), {})
|
||||
handler._extract_arguments(func, 1, 2)
|
||||
assert func in handler._signature_cache
|
||||
|
||||
cached_sig = handler._signature_cache[func]
|
||||
handler._extract_arguments(func, (3, 4), {})
|
||||
handler._extract_arguments(func, 3, 4)
|
||||
assert handler._signature_cache[func] is cached_sig
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func():
|
||||
return "result"
|
||||
|
||||
result = handler.wrapper(tracer, test_func, (), {})
|
||||
result = handler.wrapper(tracer, test_func)
|
||||
|
||||
assert result == "result"
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
@ -159,7 +159,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func():
|
||||
return "result"
|
||||
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -174,7 +174,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func():
|
||||
return "result"
|
||||
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -190,7 +190,7 @@ class TestSpanHandlerWrapper:
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -208,7 +208,7 @@ class TestSpanHandlerWrapper:
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@ -225,7 +225,7 @@ class TestSpanHandlerWrapper:
|
||||
raise ValueError("test error")
|
||||
|
||||
with pytest.raises(ValueError, match="test error"):
|
||||
handler.wrapper(tracer, test_func, (), {})
|
||||
handler.wrapper(tracer, test_func)
|
||||
|
||||
@patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True)
|
||||
def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter):
|
||||
@ -236,7 +236,7 @@ class TestSpanHandlerWrapper:
|
||||
def test_func(a, b, c=10):
|
||||
return a + b + c
|
||||
|
||||
result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3})
|
||||
result = handler.wrapper(tracer, test_func, 1, 2, c=3)
|
||||
|
||||
assert result == 6
|
||||
|
||||
@ -249,7 +249,7 @@ class TestSpanHandlerWrapper:
|
||||
def my_function(x):
|
||||
return x * 2
|
||||
|
||||
result = handler.wrapper(tracer, my_function, (5,), {})
|
||||
result = handler.wrapper(tracer, my_function, 5)
|
||||
|
||||
assert result == 10
|
||||
spans = memory_span_exporter.get_finished_spans()
|
||||
|
||||
138
api/tests/unit_tests/libs/test_pyrefly_type_coverage.py
Normal file
138
api/tests/unit_tests/libs/test_pyrefly_type_coverage.py
Normal file
@ -0,0 +1,138 @@
|
||||
import json
|
||||
|
||||
from libs.pyrefly_type_coverage import (
|
||||
CoverageSummary,
|
||||
format_comparison_markdown,
|
||||
format_summary_markdown,
|
||||
parse_summary,
|
||||
)
|
||||
|
||||
|
||||
def _make_report(summary: dict) -> str:
|
||||
return json.dumps({"module_reports": [], "summary": summary})
|
||||
|
||||
|
||||
_SAMPLE_SUMMARY: dict = {
|
||||
"n_modules": 100,
|
||||
"n_typable": 1000,
|
||||
"n_typed": 400,
|
||||
"n_any": 50,
|
||||
"n_untyped": 550,
|
||||
"coverage": 45.0,
|
||||
"strict_coverage": 40.0,
|
||||
"n_functions": 200,
|
||||
"n_methods": 300,
|
||||
"n_function_params": 150,
|
||||
"n_method_params": 250,
|
||||
"n_classes": 80,
|
||||
"n_attrs": 40,
|
||||
"n_properties": 20,
|
||||
"n_type_ignores": 10,
|
||||
}
|
||||
|
||||
|
||||
def _make_summary(
|
||||
*,
|
||||
n_modules: int = 100,
|
||||
n_typable: int = 1000,
|
||||
n_typed: int = 400,
|
||||
n_any: int = 50,
|
||||
n_untyped: int = 550,
|
||||
coverage: float = 45.0,
|
||||
strict_coverage: float = 40.0,
|
||||
) -> CoverageSummary:
|
||||
return {
|
||||
"n_modules": n_modules,
|
||||
"n_typable": n_typable,
|
||||
"n_typed": n_typed,
|
||||
"n_any": n_any,
|
||||
"n_untyped": n_untyped,
|
||||
"coverage": coverage,
|
||||
"strict_coverage": strict_coverage,
|
||||
}
|
||||
|
||||
|
||||
def test_parse_summary_extracts_fields() -> None:
|
||||
report_json = _make_report(_SAMPLE_SUMMARY)
|
||||
|
||||
result = parse_summary(report_json)
|
||||
|
||||
assert result["n_modules"] == 100
|
||||
assert result["n_typable"] == 1000
|
||||
assert result["n_typed"] == 400
|
||||
assert result["n_any"] == 50
|
||||
assert result["n_untyped"] == 550
|
||||
assert result["coverage"] == 45.0
|
||||
assert result["strict_coverage"] == 40.0
|
||||
|
||||
|
||||
def test_parse_summary_handles_empty_input() -> None:
|
||||
assert parse_summary("")["n_modules"] == 0
|
||||
assert parse_summary(" ")["n_modules"] == 0
|
||||
|
||||
|
||||
def test_parse_summary_handles_invalid_json() -> None:
|
||||
assert parse_summary("not json")["n_modules"] == 0
|
||||
|
||||
|
||||
def test_parse_summary_handles_missing_summary_key() -> None:
|
||||
assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0
|
||||
|
||||
|
||||
def test_parse_summary_handles_incomplete_summary() -> None:
|
||||
partial = json.dumps({"summary": {"n_modules": 5}})
|
||||
assert parse_summary(partial)["n_modules"] == 0
|
||||
|
||||
|
||||
def test_format_summary_markdown_contains_key_metrics() -> None:
|
||||
summary = _make_summary()
|
||||
|
||||
result = format_summary_markdown(summary)
|
||||
|
||||
assert "**Type coverage**" in result
|
||||
assert "45.00%" in result
|
||||
assert "40.00%" in result
|
||||
assert "| Modules | 100 |" in result
|
||||
|
||||
|
||||
def test_format_comparison_markdown_shows_positive_delta() -> None:
|
||||
base = _make_summary()
|
||||
pr = _make_summary(
|
||||
n_modules=101,
|
||||
n_typable=1010,
|
||||
n_typed=420,
|
||||
n_untyped=540,
|
||||
coverage=46.53,
|
||||
strict_coverage=41.58,
|
||||
)
|
||||
|
||||
result = format_comparison_markdown(base, pr)
|
||||
|
||||
assert "| Base | PR | Delta |" in result
|
||||
assert "+1.53%" in result
|
||||
assert "+1.58%" in result
|
||||
assert "+20" in result
|
||||
|
||||
|
||||
def test_format_comparison_markdown_shows_negative_delta() -> None:
|
||||
base = _make_summary()
|
||||
pr = _make_summary(
|
||||
n_typed=390,
|
||||
n_any=60,
|
||||
coverage=44.0,
|
||||
strict_coverage=39.0,
|
||||
)
|
||||
|
||||
result = format_comparison_markdown(base, pr)
|
||||
|
||||
assert "-1.00%" in result
|
||||
assert "-10" in result
|
||||
|
||||
|
||||
def test_format_comparison_markdown_shows_zero_delta() -> None:
|
||||
summary = _make_summary()
|
||||
|
||||
result = format_comparison_markdown(summary, summary)
|
||||
|
||||
assert "0.00%" in result
|
||||
assert "| 0 |" in result
|
||||
@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
||||
"""Patch session_factory.create_session() to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=False)
|
||||
session.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_session.return_value = session
|
||||
patcher = patch(f"{MODULE}.session_factory", mock_factory)
|
||||
return patcher, session
|
||||
|
||||
|
||||
class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.scalar.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
@ -30,10 +32,10 @@ class TestGetPermission:
|
||||
assert result is permission
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
@ -43,10 +45,10 @@ class TestGetPermission:
|
||||
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
@ -54,20 +56,24 @@ class TestChangePermission:
|
||||
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
|
||||
)
|
||||
|
||||
assert result is True
|
||||
session.begin.assert_called_once()
|
||||
session.add.assert_called_once()
|
||||
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
|
||||
)
|
||||
|
||||
assert result is True
|
||||
session.begin.assert_called_once()
|
||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
session.add.assert_not_called()
|
||||
|
||||
@ -1427,16 +1427,7 @@ class TestRegisterService:
|
||||
mock_tenant.name = "Test Workspace"
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
|
||||
# Mock database queries - need to mock the sessionmaker query
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_lookup.return_value = None
|
||||
@ -1475,7 +1466,7 @@ class TestRegisterService:
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com")
|
||||
|
||||
def test_invite_new_member_normalizes_new_account_email(
|
||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||
@ -1486,13 +1477,7 @@ class TestRegisterService:
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
mixed_email = "Invitee@Example.com"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_lookup.return_value = None
|
||||
@ -1525,7 +1510,7 @@ class TestRegisterService:
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
|
||||
mock_lookup.assert_called_once_with(mixed_email)
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
|
||||
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
||||
@ -1545,16 +1530,7 @@ class TestRegisterService:
|
||||
account_id="existing-user-456", email="existing@example.com", status="pending"
|
||||
)
|
||||
|
||||
# Mock database queries - need to mock the sessionmaker query
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
@ -1584,7 +1560,7 @@ class TestRegisterService:
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
|
||||
mock_task_dependencies.delay.assert_called_once()
|
||||
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
|
||||
mock_lookup.assert_called_once_with("existing@example.com")
|
||||
|
||||
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
|
||||
"""Test inviting a member who is already in the tenant."""
|
||||
|
||||
@ -1,392 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import TracingProviderEnum
|
||||
from models.model import App, TraceAppConfig
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
class TestOpsService:
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = None
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None."):
|
||||
OpsService.get_tracing_app_config("app_id", "arize")
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "default_url"),
|
||||
[
|
||||
("arize", "https://app.arize.com/"),
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
("databricks", "https://www.databricks.com/"),
|
||||
],
|
||||
)
|
||||
def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == default_url
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {}
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", provider)
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "success_url"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
trace_config.tracing_config = {"some": "config"}
|
||||
trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
mock_db.session.get.return_value = app
|
||||
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"}
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
|
||||
# Act
|
||||
result = OpsService.get_tracing_app_config("app_id", "langfuse")
|
||||
|
||||
# Assert
|
||||
assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/"
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid tracing provider: invalid_provider"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"})
|
||||
|
||||
# Assert
|
||||
assert result == {"error": "Invalid Credentials"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "config"),
|
||||
[
|
||||
(TracingProviderEnum.ARIZE, {}),
|
||||
(TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}),
|
||||
(TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}),
|
||||
(TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}),
|
||||
],
|
||||
)
|
||||
def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config):
|
||||
# Arrange
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error")
|
||||
mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error")
|
||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, config)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.LANGFUSE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_key.return_value = "key"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config(
|
||||
"app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig)
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {}
|
||||
|
||||
# Act
|
||||
# 'project' is in other_keys for Arize
|
||||
# provide an empty string for the project in the tracing_config
|
||||
# create_tracing_app_config will replace it with the default from the model
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url"
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"}
|
||||
|
||||
# Act
|
||||
result = OpsService.create_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"result": "success"}
|
||||
mock_db.session.add.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"):
|
||||
OpsService.update_tracing_app_config("app_id", "invalid_provider", {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Invalid Credentials"):
|
||||
OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db):
|
||||
# Arrange
|
||||
provider = TracingProviderEnum.ARIZE
|
||||
current_config = MagicMock(spec=TraceAppConfig)
|
||||
current_config.to_dict.return_value = {"some": "data"}
|
||||
app = MagicMock(spec=App)
|
||||
app.tenant_id = "tenant_id"
|
||||
mock_db.session.scalar.return_value = current_config
|
||||
mock_db.session.get.return_value = app
|
||||
mock_ops_trace_manager.decrypt_tracing_config.return_value = {}
|
||||
mock_ops_trace_manager.check_trace_config_is_effective.return_value = True
|
||||
|
||||
# Act
|
||||
result = OpsService.update_tracing_app_config("app_id", provider, {})
|
||||
|
||||
# Assert
|
||||
assert result == {"some": "data"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_no_config(self, mock_db):
|
||||
# Arrange
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.ops_service.db")
|
||||
def test_delete_tracing_app_config_success(self, mock_db):
|
||||
# Arrange
|
||||
trace_config = MagicMock(spec=TraceAppConfig)
|
||||
mock_db.session.scalar.return_value = trace_config
|
||||
|
||||
# Act
|
||||
result = OpsService.delete_tracing_app_config("app_id", "arize")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_db.session.delete.assert_called_with(trace_config)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
@ -0,0 +1,126 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { LoopVariableMap, NodeTracing } from '@/types/workflow'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { BlockEnum } from '../../../types'
|
||||
import LoopResultPanel from '../loop-result-panel'
|
||||
|
||||
const mockCodeEditor = vi.hoisted(() => vi.fn())
|
||||
const mockTracingPanel = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
|
||||
__esModule: true,
|
||||
default: (props: { title: ReactNode, value: unknown }) => {
|
||||
mockCodeEditor(props)
|
||||
return (
|
||||
<section data-testid="code-editor">
|
||||
<div>{props.title}</div>
|
||||
<div>{JSON.stringify(props.value)}</div>
|
||||
</section>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/run/tracing-panel', () => ({
|
||||
__esModule: true,
|
||||
default: (props: { list: NodeTracing[], className?: string }) => {
|
||||
mockTracingPanel(props)
|
||||
return <div data-testid="tracing-panel">{props.list.length}</div>
|
||||
},
|
||||
}))
|
||||
|
||||
const createNodeTracing = (id: string, executionMetadata?: NonNullable<NodeTracing['execution_metadata']>): NodeTracing => ({
|
||||
id,
|
||||
index: 0,
|
||||
predecessor_node_id: '',
|
||||
node_id: `node-${id}`,
|
||||
node_type: BlockEnum.Code,
|
||||
title: `Node ${id}`,
|
||||
inputs: {},
|
||||
inputs_truncated: false,
|
||||
process_data: {},
|
||||
process_data_truncated: false,
|
||||
outputs: {},
|
||||
outputs_truncated: false,
|
||||
status: 'succeeded',
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
execution_metadata: executionMetadata,
|
||||
metadata: {
|
||||
iterator_length: 0,
|
||||
iterator_index: 0,
|
||||
loop_length: 0,
|
||||
loop_index: 0,
|
||||
},
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'Tester',
|
||||
email: 'tester@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
})
|
||||
|
||||
describe('LoopResultPanel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Loop variables should be resolved by the actual run key, not the rendered row position.
|
||||
describe('Loop Variable Resolution', () => {
|
||||
it('should read loop variables by the actual loop index when rows are compacted', () => {
|
||||
const loopVariableMap: LoopVariableMap = {
|
||||
2: { item: 'alpha' },
|
||||
}
|
||||
|
||||
render(
|
||||
<LoopResultPanel
|
||||
list={[[
|
||||
createNodeTracing('loop-2-step-1', {
|
||||
total_tokens: 0,
|
||||
total_price: 0,
|
||||
currency: 'USD',
|
||||
loop_index: 2,
|
||||
}),
|
||||
]]}
|
||||
onBack={vi.fn()}
|
||||
loopVariableMap={loopVariableMap}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('workflow.singleRun.loop 1'))
|
||||
|
||||
expect(screen.getByTestId('code-editor')).toHaveTextContent('{"item":"alpha"}')
|
||||
expect(mockCodeEditor).toHaveBeenCalledWith(expect.objectContaining({
|
||||
value: loopVariableMap[2],
|
||||
}))
|
||||
})
|
||||
|
||||
it('should read loop variables by parallel run id when available', () => {
|
||||
const loopVariableMap: LoopVariableMap = {
|
||||
'parallel-1': { item: 'beta' },
|
||||
}
|
||||
|
||||
render(
|
||||
<LoopResultPanel
|
||||
list={[[
|
||||
createNodeTracing('parallel-step-1', {
|
||||
total_tokens: 0,
|
||||
total_price: 0,
|
||||
currency: 'USD',
|
||||
parallel_mode_run_id: 'parallel-1',
|
||||
}),
|
||||
]]}
|
||||
onBack={vi.fn()}
|
||||
loopVariableMap={loopVariableMap}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('workflow.singleRun.loop 1'))
|
||||
|
||||
expect(screen.getByTestId('code-editor')).toHaveTextContent('{"item":"beta"}')
|
||||
expect(mockCodeEditor).toHaveBeenCalledWith(expect.objectContaining({
|
||||
value: loopVariableMap['parallel-1'],
|
||||
}))
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -19,6 +19,18 @@ import { cn } from '@/utils/classnames'
|
||||
|
||||
const i18nPrefix = 'singleRun'
|
||||
|
||||
const getLoopRunKey = (loop: NodeTracing[], fallbackIndex: number) => {
|
||||
const executionMetadata = loop[0]?.execution_metadata
|
||||
|
||||
if (executionMetadata?.parallel_mode_run_id !== undefined)
|
||||
return executionMetadata.parallel_mode_run_id
|
||||
|
||||
if (executionMetadata?.loop_index !== undefined)
|
||||
return String(executionMetadata.loop_index)
|
||||
|
||||
return String(fallbackIndex)
|
||||
}
|
||||
|
||||
type Props = {
|
||||
list: NodeTracing[][]
|
||||
onBack: () => void
|
||||
@ -42,10 +54,8 @@ const LoopResultPanel: FC<Props> = ({
|
||||
}))
|
||||
}, [])
|
||||
|
||||
const countLoopDuration = (loop: NodeTracing[], loopDurationMap: LoopDurationMap): string => {
|
||||
const loopRunIndex = loop[0]?.execution_metadata?.loop_index as number
|
||||
const loopRunId = loop[0]?.execution_metadata?.parallel_mode_run_id
|
||||
const loopItem = loopDurationMap[loopRunId || loopRunIndex]
|
||||
const countLoopDuration = (loop: NodeTracing[], index: number, loopDurationMap: LoopDurationMap): string => {
|
||||
const loopItem = loopDurationMap[getLoopRunKey(loop, index)]
|
||||
const duration = loopItem
|
||||
return `${(duration && duration > 0.01) ? duration.toFixed(2) : 0.01}s`
|
||||
}
|
||||
@ -59,13 +69,13 @@ const LoopResultPanel: FC<Props> = ({
|
||||
return <RiErrorWarningLine className="h-4 w-4 text-text-destructive" />
|
||||
|
||||
if (isRunning)
|
||||
return <RiLoader2Line className="h-3.5 w-3.5 animate-spin text-primary-600" />
|
||||
return <RiLoader2Line className="text-primary-600 h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
return (
|
||||
<>
|
||||
{hasDurationMap && (
|
||||
<div className="system-xs-regular text-text-tertiary">
|
||||
{countLoopDuration(loop, loopDurationMap)}
|
||||
{countLoopDuration(loop, index, loopDurationMap)}
|
||||
</div>
|
||||
)}
|
||||
<RiArrowRightSLine
|
||||
@ -98,7 +108,7 @@ const LoopResultPanel: FC<Props> = ({
|
||||
<div
|
||||
className={cn(
|
||||
'flex w-full cursor-pointer items-center justify-between px-3',
|
||||
expandedLoops[index] ? 'pb-2 pt-3' : 'py-3',
|
||||
expandedLoops[index] ? 'pt-3 pb-2' : 'py-3',
|
||||
'rounded-xl text-left',
|
||||
)}
|
||||
onClick={() => toggleLoop(index)}
|
||||
@ -107,7 +117,7 @@ const LoopResultPanel: FC<Props> = ({
|
||||
<div className="flex h-4 w-4 shrink-0 items-center justify-center rounded-[5px] border-divider-subtle bg-util-colors-cyan-cyan-500">
|
||||
<Loop className="h-3 w-3 text-text-primary-on-surface" />
|
||||
</div>
|
||||
<span className="system-sm-semibold-uppercase grow text-text-primary">
|
||||
<span className="grow system-sm-semibold-uppercase text-text-primary">
|
||||
{t(`${i18nPrefix}.loop`, { ns: 'workflow' })}
|
||||
{' '}
|
||||
{index + 1}
|
||||
@ -129,14 +139,14 @@ const LoopResultPanel: FC<Props> = ({
|
||||
)}
|
||||
>
|
||||
{
|
||||
loopVariableMap?.[index] && (
|
||||
loopVariableMap?.[getLoopRunKey(loop, index)] && (
|
||||
<div className="p-2 pb-0">
|
||||
<CodeEditor
|
||||
readOnly
|
||||
title={<div>{t('nodes.loop.loopVariables', { ns: 'workflow' }).toLocaleUpperCase()}</div>}
|
||||
language={CodeLanguage.json}
|
||||
height={112}
|
||||
value={loopVariableMap[index]}
|
||||
value={loopVariableMap[getLoopRunKey(loop, index)]}
|
||||
isJSONStringifyBeauty
|
||||
/>
|
||||
</div>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import format from '..'
|
||||
import format, { addChildrenToIterationNode } from '..'
|
||||
import graphToLogStruct from '../../graph-to-log-struct'
|
||||
|
||||
describe('iteration', () => {
|
||||
@ -9,15 +9,48 @@ describe('iteration', () => {
|
||||
it('result should have no nodes in iteration node', () => {
|
||||
expect(result.find(item => !!item.execution_metadata?.iteration_id)).toBeUndefined()
|
||||
})
|
||||
// test('iteration should put nodes in details', () => {
|
||||
// expect(result).toEqual([
|
||||
// startNode,
|
||||
// {
|
||||
// ...iterationNode,
|
||||
// details: [
|
||||
// [iterations[0], iterations[1]],
|
||||
// ],
|
||||
// },
|
||||
// ])
|
||||
// })
|
||||
|
||||
it('should place the first child of a new iteration at a new record when its index is missing', () => {
|
||||
const parent = { node_id: 'iter1', node_type: 'iteration', execution_metadata: {} } as unknown as NodeTracing
|
||||
const child0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing
|
||||
const streaming = { node_id: 'code', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing
|
||||
|
||||
const result = addChildrenToIterationNode(parent, [child0, streaming])
|
||||
expect(result.details![0]).toEqual([child0])
|
||||
expect(result.details![1]).toEqual([streaming])
|
||||
})
|
||||
|
||||
it('should keep missing iteration_index items in the current record when the node has not restarted', () => {
|
||||
const parent = {
|
||||
node_id: 'iter1',
|
||||
node_type: 'iteration',
|
||||
execution_metadata: {
|
||||
iteration_duration_map: { 0: 1.2, 1: 0.4 },
|
||||
},
|
||||
} as unknown as NodeTracing
|
||||
const child0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing
|
||||
const child1 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 1 } } as unknown as NodeTracing
|
||||
const streaming = { node_id: 'tool', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing
|
||||
|
||||
const result = addChildrenToIterationNode(parent, [child0, child1, streaming])
|
||||
expect(result.details![0]).toEqual([child0])
|
||||
expect(result.details![1]).toEqual([child1, streaming])
|
||||
})
|
||||
|
||||
it('should not jump to the latest iteration when an earlier item is missing iteration_index', () => {
|
||||
const parent = {
|
||||
node_id: 'iter1',
|
||||
node_type: 'iteration',
|
||||
execution_metadata: {
|
||||
iteration_duration_map: { 0: 1.2, 1: 0.4 },
|
||||
},
|
||||
} as unknown as NodeTracing
|
||||
const code0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing
|
||||
const tool = { node_id: 'tool', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing
|
||||
const code1 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 1 } } as unknown as NodeTracing
|
||||
|
||||
const result = addChildrenToIterationNode(parent, [code0, tool, code1])
|
||||
expect(result.details![0]).toEqual([code0, tool])
|
||||
expect(result.details![1]).toEqual([code1])
|
||||
})
|
||||
})
|
||||
|
||||
@ -4,15 +4,31 @@ import formatParallelNode from '../parallel'
|
||||
|
||||
export function addChildrenToIterationNode(iterationNode: NodeTracing, childrenNodes: NodeTracing[]): NodeTracing {
|
||||
const details: NodeTracing[][] = []
|
||||
childrenNodes.forEach((item, index) => {
|
||||
let lastResolvedIndex = -1
|
||||
|
||||
childrenNodes.forEach((item) => {
|
||||
if (!item.execution_metadata)
|
||||
return
|
||||
const { iteration_index = 0 } = item.execution_metadata
|
||||
const runIndex: number = iteration_index !== undefined ? iteration_index : index
|
||||
const { iteration_index } = item.execution_metadata
|
||||
let runIndex: number
|
||||
|
||||
if (iteration_index !== undefined) {
|
||||
runIndex = iteration_index
|
||||
}
|
||||
else if (lastResolvedIndex >= 0) {
|
||||
const currentGroup = details[lastResolvedIndex] || []
|
||||
const seenSameNodeInCurrentGroup = currentGroup.some(node => node.node_id === item.node_id)
|
||||
runIndex = seenSameNodeInCurrentGroup ? lastResolvedIndex + 1 : lastResolvedIndex
|
||||
}
|
||||
else {
|
||||
runIndex = 0
|
||||
}
|
||||
|
||||
if (!details[runIndex])
|
||||
details[runIndex] = []
|
||||
|
||||
details[runIndex].push(item)
|
||||
lastResolvedIndex = runIndex
|
||||
})
|
||||
return {
|
||||
...iterationNode,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import format from '..'
|
||||
import format, { addChildrenToLoopNode } from '..'
|
||||
import graphToLogStruct from '../../graph-to-log-struct'
|
||||
|
||||
describe('loop', () => {
|
||||
@ -21,4 +21,48 @@ describe('loop', () => {
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
it('should place the first child of a new loop run at a new record when its index is missing', () => {
|
||||
const parent = { node_id: 'loop1', node_type: 'loop', execution_metadata: {} } as unknown as NodeTracing
|
||||
const child0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing
|
||||
const streaming = { node_id: 'code', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing
|
||||
|
||||
const result = addChildrenToLoopNode(parent, [child0, streaming])
|
||||
expect(result.details![0]).toEqual([child0])
|
||||
expect(result.details![1]).toEqual([streaming])
|
||||
})
|
||||
|
||||
it('should keep missing loop_index items in the current record when the node has not restarted', () => {
|
||||
const parent = {
|
||||
node_id: 'loop1',
|
||||
node_type: 'loop',
|
||||
execution_metadata: {
|
||||
loop_duration_map: { 0: 1.2, 1: 0.4 },
|
||||
},
|
||||
} as unknown as NodeTracing
|
||||
const child0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing
|
||||
const child1 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 1 } } as unknown as NodeTracing
|
||||
const streaming = { node_id: 'tool', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing
|
||||
|
||||
const result = addChildrenToLoopNode(parent, [child0, child1, streaming])
|
||||
expect(result.details![0]).toEqual([child0])
|
||||
expect(result.details![1]).toEqual([child1, streaming])
|
||||
})
|
||||
|
||||
it('should not jump to the latest loop when an earlier item is missing loop_index', () => {
|
||||
const parent = {
|
||||
node_id: 'loop1',
|
||||
node_type: 'loop',
|
||||
execution_metadata: {
|
||||
loop_duration_map: { 0: 1.2, 1: 0.4 },
|
||||
},
|
||||
} as unknown as NodeTracing
|
||||
const code0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing
|
||||
const tool = { node_id: 'tool', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing
|
||||
const code1 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 1 } } as unknown as NodeTracing
|
||||
|
||||
const result = addChildrenToLoopNode(parent, [code0, tool, code1])
|
||||
expect(result.details![0]).toEqual([code0, tool])
|
||||
expect(result.details![1]).toEqual([code1])
|
||||
})
|
||||
})
|
||||
|
||||
@ -3,20 +3,49 @@ import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import formatParallelNode from '../parallel'
|
||||
|
||||
export function addChildrenToLoopNode(loopNode: NodeTracing, childrenNodes: NodeTracing[]): NodeTracing {
|
||||
const details: NodeTracing[][] = []
|
||||
const detailsByKey = new Map<string, NodeTracing[]>()
|
||||
let lastResolvedIndex = -1
|
||||
const order: string[] = []
|
||||
|
||||
const ensureGroup = (key: string) => {
|
||||
const group = detailsByKey.get(key)
|
||||
if (group)
|
||||
return group
|
||||
|
||||
const newGroup: NodeTracing[] = []
|
||||
detailsByKey.set(key, newGroup)
|
||||
order.push(key)
|
||||
return newGroup
|
||||
}
|
||||
|
||||
childrenNodes.forEach((item) => {
|
||||
if (!item.execution_metadata)
|
||||
return
|
||||
const { parallel_mode_run_id, loop_index = 0 } = item.execution_metadata
|
||||
const runIndex: number = (parallel_mode_run_id || loop_index) as number
|
||||
if (!details[runIndex])
|
||||
details[runIndex] = []
|
||||
const { parallel_mode_run_id, loop_index } = item.execution_metadata
|
||||
let runIndex: number | string
|
||||
|
||||
details[runIndex].push(item)
|
||||
if (parallel_mode_run_id !== undefined) {
|
||||
runIndex = parallel_mode_run_id
|
||||
}
|
||||
else if (loop_index !== undefined) {
|
||||
runIndex = loop_index
|
||||
}
|
||||
else if (lastResolvedIndex >= 0) {
|
||||
const currentGroup = detailsByKey.get(String(lastResolvedIndex)) || []
|
||||
const seenSameNodeInCurrentGroup = currentGroup.some(node => node.node_id === item.node_id)
|
||||
runIndex = seenSameNodeInCurrentGroup ? lastResolvedIndex + 1 : lastResolvedIndex
|
||||
}
|
||||
else {
|
||||
runIndex = 0
|
||||
}
|
||||
|
||||
ensureGroup(String(runIndex)).push(item)
|
||||
if (typeof runIndex === 'number')
|
||||
lastResolvedIndex = runIndex
|
||||
})
|
||||
return {
|
||||
...loopNode,
|
||||
details,
|
||||
details: order.map(key => detailsByKey.get(key) || []),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -11088,11 +11088,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/workflow/run/loop-log/loop-result-panel.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/workflow/run/loop-result-panel.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 4
|
||||
|
||||
Loading…
Reference in New Issue
Block a user