diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml new file mode 100644 index 0000000000..2df364953c --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -0,0 +1,118 @@ +name: Comment with Pyrefly Type Coverage + +on: + workflow_run: + workflows: + - Pyrefly Type Coverage + types: + - completed + +permissions: {} + +jobs: + comment: + name: Comment PR with type coverage + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + issues: write + pull-requests: write + if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} + steps: + - name: Checkout default branch (trusted code) + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download type coverage artifact + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const artifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: ${{ github.event.workflow_run.id }}, + }); + const match = artifacts.data.artifacts.find((artifact) => + artifact.name === 'pyrefly_type_coverage' + ); + if (!match) { + throw new Error('pyrefly_type_coverage artifact not found'); + } + const download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: match.id, + archive_format: 'zip', + }); + fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data)); + + - name: Unzip artifact + run: unzip -o pyrefly_type_coverage.zip + + - name: Render coverage markdown from structured data + id: render + run: | + comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \ + --base base_report.json \ + < pr_report.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } > /tmp/type_coverage_comment.md + + - name: Post comment + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + let prNumber = null; + try { + prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10); + } catch (err) { + const prs = context.payload.workflow_run.pull_requests || []; + if (prs.length > 0 && prs[0].number) { + prNumber = prs[0].number; + } + } + if (!prNumber) { + throw new Error('PR number not found in artifact or workflow_run payload'); + } + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const marker = '### Pyrefly Type Coverage'; + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml new file mode 100644 index 0000000000..0c80c6a756 --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -0,0 +1,120 @@ +name: Pyrefly Type Coverage + +on: + pull_request: + paths: + - 'api/**/*.py' + +permissions: + contents: read + +jobs: + pyrefly-type-coverage: + runs-on: ubuntu-latest + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Checkout PR branch + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Run pyrefly report on PR branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \ + mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \ + echo '{}' > /tmp/pyrefly_report_pr.json + + - name: Save helper script from base branch + run: | + git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \ + || cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py + + - name: Checkout base branch + run: git checkout ${{ github.base_ref }} + + - name: Run pyrefly report on base branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \ + mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \ + echo '{}' > /tmp/pyrefly_report_base.json + + - name: Generate coverage comparison + id: coverage + run: | + comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \ + --base /tmp/pyrefly_report_base.json \ + < /tmp/pyrefly_report_pr.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md + + # Save structured data for the fork-PR comment workflow + cp /tmp/pyrefly_report_pr.json pr_report.json + cp /tmp/pyrefly_report_base.json base_report.json + + - name: Save PR number + run: | + echo ${{ github.event.pull_request.number }} > pr_number.txt + + - name: Upload type coverage artifact + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: pyrefly_type_coverage + path: | + pr_report.json + base_report.json + pr_number.txt + + - name: Comment PR with type coverage + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const marker = '### Pyrefly Type Coverage'; + let body; + try { + body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + } catch { + body = `${marker}\n\n_Coverage report unavailable._`; + } + const prNumber = context.payload.pull_request.number; + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/api/commands/account.py b/api/commands/account.py index 84af7a5ae6..6a2a2e0428 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,7 +2,6 @@ import base64 import secrets import click -from sqlalchemy.orm import sessionmaker from constants.languages import languages from extensions.ext_database import db @@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account = db.session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") @@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return - account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) + account = db.session.merge(account) + account.email = normalized_new_email + db.session.commit() + click.echo(click.style("Email updated successfully.", fg="green")) @click.command("create-tenant", help="Create account and tenant.") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 9e7faa09c5..1fd781b4fc 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -14,7 +13,6 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account @@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(normalized_email, args.password_confirm) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 63bc98b53f..ed390a5f89 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_reset_password_email( account=account, @@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt, session) - else: - raise AccountNotFound() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AccountNotFound() return {"result": "success"} - def _update_existing_account(self, account, password_hashed, salt, session): + def _update_existing_account(self, account, password_hashed, salt): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c7011fd22..d31fb4a46c 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,6 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(user_info.email) return account diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 626d330e9d..af25669ae0 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -8,7 +8,6 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import supported_language @@ -562,8 +561,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) if account is None: raise AccountNotFound() email_for_sending = account.email diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 80c3289fb4..61fd794c22 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -3,7 +3,6 @@ import secrets from flask import request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( @@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) - token = None + account = AccountService.get_account_by_email_with_case_fallback(request_email) if account is None: raise AuthenticationFailedError() else: @@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt) - else: - raise AuthenticationFailedError() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AuthenticationFailedError() return {"result": "success"} diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 0b3aa79838..70d45b15c4 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request: ReceiveRequestT _session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]" - _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] + _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object] def __init__( self, @@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request_meta: RequestParams.Meta | None, request: ReceiveRequestT, session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]", - on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object], ): self.request_id = request_id self.request_meta = request_meta diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 2653d20a7d..10e3082aa3 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -31,7 +31,6 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] RequestId = Annotated[int | str, Field(union_mode="left_to_right")] -type AnyFunction = Callable[..., Any] class RequestParams(BaseModel): diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index e2d2be92cb..c76cb865c3 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Union, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): ) @classmethod - def _get_user(cls, user_id: str) -> Union[EndUser, Account]: + def _get_user(cls, user_id: str) -> EndUser | Account: """ get the user by user id """ diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index a59d167a0a..d8674b3af9 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,7 +6,6 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Union from uuid import uuid4 import httpx @@ -158,7 +157,7 @@ class ToolFileManager: return tool_file - def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary(self, id: str) -> tuple[bytes, str] | None: """ get file binary @@ -176,7 +175,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None: """ get file binary diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2593e381cf..f8f07369d0 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa from graphon.runtime import VariablePool @@ -100,7 +100,7 @@ class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + _builtin_tools_labels: dict[str, I18nObject | None] = {} @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -190,7 +190,7 @@ class ToolManager: invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, - ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: + ) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool: """ get the tool runtime @@ -398,7 +398,7 @@ class ToolManager: agent_tool: AgentToolEntity, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the agent tool runtime @@ -442,7 +442,7 @@ class ToolManager: workflow_tool: WorkflowToolRuntimeSpec, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the workflow tool runtime @@ -634,7 +634,7 @@ class ToolManager: cls._builtin_providers_loaded = False @classmethod - def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: + def get_tool_label(cls, tool_name: str) -> I18nObject | None: """ get the tool label @@ -1052,7 +1052,7 @@ class ToolManager: def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional["VariablePool"], + variable_pool: "VariablePool | None", tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 1dd92caeae..ad83826427 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab handler = _get_handler_instance(handler_class or SpanHandler) tracer = get_tracer(__name__) - return handler.wrapper( - tracer=tracer, - wrapped=func, - args=args, - kwargs=kwargs, - ) + return handler.wrapper(tracer, func, *args, **kwargs) return cast(Callable[P, R], wrapper) diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py index e465a615a6..b0d9fa7af6 100644 --- a/api/extensions/otel/decorators/handler.py +++ b/api/extensions/otel/decorators/handler.py @@ -1,8 +1,8 @@ import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable from typing import Any -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer class SpanHandler: @@ -16,9 +16,9 @@ class SpanHandler: exceptions. Handlers can override the wrapper method to customize behavior. """ - _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} + _signature_cache: dict[Callable[..., object], inspect.Signature] = {} - def _build_span_name(self, wrapped: Callable[..., Any]) -> str: + def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str: """ Build the span name from the wrapped function. @@ -29,11 +29,11 @@ class SpanHandler: """ return f"{wrapped.__module__}.{wrapped.__qualname__}" - def _extract_arguments[T]( + def _extract_arguments[**P, R]( self, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> dict[str, Any] | None: """ Extract function arguments using inspect.signature. @@ -59,13 +59,13 @@ class SpanHandler: except Exception: return None - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: """ Fully control the wrapper behavior. diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py index cc6c75304f..df5142c310 100644 --- a/api/extensions/otel/decorators/handlers/generate_handler.py +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -15,15 +14,15 @@ logger = logging.getLogger(__name__) class AppGenerateHandler(SpanHandler): """Span handler for ``AppGenerateService.generate``.""" - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py index 8abd60197c..6b2112ceb2 100644 --- a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py +++ b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -14,15 +13,15 @@ logger = logging.getLogger(__name__) class WorkflowAppRunnerHandler(SpanHandler): """Span handler for ``WorkflowAppRunner.run``.""" - def wrapper( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/libs/pyrefly_type_coverage.py b/api/libs/pyrefly_type_coverage.py new file mode 100644 index 0000000000..369b8dff3c --- /dev/null +++ b/api/libs/pyrefly_type_coverage.py @@ -0,0 +1,145 @@ +"""Helpers for generating type-coverage summaries from pyrefly report output.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import TypedDict + + +class CoverageSummary(TypedDict): + n_modules: int + n_typable: int + n_typed: int + n_any: int + n_untyped: int + coverage: float + strict_coverage: float + + +_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__) + +_EMPTY_SUMMARY: CoverageSummary = { + "n_modules": 0, + "n_typable": 0, + "n_typed": 0, + "n_any": 0, + "n_untyped": 0, + "coverage": 0.0, + "strict_coverage": 0.0, +} + + +def parse_summary(report_json: str) -> CoverageSummary: + """Extract the summary section from ``pyrefly report`` JSON output. + + Returns an empty summary when *report_json* is empty or malformed so that + the CI workflow can degrade gracefully instead of crashing. + """ + if not report_json or not report_json.strip(): + return _EMPTY_SUMMARY.copy() + + try: + data = json.loads(report_json) + except json.JSONDecodeError: + return _EMPTY_SUMMARY.copy() + + summary = data.get("summary") + if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary): + return _EMPTY_SUMMARY.copy() + + return { + "n_modules": summary["n_modules"], + "n_typable": summary["n_typable"], + "n_typed": summary["n_typed"], + "n_any": summary["n_any"], + "n_untyped": summary["n_untyped"], + "coverage": summary["coverage"], + "strict_coverage": summary["strict_coverage"], + } + + +def format_summary_markdown(summary: CoverageSummary) -> str: + """Format a single coverage summary as a Markdown table.""" + + return ( + "| Metric | Value |\n" + "| --- | ---: |\n" + f"| Modules | {summary['n_modules']} |\n" + f"| Typable symbols | {summary['n_typable']:,} |\n" + f"| Typed symbols | {summary['n_typed']:,} |\n" + f"| Untyped symbols | {summary['n_untyped']:,} |\n" + f"| Any symbols | {summary['n_any']:,} |\n" + f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n" + f"| Strict coverage | {summary['strict_coverage']:.2f}% |" + ) + + +def format_comparison_markdown( + base: CoverageSummary, + pr: CoverageSummary, +) -> str: + """Format a comparison between base and PR coverage as Markdown.""" + + coverage_delta = pr["coverage"] - base["coverage"] + strict_delta = pr["strict_coverage"] - base["strict_coverage"] + typed_delta = pr["n_typed"] - base["n_typed"] + untyped_delta = pr["n_untyped"] - base["n_untyped"] + + def _fmt_delta(value: float, fmt: str = ".2f") -> str: + sign = "+" if value > 0 else "" + return f"{sign}{value:{fmt}}" + + lines = [ + "| Metric | Base | PR | Delta |", + "| --- | ---: | ---: | ---: |", + (f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"), + ( + f"| Strict coverage | {base['strict_coverage']:.2f}% " + f"| {pr['strict_coverage']:.2f}% " + f"| {_fmt_delta(strict_delta)}% |" + ), + (f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"), + (f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"), + ( + f"| Modules | {base['n_modules']} " + f"| {pr['n_modules']} " + f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |" + ), + ] + return "\n".join(lines) + + +def main() -> int: + """Read pyrefly report JSON from stdin and print a Markdown summary. + + Accepts an optional ``--base `` argument. When provided, the output + includes a base-vs-PR comparison table. + """ + + args = sys.argv[1:] + + base_file: str | None = None + if "--base" in args: + idx = args.index("--base") + if idx + 1 >= len(args): + sys.stderr.write("error: --base requires a file path\n") + return 1 + base_file = args[idx + 1] + + pr_report = sys.stdin.read() + pr_summary = parse_summary(pr_report) + + if base_file is not None: + base_text = Path(base_file).read_text() if Path(base_file).exists() else "" + base_summary = parse_summary(base_text) + sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n") + else: + sys.stdout.write(format_summary_markdown(pr_summary) + "\n") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 8e8d2e6fd9..bb4d6a7ec9 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -4,7 +4,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -121,7 +121,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType": """ Get workflow type from app mode. @@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None": return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property diff --git a/api/services/account_service.py b/api/services/account_service.py index 4b58b3b697..1f5f81e5bd 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter from sqlalchemy import delete, func, select, update -from sqlalchemy.orm import Session, sessionmaker + +from core.db.session_factory import session_factory class InvitationData(TypedDict): @@ -800,19 +801,19 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + def get_account_by_email_with_case_fallback(email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - query_session = session or db.session - account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account + with session_factory.create_session() as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1516,8 +1517,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 17ed98d301..5e8c7aa337 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -4,7 +4,7 @@ import logging import threading import uuid from collections.abc import Callable, Generator, Mapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -88,7 +88,7 @@ class AppGenerateService: def generate( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, @@ -356,11 +356,11 @@ class AppGenerateService: def generate_more_like_this( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[Mapping, Generator]: + ) -> Mapping | Generator: """ Generate more like this :param app_model: app model diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 55ae1e03b1..a731d5c048 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac import json from datetime import UTC, datetime -from typing import Any, Union +from typing import Any from celery.result import AsyncResult from sqlalchemy import select @@ -50,7 +50,7 @@ class AsyncWorkflowService: @classmethod def trigger_workflow_async( - cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData + cls, session: Session, user: Account | EndUser, trigger_data: TriggerData ) -> AsyncTriggerResponse: """ Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK @@ -177,7 +177,7 @@ class AsyncWorkflowService: @classmethod def reinvoke_trigger( - cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str + cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str ) -> AsyncTriggerResponse: """ Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 2bf1afba3e..ce5dee4943 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Union, cast +from typing import Any, cast from urllib.parse import urlparse import httpx @@ -195,9 +195,7 @@ class ExternalDatasetService: raise ValueError(f"{parameter.get('name')} is required") @staticmethod - def process_external_api( - settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]] - ) -> httpx.Response: + def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response: """ do http request depending on api bundle """ diff --git a/api/services/file_service.py b/api/services/file_service.py index 7443ca3271..79a935de4b 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Iterator, Sequence from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile -from typing import Literal, Union +from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile from graphon.file import helpers as file_helpers @@ -52,7 +52,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser], + user: Account | EndUser, source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 3cce83a975..41b6b885b2 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict, Union +from typing import Any, TypedDict from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -626,7 +626,7 @@ class ModelLoadBalancingService: def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + ) -> ModelCredentialSchema | ProviderCredentialSchema: """Get form schemas.""" if provider_configuration.provider.model_credential_schema: return provider_configuration.provider.model_credential_schema diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 0d2a70acbd..3cca4268d0 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,14 +1,13 @@ from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginPermission class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: return session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) @@ -19,7 +18,7 @@ class PluginPermissionService: install_permission: TenantPluginPermission.InstallPermission, debug_permission: TenantPluginPermission.DebugPermission, ): - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session, session.begin(): permission = session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 10e89b1dba..56bc785958 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Union +from typing import Any from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator @@ -17,7 +17,7 @@ class PipelineGenerateService: def generate( cls, pipeline: Pipeline, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f6d80f9a6e..5fc5b412b3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,7 +5,7 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Union, cast +from typing import Any, cast from uuid import uuid4 from flask_login import current_user @@ -1387,7 +1387,7 @@ class RagPipelineService: "uninstalled_recommended_plugins": uninstalled_plugin_list, } - def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): + def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser): """ Retry error document """ diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 4fd2ea1628..72954a3102 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Union +from typing import Any from pydantic import TypeAdapter, ValidationError from yarl import URL @@ -69,7 +69,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): + def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity): """ repack provider diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 0a73c91279..45e1f80e35 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -7,15 +7,16 @@ with appropriate retry policies and error handling. import logging from datetime import UTC, datetime -from typing import Any +from typing import Any, NotRequired from celery import shared_task from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from typing_extensions import TypedDict from configs import dify_config -from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.timeslice_layer import TimeSliceLayer @@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf logger = logging.getLogger(__name__) +class WorkflowGeneratorArgsDict(TypedDict): + inputs: dict[str, Any] + files: list[Any] + _skip_prepare_user_inputs: bool + workflow_id: NotRequired[str] + + @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): """Execute workflow for professional tier with highest priority""" @@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]): ) -def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]: +def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict: """Build args passed into WorkflowAppGenerator.generate for Celery executions.""" - - args: dict[str, Any] = { + return { "inputs": dict(trigger_data.inputs), "files": list(trigger_data.files), - SKIP_PREPARE_USER_INPUTS_KEY: True, + "_skip_prepare_user_inputs": True, } - return args def _execute_workflow_common( diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 879c337319..320da85b60 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 7b7393dade..d2703ed5cc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, + mock_db, mock_get_account, mock_update_account, app, @@ -126,6 +128,7 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account wraps_features = SimpleNamespace(enable_email_password_login=True) with ( @@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a2f1328579..1eabb45422 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -437,7 +437,10 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8f9db287e3..50249bcd74 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -335,10 +335,12 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, + mock_db, mock_get_account, mock_revoke_token, mock_get_data, @@ -356,6 +358,7 @@ class TestForgotPasswordResetApi: # Arrange mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 04ad143103..f14b2c0ae5 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi: @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, - mock_session_cls, mock_extract_ip, mock_rate_limit, mock_get_account, @@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi: mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password", - method="POST", - json={"email": "User@Example.com", "language": "zh-Hans"}, - ): - response = ForgotPasswordSendEmailApi().post() + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -153,14 +148,14 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( self, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_get_account, mock_update_account, app, @@ -168,29 +163,27 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = mock_account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "token-123", - "new_password": "ValidPass123!", - "password_confirm": "ValidPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -199,7 +192,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_token_bytes, mock_hash_password, app, @@ -207,20 +200,18 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() mock_get_account.return_value = account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("reset-token") diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py similarity index 51% rename from api/tests/unit_tests/services/test_hit_testing_service.py rename to api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 80e9729f5b..f332ba05ec 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,239 +1,193 @@ +from __future__ import annotations + import json from typing import Any, cast from unittest.mock import ANY, MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.models.document import Document -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery from services.hit_testing_service import HitTestingService -class TestHitTestingService: - """Test suite for HitTestingService""" +def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + ds = Dataset( + tenant_id=kwargs.get("tenant_id", tenant_id), + name=kwargs.get("name", "test-dataset"), + created_by=kwargs.get("created_by", created_by), + provider=provider, + ) + db_session.add(ds) + db_session.commit() + db_session.refresh(ds) + return ds - # ===== Utility Method Tests ===== + +class TestHitTestingService: + # ── Utility methods (pure logic, no DB) ──────────────────────────── def test_escape_query_for_search_should_escape_double_quotes(self): - """Test that escape_query_for_search escapes double quotes correctly""" - # Arrange query = 'test "query" with quotes' - expected = 'test \\"query\\" with quotes' - - # Act result = HitTestingService.escape_query_for_search(query) - - # Assert - assert result == expected + assert result == 'test \\"query\\" with quotes' def test_hit_testing_args_check_should_pass_with_valid_query(self): - """Test that hit_testing_args_check passes with a valid query""" - # Arrange - args = {"query": "valid query"} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"query": "valid query"}) def test_hit_testing_args_check_should_pass_with_valid_attachments(self): - """Test that hit_testing_args_check passes with valid attachment_ids""" - # Arrange - args = {"attachment_ids": ["id1", "id2"]} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): - """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" - # Arrange - args = {} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query or attachment_ids is required" in str(exc_info.value) + with pytest.raises(ValueError, match="Query or attachment_ids is required"): + HitTestingService.hit_testing_args_check({}) def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): - """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" - # Arrange - args = {"query": "a" * 251} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query cannot exceed 250 characters" in str(exc_info.value) + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): + HitTestingService.hit_testing_args_check({"query": "a" * 251}) def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): - """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" - # Arrange - args = {"attachment_ids": "not a list"} + with pytest.raises(ValueError, match="Attachment_ids must be a list"): + HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Attachment_ids must be a list" in str(exc_info.value) - - # ===== Response Formatting Tests ===== + # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") def test_compact_retrieve_response_should_format_correctly(self, mock_format): - """Test that compact_retrieve_response formats the response correctly""" - # Arrange query = "test query" mock_doc = MagicMock(spec=Document) - documents = [mock_doc] mock_record = MagicMock() mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - # Act - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) - # Assert assert cast(dict[str, Any], result["query"])["content"] == query assert len(result["records"]) == 1 assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" - mock_format.assert_called_once_with(documents) + mock_format.assert_called_once_with([mock_doc]) - def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): - """Test that compact_external_retrieve_response returns records when dataset provider is external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "external" - query = "test query" + def test_compact_external_retrieve_response_should_return_records_for_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert len(result["records"]) == 2 assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][1])["title"] == "t2" - def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): - """Test that compact_external_retrieve_response returns empty records for non-external provider""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" - documents = [{"content": "c1"}] + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="vendor") - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== External Retrieve Tests ===== + # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): - """Test that external_retrieve successfully retrieves from external provider and commits query""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - dataset.provider = "external" - query = 'test "query"' + def test_external_retrieve_should_succeed_for_external_provider( + self, mock_ext_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") + account_id = str(uuid4()) account = MagicMock() - account.id = "account_id" - + account.id = account_id mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.external_retrieve( dataset=dataset, - query=query, + query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" - - # Verify call to RetrievalService.external_retrieve with escaped query mock_ext_retrieve.assert_called_once_with( - dataset_id="dataset_id", + dataset_id=dataset.id, query='test \\"query\\"', external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ) - # Verify DatasetQuery record was added and committed - mock_add.assert_called_once() - mock_commit.assert_called_once() + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self): - """Test that external_retrieve returns empty results immediately if provider is not external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" + def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - # Act - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== Retrieve Tests ===== + # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve uses default model when retrieval_model is not provided""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" + def test_retrieve_should_use_default_model_when_none_provided( + self, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None - query = "test query" account = MagicMock() - account.id = "account_id" - + account.id = str(uuid4()) mock_retrieve.return_value = [] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" mock_retrieve.assert_called_once() - # Verify top_k from default_retrieval_model (4) assert mock_retrieve.call_args.kwargs["top_k"] == 4 - mock_commit.assert_called_once() + + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): - """Test that retrieve correctly calls metadata filtering when conditions are present""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_metadata_filtering( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "semantic_search", @@ -242,29 +196,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response - mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_get_meta.assert_called_once() mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): - """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_return_empty_if_metadata_filtering_fails( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() retrieval_model = { @@ -274,37 +226,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response: condition returned but no IDs mock_get_meta.return_value = ({}, "condition_string") - # Act result = cast( dict[str, Any], HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, ), ) - # Assert assert result["records"] == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) attachment_ids = ["att1", "att2"] retrieval_model = { @@ -315,21 +257,19 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, attachment_ids=attachment_ids, ) - # Assert mock_retrieve.assert_called_once_with( retrieval_method=ANY, - dataset_id="dataset_id", - query=query, + dataset_id=dataset.id, + query="test query", attachment_ids=attachment_ids, top_k=4, score_threshold=0.0, @@ -338,26 +278,27 @@ class TestHitTestingService: weights=None, document_ids_filter=None, ) - # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) - # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) - called_query = mock_add.call_args[0][0] - query_content = json.loads(called_query.content) + + # Verify DatasetQuery was persisted with correct content structure + db_session_with_containers.expire_all() + latest = db_session_with_containers.scalar( + select(DatasetQuery) + .where(DatasetQuery.dataset_id == dataset.id) + .order_by(DatasetQuery.created_at.desc()) + .limit(1) + ) + assert latest is not None + query_content = json.loads(latest.content) assert len(query_content) == 3 # 1 text + 2 images assert query_content[0]["content_type"] == "text_query" assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve passes reranking and threshold parameters correctly""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "hybrid_search", @@ -371,12 +312,14 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_retrieve.assert_called_once() kwargs = mock_retrieve.call_args.kwargs assert kwargs["score_threshold"] == 0.5 diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py new file mode 100644 index 0000000000..e2e1a228b2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import TraceAppConfig +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.ops_service import OpsService +from tests.test_containers_integration_tests.helpers import generate_valid_password + + +class TestOpsService: + @pytest.fixture + def mock_external_service_dependencies(self): + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_ops_trace_manager(self): + with patch("services.ops_service.OpsTraceManager") as mock: + yield mock + + def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies): + fake = Faker() + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + }, + account, + ) + return app, account + + _SENTINEL = object() + + def _insert_trace_config( + self, + db_session: Session, + app_id: str, + provider: str, + tracing_config: dict | None | object = _SENTINEL, + ) -> TraceAppConfig: + trace_config = TraceAppConfig( + app_id=app_id, + tracing_provider=provider, + tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"}, + ) + db_session.add(trace_config) + db_session.commit() + return trace_config + + # ── get_tracing_app_config ───────────────────────────────────────── + + def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, "arize") + result = OpsService.get_tracing_app_config(fake_app_id, "arize") + assert result is None + + def test_get_tracing_app_config_none_config( + self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None) + + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config(app.id, "arize") + + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {} + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == default_url + + @pytest.mark.parametrize( + "provider", + ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"], + ) + def test_get_tracing_app_config_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"} + mock_otm.get_trace_config_project_url.return_value = "success_url" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == "success_url" + + def test_get_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.return_value = "key" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + def test_get_tracing_app_config_langfuse_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + # ── create_tracing_app_config ────────────────────────────────────── + + def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + def test_create_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_ops_trace_manager + ): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + result = OpsService.create_tracing_app_config( + str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"} + ) + assert result == {"error": "Invalid Credentials"} + + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config + ): + # Existing config causes the service to return None before reaching the DB insert + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(provider)) + + result = OpsService.create_tracing_app_config(app.id, provider, config) + + assert result is None + + def test_create_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_key.return_value = "key" + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config( + app.id, + TracingProviderEnum.LANGFUSE, + {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}, + ) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is None + + def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_create_tracing_app_config_with_empty_other_keys( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + # "project" is in other_keys for Arize; providing "" triggers default substitution + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("no url") + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""}) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.return_value = "http://project_url" + mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result == {"result": "success"} + + # ── update_tracing_app_config ────────────────────────────────────── + + def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + + def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE)) + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = False + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + def test_update_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {"updated": "config"} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is not None + assert result["app_id"] == app.id + + # ── delete_tracing_app_config ────────────────────────────────────── + + def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session): + result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_delete_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize") + + result = OpsService.delete_tracing_app_config(app.id, "arize") + + assert result is True + remaining = db_session_with_containers.scalar( + select(TraceAppConfig) + .where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize") + .limit(1) + ) + assert remaining is None diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 4fe65d5803..7825f502f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -233,11 +233,10 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None - assert result.password is not None - assert result.password_salt is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None + assert refreshed.password is not None + assert refreshed.password_salt is not None def test_authenticate_account_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -414,9 +413,8 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None def test_get_user_through_email_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 7f9fe9cbf9..dd643faac9 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -233,15 +233,20 @@ class TestCheckEmailUnique: def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): - session = MagicMock() + mock_session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None second = MagicMock() expected_account = MagicMock() second.scalar_one_or_none.return_value = expected_account - session.execute.side_effect = [first, second] + mock_session.execute.side_effect = [first, second] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + mock_factory = MagicMock() + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + + with patch("services.account_service.session_factory", mock_factory): + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account - assert session.execute.call_count == 2 + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py index f7475f2239..12e91f190f 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py @@ -39,7 +39,7 @@ class TestAppGenerateHandler: "root_node_id": None, } - arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs) + arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs) assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate" assert "app_model" in arguments, "Handler uses app_model but parameter is missing" @@ -70,14 +70,11 @@ class TestAppGenerateHandler: handler.wrapper( tracer, dummy_func, - (), - { - "app_model": mock_app_model, - "user": mock_account_user, - "args": {"workflow_id": test_workflow_id}, - "invoke_from": InvokeFrom.DEBUGGER, - "streaming": False, - }, + app_model=mock_app_model, + user=mock_account_user, + args={"workflow_id": test_workflow_id}, + invoke_from=InvokeFrom.DEBUGGER, + streaming=False, ) spans = memory_span_exporter.get_finished_spans() diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py index 500f80fc3c..842e7f55e2 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py @@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler: def runner_run(self): return "result" - handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {}) + handler.wrapper(tracer, runner_run, mock_workflow_runner) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 diff --git a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py index 44788bab9a..bf861e3ef7 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py @@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments: args = (1, 2, 3) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments: args = () kwargs = {"a": 1, "b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {"b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments: instance = MyClass() args = (1, 2) kwargs = {} - result = handler._extract_arguments(instance.method, args, kwargs) + result = handler._extract_arguments(instance.method, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is None @@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments: assert func not in handler._signature_cache - handler._extract_arguments(func, (1, 2), {}) + handler._extract_arguments(func, 1, 2) assert func in handler._signature_cache cached_sig = handler._signature_cache[func] - handler._extract_arguments(func, (3, 4), {}) + handler._extract_arguments(func, 3, 4) assert handler._signature_cache[func] is cached_sig @@ -142,7 +142,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - result = handler.wrapper(tracer, test_func, (), {}) + result = handler.wrapper(tracer, test_func) assert result == "result" spans = memory_span_exporter.get_finished_spans() @@ -159,7 +159,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -174,7 +174,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -190,7 +190,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -208,7 +208,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -225,7 +225,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter): @@ -236,7 +236,7 @@ class TestSpanHandlerWrapper: def test_func(a, b, c=10): return a + b + c - result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3}) + result = handler.wrapper(tracer, test_func, 1, 2, c=3) assert result == 6 @@ -249,7 +249,7 @@ class TestSpanHandlerWrapper: def my_function(x): return x * 2 - result = handler.wrapper(tracer, my_function, (5,), {}) + result = handler.wrapper(tracer, my_function, 5) assert result == 10 spans = memory_span_exporter.get_finished_spans() diff --git a/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py new file mode 100644 index 0000000000..7087490845 --- /dev/null +++ b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py @@ -0,0 +1,138 @@ +import json + +from libs.pyrefly_type_coverage import ( + CoverageSummary, + format_comparison_markdown, + format_summary_markdown, + parse_summary, +) + + +def _make_report(summary: dict) -> str: + return json.dumps({"module_reports": [], "summary": summary}) + + +_SAMPLE_SUMMARY: dict = { + "n_modules": 100, + "n_typable": 1000, + "n_typed": 400, + "n_any": 50, + "n_untyped": 550, + "coverage": 45.0, + "strict_coverage": 40.0, + "n_functions": 200, + "n_methods": 300, + "n_function_params": 150, + "n_method_params": 250, + "n_classes": 80, + "n_attrs": 40, + "n_properties": 20, + "n_type_ignores": 10, +} + + +def _make_summary( + *, + n_modules: int = 100, + n_typable: int = 1000, + n_typed: int = 400, + n_any: int = 50, + n_untyped: int = 550, + coverage: float = 45.0, + strict_coverage: float = 40.0, +) -> CoverageSummary: + return { + "n_modules": n_modules, + "n_typable": n_typable, + "n_typed": n_typed, + "n_any": n_any, + "n_untyped": n_untyped, + "coverage": coverage, + "strict_coverage": strict_coverage, + } + + +def test_parse_summary_extracts_fields() -> None: + report_json = _make_report(_SAMPLE_SUMMARY) + + result = parse_summary(report_json) + + assert result["n_modules"] == 100 + assert result["n_typable"] == 1000 + assert result["n_typed"] == 400 + assert result["n_any"] == 50 + assert result["n_untyped"] == 550 + assert result["coverage"] == 45.0 + assert result["strict_coverage"] == 40.0 + + +def test_parse_summary_handles_empty_input() -> None: + assert parse_summary("")["n_modules"] == 0 + assert parse_summary(" ")["n_modules"] == 0 + + +def test_parse_summary_handles_invalid_json() -> None: + assert parse_summary("not json")["n_modules"] == 0 + + +def test_parse_summary_handles_missing_summary_key() -> None: + assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0 + + +def test_parse_summary_handles_incomplete_summary() -> None: + partial = json.dumps({"summary": {"n_modules": 5}}) + assert parse_summary(partial)["n_modules"] == 0 + + +def test_format_summary_markdown_contains_key_metrics() -> None: + summary = _make_summary() + + result = format_summary_markdown(summary) + + assert "**Type coverage**" in result + assert "45.00%" in result + assert "40.00%" in result + assert "| Modules | 100 |" in result + + +def test_format_comparison_markdown_shows_positive_delta() -> None: + base = _make_summary() + pr = _make_summary( + n_modules=101, + n_typable=1010, + n_typed=420, + n_untyped=540, + coverage=46.53, + strict_coverage=41.58, + ) + + result = format_comparison_markdown(base, pr) + + assert "| Base | PR | Delta |" in result + assert "+1.53%" in result + assert "+1.58%" in result + assert "+20" in result + + +def test_format_comparison_markdown_shows_negative_delta() -> None: + base = _make_summary() + pr = _make_summary( + n_typed=390, + n_any=60, + coverage=44.0, + strict_coverage=39.0, + ) + + result = format_comparison_markdown(base, pr) + + assert "-1.00%" in result + assert "-10" in result + + +def test_format_comparison_markdown_shows_zero_delta() -> None: + summary = _make_summary() + + result = format_comparison_markdown(summary, summary) + + assert "0.00%" in result + assert "| 0 |" in result diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py index 20f132c015..53a9e6210c 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py @@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service" def _patched_session(): - """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" + """Patch session_factory.create_session() to return a mock session as context manager.""" session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) - db_patcher = patch(f"{MODULE}.db") - return patcher, db_patcher, session + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=False) + session.begin.return_value.__enter__ = MagicMock(return_value=session) + session.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_factory = MagicMock() + mock_factory.create_session.return_value = session + patcher = patch(f"{MODULE}.session_factory", mock_factory) + return patcher, session class TestGetPermission: def test_returns_permission_when_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() permission = MagicMock() session.scalar.return_value = permission - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.get_permission("t1") @@ -30,10 +32,10 @@ class TestGetPermission: assert result is permission def test_returns_none_when_not_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.get_permission("t1") @@ -43,10 +45,10 @@ class TestGetPermission: class TestChangePermission: def test_creates_new_permission_when_not_exists(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: perm_cls.return_value = MagicMock() from services.plugin.plugin_permission_service import PluginPermissionService @@ -54,20 +56,24 @@ class TestChangePermission: "t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE ) + assert result is True + session.begin.assert_called_once() session.add.assert_called_once() def test_updates_existing_permission(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() session.scalar.return_value = existing - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.change_permission( "t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS ) + assert result is True + session.begin.assert_called_once() assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS session.add.assert_not_called() diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index d15074e7a6..eeb5d178ec 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1427,16 +1427,7 @@ class TestRegisterService: mock_tenant.name = "Test Workspace" mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1475,7 +1466,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) + mock_lookup.assert_called_once_with("newuser@example.com") def test_invite_new_member_normalizes_new_account_email( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1486,13 +1477,7 @@ class TestRegisterService: mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") mixed_email = "Invitee@Example.com" - mock_session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1525,7 +1510,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with(mixed_email, session=mock_session) + mock_lookup.assert_called_once_with(mixed_email) mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) @@ -1545,16 +1530,7 @@ class TestRegisterService: account_id="existing-user-456", email="existing@example.com", status="pending" ) - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = mock_existing_account @@ -1584,7 +1560,7 @@ class TestRegisterService: mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_task_dependencies.delay.assert_called_once() - mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) + mock_lookup.assert_called_once_with("existing@example.com") def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): """Test inviting a member who is already in the tenant.""" diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py deleted file mode 100644 index 7067e3b3dd..0000000000 --- a/api/tests/unit_tests/services/test_ops_service.py +++ /dev/null @@ -1,392 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.ops.entities.config_entity import TracingProviderEnum -from models.model import App, TraceAppConfig -from services.ops_service import OpsService - - -class TestOpsService: - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): - # Arrange - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.get_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = None - - # Act - result = OpsService.get_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = None - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - # Act & Assert - with pytest.raises(ValueError, match="Tracing config cannot be None."): - OpsService.get_tracing_app_config("app_id", "arize") - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - ("provider", "default_url"), - [ - ("arize", "https://app.arize.com/"), - ("phoenix", "https://app.phoenix.arize.com/projects/"), - ("langsmith", "https://smith.langchain.com/"), - ("opik", "https://www.comet.com/opik/"), - ("weave", "https://wandb.ai/"), - ("aliyun", "https://arms.console.aliyun.com/"), - ("tencent", "https://console.cloud.tencent.com/apm"), - ("mlflow", "http://localhost:5000/"), - ("databricks", "https://www.databricks.com/"), - ], - ) - def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} - mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - - # Act - result = OpsService.get_tracing_app_config("app_id", provider) - - # Assert - assert result["tracing_config"]["project_url"] == default_url - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - "provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"] - ) - def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} - mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url" - - # Act - result = OpsService.get_tracing_app_config("app_id", provider) - - # Assert - assert result["tracing_config"]["project_url"] == "success_url" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" - - # Act - result = OpsService.get_tracing_app_config("app_id", "langfuse") - - # Assert - assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - - # Act - result = OpsService.get_tracing_app_config("app_id", "langfuse") - - # Assert - assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): - # Act - result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {}) - - # Assert - assert result == {"error": "Invalid tracing provider: invalid_provider"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.LANGFUSE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = False - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"}) - - # Assert - assert result == {"error": "Invalid Credentials"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - ("provider", "config"), - [ - (TracingProviderEnum.ARIZE, {}), - (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), - (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), - (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), - ], - ) - def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config): - # Arrange - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, config) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.LANGFUSE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {} - - # Act - result = OpsService.create_tracing_app_config( - "app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"} - ) - - # Assert - assert result == {"result": "success"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = None - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {} - - # Act - # 'project' is in other_keys for Arize - # provide an empty string for the project in the tracing_config - # create_tracing_app_config will replace it with the default from the model - result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""}) - - # Assert - assert result == {"result": "success"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result == {"result": "success"} - mock_db.session.add.assert_called() - mock_db.session.commit.assert_called() - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): - # Act & Assert - with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): - OpsService.update_tracing_app_config("app_id", "invalid_provider", {}) - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = None - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = app - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.check_trace_config_is_effective.return_value = False - - # Act & Assert - with pytest.raises(ValueError, match="Invalid Credentials"): - OpsService.update_tracing_app_config("app_id", provider, {}) - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - current_config.to_dict.return_value = {"some": "data"} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = app - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result == {"some": "data"} - mock_db.session.commit.assert_called_once() - - @patch("services.ops_service.db") - def test_delete_tracing_app_config_no_config(self, mock_db): - # Arrange - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.delete_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - def test_delete_tracing_app_config_success(self, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = trace_config - - # Act - result = OpsService.delete_tracing_app_config("app_id", "arize") - - # Assert - assert result is True - mock_db.session.delete.assert_called_with(trace_config) - mock_db.session.commit.assert_called_once() diff --git a/web/app/components/workflow/run/loop-log/__tests__/loop-result-panel.spec.tsx b/web/app/components/workflow/run/loop-log/__tests__/loop-result-panel.spec.tsx new file mode 100644 index 0000000000..9c2f74a02b --- /dev/null +++ b/web/app/components/workflow/run/loop-log/__tests__/loop-result-panel.spec.tsx @@ -0,0 +1,126 @@ +import type { ReactNode } from 'react' +import type { LoopVariableMap, NodeTracing } from '@/types/workflow' +import { fireEvent, render, screen } from '@testing-library/react' +import { BlockEnum } from '../../../types' +import LoopResultPanel from '../loop-result-panel' + +const mockCodeEditor = vi.hoisted(() => vi.fn()) +const mockTracingPanel = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({ + __esModule: true, + default: (props: { title: ReactNode, value: unknown }) => { + mockCodeEditor(props) + return ( +
+
{props.title}
+
{JSON.stringify(props.value)}
+
+ ) + }, +})) + +vi.mock('@/app/components/workflow/run/tracing-panel', () => ({ + __esModule: true, + default: (props: { list: NodeTracing[], className?: string }) => { + mockTracingPanel(props) + return
{props.list.length}
+ }, +})) + +const createNodeTracing = (id: string, executionMetadata?: NonNullable): NodeTracing => ({ + id, + index: 0, + predecessor_node_id: '', + node_id: `node-${id}`, + node_type: BlockEnum.Code, + title: `Node ${id}`, + inputs: {}, + inputs_truncated: false, + process_data: {}, + process_data_truncated: false, + outputs: {}, + outputs_truncated: false, + status: 'succeeded', + error: '', + elapsed_time: 0, + execution_metadata: executionMetadata, + metadata: { + iterator_length: 0, + iterator_index: 0, + loop_length: 0, + loop_index: 0, + }, + created_at: 0, + created_by: { + id: 'user-1', + name: 'Tester', + email: 'tester@example.com', + }, + finished_at: 0, +}) + +describe('LoopResultPanel', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + // Loop variables should be resolved by the actual run key, not the rendered row position. + describe('Loop Variable Resolution', () => { + it('should read loop variables by the actual loop index when rows are compacted', () => { + const loopVariableMap: LoopVariableMap = { + 2: { item: 'alpha' }, + } + + render( + , + ) + + fireEvent.click(screen.getByText('workflow.singleRun.loop 1')) + + expect(screen.getByTestId('code-editor')).toHaveTextContent('{"item":"alpha"}') + expect(mockCodeEditor).toHaveBeenCalledWith(expect.objectContaining({ + value: loopVariableMap[2], + })) + }) + + it('should read loop variables by parallel run id when available', () => { + const loopVariableMap: LoopVariableMap = { + 'parallel-1': { item: 'beta' }, + } + + render( + , + ) + + fireEvent.click(screen.getByText('workflow.singleRun.loop 1')) + + expect(screen.getByTestId('code-editor')).toHaveTextContent('{"item":"beta"}') + expect(mockCodeEditor).toHaveBeenCalledWith(expect.objectContaining({ + value: loopVariableMap['parallel-1'], + })) + }) + }) +}) diff --git a/web/app/components/workflow/run/loop-log/loop-result-panel.tsx b/web/app/components/workflow/run/loop-log/loop-result-panel.tsx index d69ba80e89..b2d627fb01 100644 --- a/web/app/components/workflow/run/loop-log/loop-result-panel.tsx +++ b/web/app/components/workflow/run/loop-log/loop-result-panel.tsx @@ -19,6 +19,18 @@ import { cn } from '@/utils/classnames' const i18nPrefix = 'singleRun' +const getLoopRunKey = (loop: NodeTracing[], fallbackIndex: number) => { + const executionMetadata = loop[0]?.execution_metadata + + if (executionMetadata?.parallel_mode_run_id !== undefined) + return executionMetadata.parallel_mode_run_id + + if (executionMetadata?.loop_index !== undefined) + return String(executionMetadata.loop_index) + + return String(fallbackIndex) +} + type Props = { list: NodeTracing[][] onBack: () => void @@ -42,10 +54,8 @@ const LoopResultPanel: FC = ({ })) }, []) - const countLoopDuration = (loop: NodeTracing[], loopDurationMap: LoopDurationMap): string => { - const loopRunIndex = loop[0]?.execution_metadata?.loop_index as number - const loopRunId = loop[0]?.execution_metadata?.parallel_mode_run_id - const loopItem = loopDurationMap[loopRunId || loopRunIndex] + const countLoopDuration = (loop: NodeTracing[], index: number, loopDurationMap: LoopDurationMap): string => { + const loopItem = loopDurationMap[getLoopRunKey(loop, index)] const duration = loopItem return `${(duration && duration > 0.01) ? duration.toFixed(2) : 0.01}s` } @@ -59,13 +69,13 @@ const LoopResultPanel: FC = ({ return if (isRunning) - return + return return ( <> {hasDurationMap && (
- {countLoopDuration(loop, loopDurationMap)} + {countLoopDuration(loop, index, loopDurationMap)}
)} = ({
toggleLoop(index)} @@ -107,7 +117,7 @@ const LoopResultPanel: FC = ({
- + {t(`${i18nPrefix}.loop`, { ns: 'workflow' })} {' '} {index + 1} @@ -129,14 +139,14 @@ const LoopResultPanel: FC = ({ )} > { - loopVariableMap?.[index] && ( + loopVariableMap?.[getLoopRunKey(loop, index)] && (
{t('nodes.loop.loopVariables', { ns: 'workflow' }).toLocaleUpperCase()}
} language={CodeLanguage.json} height={112} - value={loopVariableMap[index]} + value={loopVariableMap[getLoopRunKey(loop, index)]} isJSONStringifyBeauty />
diff --git a/web/app/components/workflow/run/utils/format-log/iteration/__tests__/index.spec.ts b/web/app/components/workflow/run/utils/format-log/iteration/__tests__/index.spec.ts index 5b427bd9cf..8f30f6723c 100644 --- a/web/app/components/workflow/run/utils/format-log/iteration/__tests__/index.spec.ts +++ b/web/app/components/workflow/run/utils/format-log/iteration/__tests__/index.spec.ts @@ -1,6 +1,6 @@ import type { NodeTracing } from '@/types/workflow' import { noop } from 'es-toolkit/function' -import format from '..' +import format, { addChildrenToIterationNode } from '..' import graphToLogStruct from '../../graph-to-log-struct' describe('iteration', () => { @@ -9,15 +9,48 @@ describe('iteration', () => { it('result should have no nodes in iteration node', () => { expect(result.find(item => !!item.execution_metadata?.iteration_id)).toBeUndefined() }) - // test('iteration should put nodes in details', () => { - // expect(result).toEqual([ - // startNode, - // { - // ...iterationNode, - // details: [ - // [iterations[0], iterations[1]], - // ], - // }, - // ]) - // }) + + it('should place the first child of a new iteration at a new record when its index is missing', () => { + const parent = { node_id: 'iter1', node_type: 'iteration', execution_metadata: {} } as unknown as NodeTracing + const child0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing + const streaming = { node_id: 'code', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing + + const result = addChildrenToIterationNode(parent, [child0, streaming]) + expect(result.details![0]).toEqual([child0]) + expect(result.details![1]).toEqual([streaming]) + }) + + it('should keep missing iteration_index items in the current record when the node has not restarted', () => { + const parent = { + node_id: 'iter1', + node_type: 'iteration', + execution_metadata: { + iteration_duration_map: { 0: 1.2, 1: 0.4 }, + }, + } as unknown as NodeTracing + const child0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing + const child1 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 1 } } as unknown as NodeTracing + const streaming = { node_id: 'tool', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing + + const result = addChildrenToIterationNode(parent, [child0, child1, streaming]) + expect(result.details![0]).toEqual([child0]) + expect(result.details![1]).toEqual([child1, streaming]) + }) + + it('should not jump to the latest iteration when an earlier item is missing iteration_index', () => { + const parent = { + node_id: 'iter1', + node_type: 'iteration', + execution_metadata: { + iteration_duration_map: { 0: 1.2, 1: 0.4 }, + }, + } as unknown as NodeTracing + const code0 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 0 } } as unknown as NodeTracing + const tool = { node_id: 'tool', execution_metadata: { iteration_id: 'iter1' } } as unknown as NodeTracing + const code1 = { node_id: 'code', execution_metadata: { iteration_id: 'iter1', iteration_index: 1 } } as unknown as NodeTracing + + const result = addChildrenToIterationNode(parent, [code0, tool, code1]) + expect(result.details![0]).toEqual([code0, tool]) + expect(result.details![1]).toEqual([code1]) + }) }) diff --git a/web/app/components/workflow/run/utils/format-log/iteration/index.ts b/web/app/components/workflow/run/utils/format-log/iteration/index.ts index fbb81118a1..5bd6b822e0 100644 --- a/web/app/components/workflow/run/utils/format-log/iteration/index.ts +++ b/web/app/components/workflow/run/utils/format-log/iteration/index.ts @@ -4,15 +4,31 @@ import formatParallelNode from '../parallel' export function addChildrenToIterationNode(iterationNode: NodeTracing, childrenNodes: NodeTracing[]): NodeTracing { const details: NodeTracing[][] = [] - childrenNodes.forEach((item, index) => { + let lastResolvedIndex = -1 + + childrenNodes.forEach((item) => { if (!item.execution_metadata) return - const { iteration_index = 0 } = item.execution_metadata - const runIndex: number = iteration_index !== undefined ? iteration_index : index + const { iteration_index } = item.execution_metadata + let runIndex: number + + if (iteration_index !== undefined) { + runIndex = iteration_index + } + else if (lastResolvedIndex >= 0) { + const currentGroup = details[lastResolvedIndex] || [] + const seenSameNodeInCurrentGroup = currentGroup.some(node => node.node_id === item.node_id) + runIndex = seenSameNodeInCurrentGroup ? lastResolvedIndex + 1 : lastResolvedIndex + } + else { + runIndex = 0 + } + if (!details[runIndex]) details[runIndex] = [] details[runIndex].push(item) + lastResolvedIndex = runIndex }) return { ...iterationNode, diff --git a/web/app/components/workflow/run/utils/format-log/loop/__tests__/index.spec.ts b/web/app/components/workflow/run/utils/format-log/loop/__tests__/index.spec.ts index f352598943..00380361ed 100644 --- a/web/app/components/workflow/run/utils/format-log/loop/__tests__/index.spec.ts +++ b/web/app/components/workflow/run/utils/format-log/loop/__tests__/index.spec.ts @@ -1,6 +1,6 @@ import type { NodeTracing } from '@/types/workflow' import { noop } from 'es-toolkit/function' -import format from '..' +import format, { addChildrenToLoopNode } from '..' import graphToLogStruct from '../../graph-to-log-struct' describe('loop', () => { @@ -21,4 +21,48 @@ describe('loop', () => { }, ]) }) + + it('should place the first child of a new loop run at a new record when its index is missing', () => { + const parent = { node_id: 'loop1', node_type: 'loop', execution_metadata: {} } as unknown as NodeTracing + const child0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing + const streaming = { node_id: 'code', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing + + const result = addChildrenToLoopNode(parent, [child0, streaming]) + expect(result.details![0]).toEqual([child0]) + expect(result.details![1]).toEqual([streaming]) + }) + + it('should keep missing loop_index items in the current record when the node has not restarted', () => { + const parent = { + node_id: 'loop1', + node_type: 'loop', + execution_metadata: { + loop_duration_map: { 0: 1.2, 1: 0.4 }, + }, + } as unknown as NodeTracing + const child0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing + const child1 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 1 } } as unknown as NodeTracing + const streaming = { node_id: 'tool', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing + + const result = addChildrenToLoopNode(parent, [child0, child1, streaming]) + expect(result.details![0]).toEqual([child0]) + expect(result.details![1]).toEqual([child1, streaming]) + }) + + it('should not jump to the latest loop when an earlier item is missing loop_index', () => { + const parent = { + node_id: 'loop1', + node_type: 'loop', + execution_metadata: { + loop_duration_map: { 0: 1.2, 1: 0.4 }, + }, + } as unknown as NodeTracing + const code0 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 0 } } as unknown as NodeTracing + const tool = { node_id: 'tool', execution_metadata: { loop_id: 'loop1' } } as unknown as NodeTracing + const code1 = { node_id: 'code', execution_metadata: { loop_id: 'loop1', loop_index: 1 } } as unknown as NodeTracing + + const result = addChildrenToLoopNode(parent, [code0, tool, code1]) + expect(result.details![0]).toEqual([code0, tool]) + expect(result.details![1]).toEqual([code1]) + }) }) diff --git a/web/app/components/workflow/run/utils/format-log/loop/index.ts b/web/app/components/workflow/run/utils/format-log/loop/index.ts index fd26c3916e..be77626786 100644 --- a/web/app/components/workflow/run/utils/format-log/loop/index.ts +++ b/web/app/components/workflow/run/utils/format-log/loop/index.ts @@ -3,20 +3,49 @@ import { BlockEnum } from '@/app/components/workflow/types' import formatParallelNode from '../parallel' export function addChildrenToLoopNode(loopNode: NodeTracing, childrenNodes: NodeTracing[]): NodeTracing { - const details: NodeTracing[][] = [] + const detailsByKey = new Map() + let lastResolvedIndex = -1 + const order: string[] = [] + + const ensureGroup = (key: string) => { + const group = detailsByKey.get(key) + if (group) + return group + + const newGroup: NodeTracing[] = [] + detailsByKey.set(key, newGroup) + order.push(key) + return newGroup + } + childrenNodes.forEach((item) => { if (!item.execution_metadata) return - const { parallel_mode_run_id, loop_index = 0 } = item.execution_metadata - const runIndex: number = (parallel_mode_run_id || loop_index) as number - if (!details[runIndex]) - details[runIndex] = [] + const { parallel_mode_run_id, loop_index } = item.execution_metadata + let runIndex: number | string - details[runIndex].push(item) + if (parallel_mode_run_id !== undefined) { + runIndex = parallel_mode_run_id + } + else if (loop_index !== undefined) { + runIndex = loop_index + } + else if (lastResolvedIndex >= 0) { + const currentGroup = detailsByKey.get(String(lastResolvedIndex)) || [] + const seenSameNodeInCurrentGroup = currentGroup.some(node => node.node_id === item.node_id) + runIndex = seenSameNodeInCurrentGroup ? lastResolvedIndex + 1 : lastResolvedIndex + } + else { + runIndex = 0 + } + + ensureGroup(String(runIndex)).push(item) + if (typeof runIndex === 'number') + lastResolvedIndex = runIndex }) return { ...loopNode, - details, + details: order.map(key => detailsByKey.get(key) || []), } } diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 36684eae39..4797b57c12 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -11088,11 +11088,6 @@ "count": 1 } }, - "app/components/workflow/run/loop-log/loop-result-panel.tsx": { - "tailwindcss/enforce-consistent-class-order": { - "count": 3 - } - }, "app/components/workflow/run/loop-result-panel.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 4