mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
Merge branch 'main' into tp
This commit is contained in:
commit
3d445e1d95
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@ -116,6 +116,12 @@ jobs:
|
||||
if: github.event_name != 'merge_group'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Generate API docs
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd api
|
||||
uv run dev/generate_swagger_markdown_docs.py --swagger-dir openapi --markdown-dir openapi/markdown
|
||||
|
||||
- name: ESLint autofix
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
4
Makefile
4
Makefile
@ -71,13 +71,13 @@ type-check:
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
|
||||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces.
|
||||
|
||||
Flask-RESTX treats `SchemaModel` bodies as opaque JSON schemas; it does not
|
||||
promote Pydantic's nested `$defs` into top-level Swagger `definitions`.
|
||||
These helpers keep that translation centralized so models registered through
|
||||
`register_schema_models` emit resolvable Swagger 2.0 references.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
@ -8,10 +14,32 @@ from pydantic import BaseModel, TypeAdapter
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a single BaseModel with a namespace for Swagger documentation."""
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
|
||||
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
nested_definitions = schema.get("$defs")
|
||||
schema_to_register = dict(schema)
|
||||
if isinstance(nested_definitions, dict):
|
||||
schema_to_register.pop("$defs")
|
||||
|
||||
namespace.schema_model(name, schema_to_register)
|
||||
|
||||
if not isinstance(nested_definitions, dict):
|
||||
return
|
||||
|
||||
for nested_name, nested_schema in nested_definitions.items():
|
||||
if isinstance(nested_schema, dict):
|
||||
_register_json_schema(namespace, nested_name, nested_schema)
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""
|
||||
|
||||
_register_json_schema(
|
||||
namespace,
|
||||
model.__name__,
|
||||
model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
|
||||
@ -34,8 +62,10 @@ def get_or_create_model(model_name: str, field_def):
|
||||
def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
|
||||
"""Register multiple StrEnum with a namespace."""
|
||||
for model in models:
|
||||
namespace.schema_model(
|
||||
model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
_register_json_schema(
|
||||
namespace,
|
||||
model.__name__,
|
||||
TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
@ -301,15 +302,7 @@ class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, UpsertNotificationPayload, BatchAddNotificationAccountsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
|
||||
@ -2,7 +2,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -33,6 +33,7 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
register_enum_models(console_ns, ImportStatus)
|
||||
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -19,13 +20,12 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
@ -41,16 +41,16 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ModelConfig)
|
||||
register_enum_models(console_ns, LLMMode)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RuleGeneratePayload,
|
||||
RuleCodeGeneratePayload,
|
||||
RuleStructuredOutputPayload,
|
||||
InstructionGeneratePayload,
|
||||
InstructionTemplatePayload,
|
||||
ModelConfig,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
|
||||
95
api/dev/generate_fastopenapi_specs.py
Normal file
95
api/dev/generate_fastopenapi_specs.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Generate FastOpenAPI OpenAPI 3.0 specs without booting the full backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
API_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(API_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(API_ROOT))
|
||||
|
||||
from dev.generate_swagger_specs import apply_runtime_defaults, drop_null_values, sort_openapi_arrays
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FastOpenApiSpecTarget:
|
||||
route: str
|
||||
filename: str
|
||||
|
||||
|
||||
FASTOPENAPI_SPEC_TARGETS: tuple[FastOpenApiSpecTarget, ...] = (
|
||||
FastOpenApiSpecTarget(route="/fastopenapi/openapi.json", filename="fastopenapi-console-openapi.json"),
|
||||
)
|
||||
|
||||
|
||||
def create_fastopenapi_spec_app():
|
||||
"""Build a minimal Flask app that only mounts FastOpenAPI docs routes."""
|
||||
|
||||
apply_runtime_defaults()
|
||||
|
||||
from app_factory import create_flask_app_with_configs
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
app = create_flask_app_with_configs()
|
||||
ext_fastopenapi.init_app(app)
|
||||
return app
|
||||
|
||||
|
||||
def generate_fastopenapi_specs(output_dir: Path) -> list[Path]:
|
||||
"""Write FastOpenAPI specs to `output_dir` and return the written paths."""
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
app = create_fastopenapi_spec_app()
|
||||
client = app.test_client()
|
||||
|
||||
written_paths: list[Path] = []
|
||||
for target in FASTOPENAPI_SPEC_TARGETS:
|
||||
response = client.get(target.route)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}")
|
||||
|
||||
payload = response.get_json()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"unexpected response payload for {target.route}")
|
||||
payload = drop_null_values(payload)
|
||||
payload = sort_openapi_arrays(payload)
|
||||
|
||||
output_path = output_dir / target.filename
|
||||
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
written_paths.append(output_path)
|
||||
|
||||
return written_paths
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("openapi"),
|
||||
help="Directory where the OpenAPI JSON files will be written.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
written_paths = generate_fastopenapi_specs(args.output_dir)
|
||||
|
||||
for path in written_paths:
|
||||
logger.debug(path)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
161
api/dev/generate_swagger_markdown_docs.py
Normal file
161
api/dev/generate_swagger_markdown_docs.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""Generate OpenAPI JSON specs and split Markdown API docs.
|
||||
|
||||
The Markdown step uses `swagger-markdown`, the same converter family as the
|
||||
Swagger Markdown UI, so CI and local regeneration catch converter-incompatible
|
||||
OpenAPI output early.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
API_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(API_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(API_ROOT))
|
||||
|
||||
from dev.generate_fastopenapi_specs import FASTOPENAPI_SPEC_TARGETS, generate_fastopenapi_specs
|
||||
from dev.generate_swagger_specs import SPEC_TARGETS, generate_specs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SWAGGER_MARKDOWN_PACKAGE = "swagger-markdown@3.0.0"
|
||||
CONSOLE_SWAGGER_FILENAME = "console-swagger.json"
|
||||
STALE_COMBINED_MARKDOWN_FILENAME = "api-reference.md"
|
||||
|
||||
|
||||
def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
|
||||
subprocess.run(
|
||||
[
|
||||
"npx",
|
||||
"--yes",
|
||||
SWAGGER_MARKDOWN_PACKAGE,
|
||||
"-i",
|
||||
str(spec_path),
|
||||
"-o",
|
||||
str(markdown_path),
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def _demote_markdown_headings(markdown: str, *, levels: int = 1) -> str:
|
||||
"""Nest generated Markdown under another Markdown section."""
|
||||
|
||||
heading_prefix = "#" * levels
|
||||
lines = []
|
||||
for line in markdown.splitlines():
|
||||
if line.startswith("#"):
|
||||
lines.append(f"{heading_prefix}{line}")
|
||||
else:
|
||||
lines.append(line)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def _append_fastopenapi_markdown(console_markdown_path: Path, fastopenapi_markdown_path: Path) -> None:
|
||||
"""Append FastOpenAPI console docs to the existing console API Markdown."""
|
||||
|
||||
console_markdown = console_markdown_path.read_text(encoding="utf-8").rstrip()
|
||||
fastopenapi_markdown = _demote_markdown_headings(
|
||||
fastopenapi_markdown_path.read_text(encoding="utf-8"),
|
||||
levels=2,
|
||||
)
|
||||
console_markdown_path.write_text(
|
||||
"\n\n".join(
|
||||
[
|
||||
console_markdown,
|
||||
"## FastOpenAPI Preview (OpenAPI 3.0)",
|
||||
fastopenapi_markdown,
|
||||
]
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def generate_markdown_docs(
|
||||
swagger_dir: Path,
|
||||
markdown_dir: Path,
|
||||
*,
|
||||
keep_swagger_json: bool = False,
|
||||
) -> list[Path]:
|
||||
"""Generate intermediate specs, convert them to split Markdown API docs, and return Markdown paths."""
|
||||
|
||||
swagger_paths = generate_specs(swagger_dir)
|
||||
fastopenapi_paths = generate_fastopenapi_specs(swagger_dir)
|
||||
spec_paths = [*swagger_paths, *fastopenapi_paths]
|
||||
swagger_paths_by_name = {path.name: path for path in swagger_paths}
|
||||
fastopenapi_paths_by_name = {path.name: path for path in fastopenapi_paths}
|
||||
|
||||
markdown_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
written_paths: list[Path] = []
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix="dify-api-docs-") as temp_dir:
|
||||
temp_markdown_dir = Path(temp_dir)
|
||||
|
||||
for target in SPEC_TARGETS:
|
||||
swagger_path = swagger_paths_by_name[target.filename]
|
||||
markdown_path = markdown_dir / f"{swagger_path.stem}.md"
|
||||
_convert_spec_to_markdown(swagger_path, markdown_path)
|
||||
written_paths.append(markdown_path)
|
||||
|
||||
for target in FASTOPENAPI_SPEC_TARGETS: # type: ignore
|
||||
fastopenapi_path = fastopenapi_paths_by_name[target.filename]
|
||||
markdown_path = temp_markdown_dir / f"{fastopenapi_path.stem}.md"
|
||||
_convert_spec_to_markdown(fastopenapi_path, markdown_path)
|
||||
|
||||
console_markdown_path = markdown_dir / f"{Path(CONSOLE_SWAGGER_FILENAME).stem}.md"
|
||||
_append_fastopenapi_markdown(console_markdown_path, markdown_path)
|
||||
|
||||
(markdown_dir / STALE_COMBINED_MARKDOWN_FILENAME).unlink(missing_ok=True)
|
||||
finally:
|
||||
if not keep_swagger_json:
|
||||
for path in spec_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
return written_paths
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--swagger-dir",
|
||||
type=Path,
|
||||
default=Path("openapi"),
|
||||
help="Directory where intermediate JSON spec files will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--markdown-dir",
|
||||
type=Path,
|
||||
default=Path("openapi/markdown"),
|
||||
help="Directory where split Markdown API docs will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-swagger-json",
|
||||
action="store_true",
|
||||
help="Keep intermediate JSON spec files after Markdown generation.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
written_paths = generate_markdown_docs(
|
||||
args.swagger_dir,
|
||||
args.markdown_dir,
|
||||
keep_swagger_json=args.keep_swagger_json,
|
||||
)
|
||||
|
||||
for path in written_paths:
|
||||
logger.debug(path)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -9,12 +9,15 @@ which is unnecessary when the goal is only to serialize the Flask-RESTX
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Protocol, TypeGuard
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx.swagger import Swagger
|
||||
@ -30,19 +33,110 @@ if str(API_ROOT) not in sys.path:
|
||||
class SpecTarget:
|
||||
route: str
|
||||
filename: str
|
||||
namespace: str
|
||||
|
||||
|
||||
class RestxApi(Protocol):
|
||||
models: MutableMapping[str, object]
|
||||
|
||||
def model(self, name: str, model: dict[object, object]) -> object: ...
|
||||
|
||||
|
||||
SPEC_TARGETS: tuple[SpecTarget, ...] = (
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"),
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
|
||||
)
|
||||
|
||||
_ORIGINAL_REGISTER_MODEL = Swagger.register_model
|
||||
_ORIGINAL_REGISTER_FIELD = Swagger.register_field
|
||||
|
||||
|
||||
def _apply_runtime_defaults() -> None:
|
||||
def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]:
|
||||
"""Return whether a nested field map is an anonymous inline mapping."""
|
||||
|
||||
from flask_restx.model import Model, OrderedModel
|
||||
|
||||
return isinstance(value, dict) and not isinstance(value, (Model, OrderedModel))
|
||||
|
||||
|
||||
def _jsonable_schema_value(value: object) -> object:
|
||||
"""Return a deterministic JSON-serializable representation for schema fingerprints."""
|
||||
|
||||
if value is None or isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
if isinstance(value, list | tuple):
|
||||
return [_jsonable_schema_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _jsonable_schema_value(item) for key, item in value.items()}
|
||||
value_type = type(value)
|
||||
return f"<{value_type.__module__}.{value_type.__qualname__}>"
|
||||
|
||||
|
||||
def _field_signature(field: object) -> object:
|
||||
"""Build a stable signature for a Flask-RESTX field object."""
|
||||
|
||||
from flask_restx import fields
|
||||
from flask_restx.model import instance
|
||||
|
||||
field_instance = instance(field)
|
||||
signature: dict[str, object] = {
|
||||
"class": f"{field_instance.__class__.__module__}.{field_instance.__class__.__qualname__}"
|
||||
}
|
||||
|
||||
if isinstance(field_instance, fields.Nested):
|
||||
nested = getattr(field_instance, "nested", None)
|
||||
if _is_inline_field_map(nested):
|
||||
signature["nested"] = _inline_model_signature(nested)
|
||||
else:
|
||||
signature["nested"] = getattr(
|
||||
nested,
|
||||
"name",
|
||||
f"<{type(nested).__module__}.{type(nested).__qualname__}>",
|
||||
)
|
||||
elif hasattr(field_instance, "container"):
|
||||
signature["container"] = _field_signature(field_instance.container)
|
||||
else:
|
||||
schema = getattr(field_instance, "__schema__", None)
|
||||
if isinstance(schema, dict):
|
||||
signature["schema"] = _jsonable_schema_value(schema)
|
||||
|
||||
for attr_name in (
|
||||
"attribute",
|
||||
"default",
|
||||
"description",
|
||||
"example",
|
||||
"max",
|
||||
"min",
|
||||
"nullable",
|
||||
"readonly",
|
||||
"required",
|
||||
"title",
|
||||
):
|
||||
if hasattr(field_instance, attr_name):
|
||||
signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name))
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def _inline_model_signature(nested_fields: dict[object, object]) -> object:
|
||||
"""Build a stable signature for an anonymous inline model."""
|
||||
|
||||
return [
|
||||
(str(field_name), _field_signature(field))
|
||||
for field_name, field in sorted(nested_fields.items(), key=lambda item: str(item[0]))
|
||||
]
|
||||
|
||||
|
||||
def _inline_model_name(nested_fields: dict[object, object]) -> str:
|
||||
"""Return a stable Swagger model name for an anonymous inline field map."""
|
||||
|
||||
signature = json.dumps(_inline_model_signature(nested_fields), sort_keys=True, separators=(",", ":"))
|
||||
digest = hashlib.sha1(signature.encode("utf-8")).hexdigest()[:12]
|
||||
return f"_AnonymousInlineModel_{digest}"
|
||||
|
||||
|
||||
def apply_runtime_defaults() -> None:
|
||||
"""Force the small config surface required for Swagger generation."""
|
||||
|
||||
os.environ.setdefault("SECRET_KEY", "spec-export")
|
||||
@ -74,25 +168,26 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
|
||||
anonymous_models = getattr(self, "_anonymous_inline_models", None)
|
||||
if anonymous_models is None:
|
||||
anonymous_models = {}
|
||||
self._anonymous_inline_models = anonymous_models
|
||||
self.__dict__["_anonymous_inline_models"] = anonymous_models
|
||||
|
||||
anonymous_name = anonymous_models.get(id(nested_fields))
|
||||
if anonymous_name is None:
|
||||
anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}"
|
||||
anonymous_name = _inline_model_name(nested_fields)
|
||||
anonymous_models[id(nested_fields)] = anonymous_name
|
||||
self.api.model(anonymous_name, nested_fields)
|
||||
if anonymous_name not in self.api.models:
|
||||
self.api.model(anonymous_name, nested_fields)
|
||||
|
||||
return self.api.models[anonymous_name]
|
||||
|
||||
def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]:
|
||||
if isinstance(model, dict):
|
||||
if _is_inline_field_map(model):
|
||||
model = get_or_create_inline_model(self, model)
|
||||
|
||||
return _ORIGINAL_REGISTER_MODEL(self, model)
|
||||
|
||||
def register_field_with_inline_dict_support(self: Swagger, field: object) -> None:
|
||||
nested = getattr(field, "nested", None)
|
||||
if isinstance(nested, dict):
|
||||
if _is_inline_field_map(nested):
|
||||
field.model = get_or_create_inline_model(self, nested) # type: ignore
|
||||
|
||||
_ORIGINAL_REGISTER_FIELD(self, field)
|
||||
@ -105,22 +200,169 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
|
||||
def create_spec_app() -> Flask:
|
||||
"""Build a minimal Flask app that only mounts the Swagger-producing blueprints."""
|
||||
|
||||
_apply_runtime_defaults()
|
||||
apply_runtime_defaults()
|
||||
_patch_swagger_for_inline_nested_dicts()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
from controllers.console import bp as console_bp
|
||||
from controllers.console import console_ns
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.web import web_ns
|
||||
|
||||
app.register_blueprint(console_bp)
|
||||
app.register_blueprint(web_bp)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
for namespace in (console_ns, web_ns, service_api_ns):
|
||||
for api in namespace.apis:
|
||||
_materialize_inline_model_definitions(api)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _registered_models(namespace: str) -> dict[str, object]:
|
||||
"""Return the Flask-RESTX models registered for a Swagger namespace."""
|
||||
|
||||
if namespace == "console":
|
||||
from controllers.console import console_ns
|
||||
|
||||
models = dict(console_ns.models)
|
||||
for api in console_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
if namespace == "web":
|
||||
from controllers.web import web_ns
|
||||
|
||||
models = dict(web_ns.models)
|
||||
for api in web_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
if namespace == "service":
|
||||
from controllers.service_api import service_api_ns
|
||||
|
||||
models = dict(service_api_ns.models)
|
||||
for api in service_api_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
|
||||
raise ValueError(f"unknown Swagger namespace: {namespace}")
|
||||
|
||||
|
||||
def _materialize_inline_model_definitions(api: RestxApi) -> None:
|
||||
"""Convert inline `fields.Nested({...})` maps into named API models."""
|
||||
|
||||
from flask_restx import fields
|
||||
from flask_restx.model import Model, OrderedModel, instance
|
||||
|
||||
inline_models: dict[int, dict[object, object]] = {}
|
||||
inline_model_names: dict[int, str] = {}
|
||||
|
||||
def collect_field(field: object) -> None:
|
||||
field_instance = instance(field)
|
||||
if isinstance(field_instance, fields.Nested):
|
||||
nested = getattr(field_instance, "nested", None)
|
||||
if _is_inline_field_map(nested) and id(nested) not in inline_models:
|
||||
inline_models[id(nested)] = nested
|
||||
for nested_field in nested.values():
|
||||
collect_field(nested_field)
|
||||
|
||||
container = getattr(field_instance, "container", None)
|
||||
if container is not None:
|
||||
collect_field(container)
|
||||
|
||||
for model in list(api.models.values()):
|
||||
if isinstance(model, (Model, OrderedModel)):
|
||||
for field in model.values():
|
||||
collect_field(field)
|
||||
|
||||
for nested_fields in sorted(inline_models.values(), key=_inline_model_name):
|
||||
anonymous_name = _inline_model_name(nested_fields)
|
||||
inline_model_names[id(nested_fields)] = anonymous_name
|
||||
if anonymous_name not in api.models:
|
||||
api.model(anonymous_name, nested_fields)
|
||||
|
||||
def model_name_for(nested_fields: dict[object, object]) -> str:
|
||||
anonymous_name = inline_model_names.get(id(nested_fields))
|
||||
if anonymous_name is None:
|
||||
anonymous_name = _inline_model_name(nested_fields)
|
||||
inline_model_names[id(nested_fields)] = anonymous_name
|
||||
if anonymous_name not in api.models:
|
||||
api.model(anonymous_name, nested_fields)
|
||||
return anonymous_name
|
||||
|
||||
def materialize_field(field: object) -> None:
|
||||
field_instance = instance(field)
|
||||
if isinstance(field_instance, fields.Nested):
|
||||
nested = getattr(field_instance, "nested", None)
|
||||
if _is_inline_field_map(nested):
|
||||
field_instance.model = api.models[model_name_for(nested)] # type: ignore[attr-defined]
|
||||
|
||||
container = getattr(field_instance, "container", None)
|
||||
if container is not None:
|
||||
materialize_field(container)
|
||||
|
||||
index = 0
|
||||
while index < len(api.models):
|
||||
model = list(api.models.values())[index]
|
||||
index += 1
|
||||
if isinstance(model, (Model, OrderedModel)):
|
||||
for field in model.values():
|
||||
materialize_field(field)
|
||||
|
||||
|
||||
def drop_null_values(value: object) -> object:
|
||||
"""Remove JSON null values that make the Markdown converter crash."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {key: drop_null_values(item) for key, item in value.items() if item is not None}
|
||||
if isinstance(value, list):
|
||||
return [drop_null_values(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def sort_openapi_arrays(value: object, *, parent_key: str | None = None) -> object:
|
||||
"""Sort order-insensitive Swagger arrays so generated Markdown is stable."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {key: sort_openapi_arrays(item, parent_key=key) for key, item in value.items()}
|
||||
if not isinstance(value, list):
|
||||
return value
|
||||
|
||||
sorted_items = [sort_openapi_arrays(item, parent_key=parent_key) for item in value]
|
||||
if parent_key == "parameters":
|
||||
return sorted(
|
||||
sorted_items,
|
||||
key=lambda item: (
|
||||
item.get("in", "") if isinstance(item, dict) else "",
|
||||
item.get("name", "") if isinstance(item, dict) else "",
|
||||
json.dumps(item, sort_keys=True, default=str),
|
||||
),
|
||||
)
|
||||
if parent_key in {"enum", "required", "schemes", "tags"}:
|
||||
string_items = [item for item in sorted_items if isinstance(item, str)]
|
||||
if len(string_items) == len(sorted_items):
|
||||
return sorted(string_items)
|
||||
return sorted_items
|
||||
|
||||
|
||||
def _merge_registered_definitions(payload: dict[str, object], namespace: str) -> dict[str, object]:
|
||||
"""Include registered but route-indirect models in the exported Swagger definitions."""
|
||||
|
||||
definitions = payload.setdefault("definitions", {})
|
||||
if not isinstance(definitions, dict):
|
||||
raise RuntimeError("unexpected Swagger definitions payload")
|
||||
|
||||
for name, model in _registered_models(namespace).items():
|
||||
schema = getattr(model, "__schema__", None)
|
||||
if isinstance(schema, dict):
|
||||
definitions.setdefault(name, schema)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def generate_specs(output_dir: Path) -> list[Path]:
|
||||
"""Write all Swagger specs to `output_dir` and return the written paths."""
|
||||
|
||||
@ -138,6 +380,9 @@ def generate_specs(output_dir: Path) -> list[Path]:
|
||||
payload = response.get_json()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"unexpected response payload for {target.route}")
|
||||
payload = _merge_registered_definitions(payload, target.namespace)
|
||||
payload = drop_null_values(payload)
|
||||
payload = sort_openapi_arrays(payload)
|
||||
|
||||
output_path = output_dir / target.filename
|
||||
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
|
||||
14766
api/openapi/markdown/console-swagger.md
Normal file
14766
api/openapi/markdown/console-swagger.md
Normal file
File diff suppressed because it is too large
Load Diff
2754
api/openapi/markdown/service-swagger.md
Normal file
2754
api/openapi/markdown/service-swagger.md
Normal file
File diff suppressed because it is too large
Load Diff
1224
api/openapi/markdown/web-swagger.md
Normal file
1224
api/openapi/markdown/web-swagger.md
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
@ -31,7 +32,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models import EndUser
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_no_end_user(monkeypatch):
|
||||
def test_get_user_id_from_message_data_no_end_user(monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = None
|
||||
@ -39,7 +40,7 @@ def test_get_user_id_from_message_data_no_end_user(monkeypatch):
|
||||
assert get_user_id_from_message_data(message_data) == "account_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
def test_get_user_id_from_message_data_with_end_user(monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
@ -57,7 +58,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
assert get_user_id_from_message_data(message_data) == "session_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
|
||||
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
@ -111,7 +112,7 @@ def test_get_workflow_node_status():
|
||||
assert status.status_code == StatusCode.UNSET
|
||||
|
||||
|
||||
def test_create_links_from_trace_id(monkeypatch):
|
||||
def test_create_links_from_trace_id(monkeypatch: pytest.MonkeyPatch):
|
||||
# Mock create_link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
import dify_trace_aliyun.data_exporter.traceclient
|
||||
|
||||
@ -40,7 +40,7 @@ def langfuse_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langfuse_config, monkeypatch):
|
||||
def trace_instance(langfuse_config, monkeypatch: pytest.MonkeyPatch):
|
||||
# Mock Langfuse client to avoid network calls
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
@ -49,7 +49,7 @@ def trace_instance(langfuse_config, monkeypatch):
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langfuse_config, monkeypatch):
|
||||
def test_init(langfuse_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_langfuse = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
@ -64,7 +64,7 @@ def test_init(langfuse_config, monkeypatch):
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
@ -114,7 +114,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Setup trace info
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
@ -218,7 +218,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
assert other_span.level == LevelEnum.ERROR
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
@ -259,7 +259,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
assert trace_data.name == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
@ -287,7 +287,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -331,7 +331,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
assert gen_data.usage.total == 30
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -636,7 +636,7 @@ def test_langfuse_trace_entity_with_list_dict_input():
|
||||
assert data.input[0]["content"] == "hello"
|
||||
|
||||
|
||||
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
|
||||
# Setup trace info to trigger LLM node usage extraction
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
|
||||
@ -35,7 +35,7 @@ def langsmith_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langsmith_config, monkeypatch):
|
||||
def trace_instance(langsmith_config, monkeypatch: pytest.MonkeyPatch):
|
||||
# Mock LangSmith client
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
@ -44,7 +44,7 @@ def trace_instance(langsmith_config, monkeypatch):
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langsmith_config, monkeypatch):
|
||||
def test_init(langsmith_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_client_class = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
@ -57,7 +57,7 @@ def test_init(langsmith_config, monkeypatch):
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
@ -107,7 +107,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace(trace_instance, monkeypatch):
|
||||
def test_workflow_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Setup trace info
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
@ -223,7 +223,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
|
||||
assert call_args[4].run_type == LangSmithRunType.retriever
|
||||
|
||||
|
||||
def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_start_time(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
@ -266,7 +266,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
assert trace_instance.add_run.called
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.trace_id = "trace-1"
|
||||
trace_info.message_id = None
|
||||
@ -290,7 +290,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace(trace_instance, monkeypatch):
|
||||
def test_message_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -516,7 +516,7 @@ def test_update_run_error(trace_instance):
|
||||
trace_instance.update_run(update_data)
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
@ -614,7 +614,7 @@ class TestMessageTrace:
|
||||
span.set_status.assert_called_once()
|
||||
span.add_event.assert_called_once()
|
||||
|
||||
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch):
|
||||
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch: pytest.MonkeyPatch):
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
|
||||
@ -35,7 +35,7 @@ def opik_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(opik_config, monkeypatch):
|
||||
def trace_instance(opik_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
|
||||
|
||||
@ -65,7 +65,7 @@ def test_prepare_opik_uuid():
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_init(opik_config, monkeypatch):
|
||||
def test_init(opik_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_opik = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
@ -82,7 +82,7 @@ def test_init(opik_config, monkeypatch):
|
||||
assert instance.project == opik_config.project
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
@ -132,7 +132,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce"
|
||||
WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef"
|
||||
@ -221,7 +221,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3"
|
||||
WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac"
|
||||
@ -265,7 +265,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a",
|
||||
tenant_id="tenant-1",
|
||||
@ -293,7 +293,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Define constants for better readability
|
||||
MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab"
|
||||
CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3"
|
||||
@ -340,7 +340,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -614,7 +614,7 @@ def test_get_project_url_error(trace_instance):
|
||||
trace_instance.get_project_url()
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog):
|
||||
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de",
|
||||
tenant_id="66e8e918-472e-4b69-8051-12502c34fc07",
|
||||
|
||||
@ -267,14 +267,14 @@ class TestInit:
|
||||
with pytest.raises(ValueError, match="Weave login failed"):
|
||||
WeaveDataTrace(config)
|
||||
|
||||
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch):
|
||||
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FILES_URL is read from environment."""
|
||||
monkeypatch.setenv("FILES_URL", "http://files.example.com")
|
||||
config = _make_weave_config()
|
||||
instance = WeaveDataTrace(config)
|
||||
assert instance.file_base_url == "http://files.example.com"
|
||||
|
||||
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch):
|
||||
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FILES_URL defaults to http://127.0.0.1:5001."""
|
||||
monkeypatch.delenv("FILES_URL", raising=False)
|
||||
config = _make_weave_config()
|
||||
@ -302,7 +302,7 @@ class TestGetProjectUrl:
|
||||
url = instance.get_project_url()
|
||||
assert url == "https://wandb.ai/my-project"
|
||||
|
||||
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch):
|
||||
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Raises ValueError when exception occurs in get_project_url."""
|
||||
monkeypatch.setattr(trace_instance, "entity", None)
|
||||
monkeypatch.setattr(trace_instance, "project_name", None)
|
||||
@ -583,7 +583,7 @@ class TestFinishCall:
|
||||
|
||||
|
||||
class TestWorkflowTrace:
|
||||
def _setup_repo(self, monkeypatch, nodes=None):
|
||||
def _setup_repo(self, monkeypatch: pytest.MonkeyPatch, nodes=None):
|
||||
"""Helper to patch session/repo dependencies."""
|
||||
if nodes is None:
|
||||
nodes = []
|
||||
@ -599,7 +599,7 @@ class TestWorkflowTrace:
|
||||
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
|
||||
return repo
|
||||
|
||||
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Workflow trace with no nodes and no message_id."""
|
||||
self._setup_repo(monkeypatch, nodes=[])
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
@ -614,7 +614,7 @@ class TestWorkflowTrace:
|
||||
assert trace_instance.start_call.call_count == 1
|
||||
assert trace_instance.finish_call.call_count == 1
|
||||
|
||||
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Workflow trace with message_id creates both message and workflow runs."""
|
||||
self._setup_repo(monkeypatch, nodes=[])
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
@ -629,7 +629,7 @@ class TestWorkflowTrace:
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
assert trace_instance.finish_call.call_count == 2
|
||||
|
||||
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Workflow trace iterates node executions and creates node runs."""
|
||||
node = _make_node(
|
||||
id="node-1",
|
||||
@ -652,7 +652,7 @@ class TestWorkflowTrace:
|
||||
# workflow run + node run = 2 calls
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
|
||||
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""LLM node uses process_data prompts as inputs."""
|
||||
node = _make_node(
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
@ -680,7 +680,7 @@ class TestWorkflowTrace:
|
||||
# The key "messages" should be present (validator transforms the list)
|
||||
assert "messages" in node_run.inputs
|
||||
|
||||
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Non-LLM node uses node_execution.inputs directly."""
|
||||
node = _make_node(
|
||||
node_type=BuiltinNodeTypes.TOOL,
|
||||
@ -701,7 +701,7 @@ class TestWorkflowTrace:
|
||||
node_run = node_call_args[0][0]
|
||||
assert node_run.inputs.get("tool_input") == "val"
|
||||
|
||||
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Raises ValueError when app_id is missing from metadata."""
|
||||
monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
|
||||
@ -714,7 +714,7 @@ class TestWorkflowTrace:
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""start_time defaults to datetime.now() when None."""
|
||||
self._setup_repo(monkeypatch, nodes=[])
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
@ -727,7 +727,7 @@ class TestWorkflowTrace:
|
||||
|
||||
assert trace_instance.start_call.call_count == 1
|
||||
|
||||
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Node with created_at=None uses datetime.now()."""
|
||||
node = _make_node(created_at=None, elapsed_time=0.5)
|
||||
self._setup_repo(monkeypatch, nodes=[node])
|
||||
@ -740,7 +740,7 @@ class TestWorkflowTrace:
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
|
||||
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Chat mode LLM node adds ls_provider and ls_model_name to attributes."""
|
||||
node = _make_node(
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
@ -765,7 +765,7 @@ class TestWorkflowTrace:
|
||||
assert node_run.attributes.get("ls_provider") == "openai"
|
||||
assert node_run.attributes.get("ls_model_name") == "gpt-4"
|
||||
|
||||
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Nodes are sorted by created_at before processing."""
|
||||
node1 = _make_node(id="node-b", created_at=_dt() + timedelta(seconds=2))
|
||||
node2 = _make_node(id="node-a", created_at=_dt())
|
||||
@ -799,7 +799,7 @@ class TestMessageTrace:
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.start_call.assert_not_called()
|
||||
|
||||
def test_basic_message_trace(self, trace_instance, monkeypatch):
|
||||
def test_basic_message_trace(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace creates message run and llm child run."""
|
||||
monkeypatch.setattr(
|
||||
"dify_trace_weave.weave_trace.db.session.get",
|
||||
@ -816,7 +816,7 @@ class TestMessageTrace:
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
assert trace_instance.finish_call.call_count == 2
|
||||
|
||||
def test_message_trace_with_file_data(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_with_file_data(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace appends file URL to file_list."""
|
||||
file_data = MagicMock()
|
||||
file_data.url = "path/to/file.png"
|
||||
@ -839,7 +839,7 @@ class TestMessageTrace:
|
||||
message_run = trace_instance.start_call.call_args_list[0][0][0]
|
||||
assert "http://files.test/path/to/file.png" in message_run.file_list
|
||||
|
||||
def test_message_trace_with_end_user(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_with_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace looks up end user and sets end_user_id attribute."""
|
||||
end_user = MagicMock()
|
||||
end_user.session_id = "session-xyz"
|
||||
@ -862,7 +862,7 @@ class TestMessageTrace:
|
||||
message_run = trace_instance.start_call.call_args_list[0][0][0]
|
||||
assert message_run.attributes.get("end_user_id") == "session-xyz"
|
||||
|
||||
def test_message_trace_no_end_user(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_no_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace handles when from_end_user_id is None."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.get.return_value = None
|
||||
@ -880,7 +880,7 @@ class TestMessageTrace:
|
||||
trace_instance.message_trace(trace_info)
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
|
||||
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""trace_id falls back to message_id when trace_id is None."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.get.return_value = None
|
||||
@ -895,7 +895,7 @@ class TestMessageTrace:
|
||||
message_run = trace_instance.start_call.call_args_list[0][0][0]
|
||||
assert message_run.id == "msg-1"
|
||||
|
||||
def test_message_trace_file_list_none(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_file_list_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace handles file_list=None gracefully."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
@ -20,7 +20,7 @@ def test_validate_distance_function_rejects_unsupported_values():
|
||||
factory._validate_distance_function("dot_product")
|
||||
|
||||
|
||||
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch):
|
||||
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
@ -45,7 +45,7 @@ def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
|
||||
|
||||
|
||||
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch):
|
||||
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-2",
|
||||
|
||||
@ -83,7 +83,7 @@ def test_get_type_is_analyticdb():
|
||||
assert vector.get_type() == "analyticdb"
|
||||
|
||||
|
||||
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
|
||||
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AnalyticdbVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
|
||||
|
||||
@ -109,7 +109,7 @@ def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
|
||||
assert dataset.index_struct is not None
|
||||
|
||||
|
||||
def test_factory_builds_sql_config_when_host_is_present(monkeypatch):
|
||||
def test_factory_builds_sql_config_when_host_is_present(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AnalyticdbVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
|
||||
|
||||
@ -24,7 +24,7 @@ def _request_class(name: str):
|
||||
return _Request
|
||||
|
||||
|
||||
def _install_openapi_stubs(monkeypatch):
|
||||
def _install_openapi_stubs(monkeypatch: pytest.MonkeyPatch):
|
||||
gpdb_package = types.ModuleType("alibabacloud_gpdb20160503")
|
||||
gpdb_package.__path__ = []
|
||||
gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models")
|
||||
@ -130,7 +130,7 @@ def test_openapi_config_to_client_params():
|
||||
assert params["read_timeout"] == 60000
|
||||
|
||||
|
||||
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
|
||||
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
initialize_mock = MagicMock()
|
||||
monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock)
|
||||
@ -145,7 +145,7 @@ def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
|
||||
initialize_mock.assert_called_once_with()
|
||||
|
||||
|
||||
def test_initialize_skips_when_cached(monkeypatch):
|
||||
def test_initialize_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -164,7 +164,7 @@ def test_initialize_skips_when_cached(monkeypatch):
|
||||
vector._create_namespace_if_not_exists.assert_not_called()
|
||||
|
||||
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -184,7 +184,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_initialize_vector_database_calls_openapi_client(monkeypatch):
|
||||
def test_initialize_vector_database_calls_openapi_client(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -199,7 +199,7 @@ def test_initialize_vector_database_calls_openapi_client(monkeypatch):
|
||||
assert request.manager_account_password == "password"
|
||||
|
||||
|
||||
def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
|
||||
def test_create_namespace_creates_when_namespace_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -211,7 +211,7 @@ def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
|
||||
vector._client.create_namespace.assert_called_once()
|
||||
|
||||
|
||||
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
|
||||
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -222,7 +222,7 @@ def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
|
||||
vector._create_namespace_if_not_exists()
|
||||
|
||||
|
||||
def test_create_namespace_noop_when_namespace_exists(monkeypatch):
|
||||
def test_create_namespace_noop_when_namespace_exists(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -234,7 +234,7 @@ def test_create_namespace_noop_when_namespace_exists(monkeypatch):
|
||||
vector._client.create_namespace.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
|
||||
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -255,7 +255,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
|
||||
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -274,7 +274,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
|
||||
vector._client.create_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
|
||||
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -293,7 +293,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
|
||||
vector.create_collection_if_not_exists(embedding_dimension=512)
|
||||
|
||||
|
||||
def test_openapi_add_delete_and_search_methods(monkeypatch):
|
||||
def test_openapi_add_delete_and_search_methods(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -348,7 +348,7 @@ def test_openapi_add_delete_and_search_methods(monkeypatch):
|
||||
assert docs_by_text[0].page_content == "high"
|
||||
|
||||
|
||||
def test_text_exists_returns_false_when_matches_empty(monkeypatch):
|
||||
def test_text_exists_returns_false_when_matches_empty(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -361,7 +361,7 @@ def test_text_exists_returns_false_when_matches_empty(monkeypatch):
|
||||
assert vector.text_exists("missing-id") is False
|
||||
|
||||
|
||||
def test_openapi_delete_success(monkeypatch):
|
||||
def test_openapi_delete_success(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -372,7 +372,7 @@ def test_openapi_delete_success(monkeypatch):
|
||||
vector._client.delete_collection.assert_called_once()
|
||||
|
||||
|
||||
def test_openapi_delete_propagates_errors(monkeypatch):
|
||||
def test_openapi_delete_propagates_errors(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
|
||||
@ -53,7 +53,7 @@ def test_sql_config_rejects_min_connection_greater_than_max_connection():
|
||||
AnalyticdbVectorBySqlConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_initialize_skips_when_cache_exists(monkeypatch):
|
||||
def test_initialize_skips_when_cache_exists(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -70,7 +70,7 @@ def test_initialize_skips_when_cache_exists(monkeypatch):
|
||||
vector._initialize_vector_database.assert_not_called()
|
||||
|
||||
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -88,7 +88,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
sql_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch):
|
||||
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
@ -119,7 +119,7 @@ def test_get_cursor_context_manager_handles_connection_lifecycle():
|
||||
pool.putconn.assert_called_once_with(connection)
|
||||
|
||||
|
||||
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch):
|
||||
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
|
||||
@ -273,7 +273,7 @@ def test_delete_drops_table():
|
||||
cursor.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch):
|
||||
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
created_pool = MagicMock()
|
||||
|
||||
@ -288,7 +288,7 @@ def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypat
|
||||
assert vector.pool is created_pool
|
||||
|
||||
|
||||
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch):
|
||||
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
@ -326,7 +326,7 @@ def test_initialize_vector_database_handles_existing_database_and_search_config(
|
||||
assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list)
|
||||
|
||||
|
||||
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch):
|
||||
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
@ -353,7 +353,7 @@ def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(mon
|
||||
worker_connection.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch):
|
||||
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._collection_name = "collection"
|
||||
@ -381,7 +381,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
|
||||
sql_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch):
|
||||
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._collection_name = "collection"
|
||||
|
||||
@ -121,7 +121,7 @@ def _build_fake_pymochow_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baidu_module(monkeypatch):
|
||||
def baidu_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_pymochow_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
import dify_vdb_baidu.baidu_vector as module
|
||||
@ -254,7 +254,7 @@ def test_search_methods_delegate_to_database_table(baidu_module):
|
||||
assert vector._get_search_res.call_count == 2
|
||||
|
||||
|
||||
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch):
|
||||
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = baidu_module.BaiduVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
|
||||
monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
@ -279,7 +279,7 @@ def test_factory_initializes_collection_name_and_index_struct(baidu_module, monk
|
||||
assert dataset.index_struct is not None
|
||||
|
||||
|
||||
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch):
|
||||
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
init_client = MagicMock(return_value="client")
|
||||
init_database = MagicMock(return_value="database")
|
||||
monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client)
|
||||
@ -372,7 +372,7 @@ def test_get_search_result_handles_invalid_metadata_json(baidu_module):
|
||||
assert "document_id" not in docs[0].metadata
|
||||
|
||||
|
||||
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch):
|
||||
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
credentials = MagicMock(return_value="credentials")
|
||||
configuration = MagicMock(return_value="configuration")
|
||||
client_cls = MagicMock(return_value="client")
|
||||
@ -411,7 +411,7 @@ def test_init_database_raises_for_unknown_create_database_error(baidu_module):
|
||||
vector._init_database()
|
||||
|
||||
|
||||
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch):
|
||||
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = SimpleNamespace(
|
||||
@ -460,7 +460,7 @@ def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypat
|
||||
vector._wait_for_index_ready.assert_called_once_with(table, 3600)
|
||||
|
||||
|
||||
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch):
|
||||
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._db = MagicMock()
|
||||
@ -493,7 +493,7 @@ def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypat
|
||||
vector._create_table(3)
|
||||
|
||||
|
||||
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch):
|
||||
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = SimpleNamespace(
|
||||
@ -524,7 +524,9 @@ def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module,
|
||||
vector._create_table(3)
|
||||
|
||||
|
||||
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch):
|
||||
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(
|
||||
baidu_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
factory = baidu_module.BaiduVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -44,7 +44,7 @@ def _build_fake_chroma_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_module(monkeypatch):
|
||||
def chroma_module(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_chroma = _build_fake_chroma_modules()
|
||||
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
|
||||
import dify_vdb_chroma.chroma_vector as module
|
||||
@ -73,7 +73,7 @@ def test_chroma_config_to_params_builds_expected_payload(chroma_module):
|
||||
assert params["settings"].chroma_client_auth_credentials == "credentials"
|
||||
|
||||
|
||||
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch):
|
||||
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -173,7 +173,7 @@ def test_search_by_full_text_returns_empty_list(chroma_module):
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch):
|
||||
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = chroma_module.ChromaVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
|
||||
|
||||
@ -45,7 +45,7 @@ def _build_fake_clickzetta_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clickzetta_module(monkeypatch):
|
||||
def clickzetta_module(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
|
||||
import dify_vdb_clickzetta.clickzetta_vector as module
|
||||
|
||||
@ -218,7 +218,7 @@ def test_search_by_like_returns_documents_with_default_score(clickzetta_module):
|
||||
assert docs[0].metadata["score"] == 0.5
|
||||
|
||||
|
||||
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
|
||||
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = clickzetta_module.ClickzettaVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1")
|
||||
|
||||
@ -243,7 +243,7 @@ def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "collection"
|
||||
|
||||
|
||||
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
clickzetta_module.ClickzettaConnectionPool._instance = None
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
|
||||
@ -255,7 +255,7 @@ def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch
|
||||
assert "username:instance:service:workspace:cluster:dify" in key
|
||||
|
||||
|
||||
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
@ -274,7 +274,7 @@ def test_connection_pool_create_connection_retries_and_configures(clickzetta_mod
|
||||
pool._configure_connection.assert_called_once_with(connection)
|
||||
|
||||
|
||||
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
@ -318,7 +318,7 @@ def test_connection_pool_configure_connection_swallows_errors(clickzetta_module)
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
@ -360,7 +360,7 @@ def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monk
|
||||
assert pool._shutdown is True
|
||||
|
||||
|
||||
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
|
||||
pool._shutdown = False
|
||||
pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True))
|
||||
@ -384,7 +384,7 @@ def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module
|
||||
pool._cleanup_expired_connections.assert_called_once()
|
||||
|
||||
|
||||
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch):
|
||||
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
pool.get_connection.return_value = "conn"
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool))
|
||||
@ -405,7 +405,7 @@ def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypat
|
||||
assert vector._ensure_connection() == "conn"
|
||||
|
||||
|
||||
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch):
|
||||
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _Thread:
|
||||
def __init__(self, target, daemon):
|
||||
self.target = target
|
||||
@ -579,7 +579,7 @@ def test_create_inverted_index_branches(clickzetta_module):
|
||||
vector._create_inverted_index(cursor)
|
||||
|
||||
|
||||
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch):
|
||||
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._config.batch_size = 2
|
||||
@ -811,7 +811,7 @@ def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module):
|
||||
assert pool._shutdown is True
|
||||
|
||||
|
||||
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch):
|
||||
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
|
||||
pool._shutdown = False
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ def _build_fake_couchbase_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def couchbase_module(monkeypatch):
|
||||
def couchbase_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_couchbase_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -194,7 +194,7 @@ def test_init_sets_cluster_handles(couchbase_module):
|
||||
vector._cluster.wait_until_ready.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_create_collection_branches(couchbase_module, monkeypatch):
|
||||
def test_create_and_create_collection_branches(couchbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = _config(couchbase_module)
|
||||
@ -319,7 +319,7 @@ def test_search_methods_and_format_metadata(couchbase_module):
|
||||
assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2}
|
||||
|
||||
|
||||
def test_delete_collection_and_factory(couchbase_module, monkeypatch):
|
||||
def test_delete_collection_and_factory(couchbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
scopes = [
|
||||
SimpleNamespace(collections=[SimpleNamespace(name="other")]),
|
||||
|
||||
@ -28,7 +28,7 @@ def _build_fake_elasticsearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def elasticsearch_ja_module(monkeypatch):
|
||||
def elasticsearch_ja_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -39,7 +39,7 @@ def elasticsearch_ja_module(monkeypatch):
|
||||
return importlib.reload(ja_module)
|
||||
|
||||
|
||||
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
|
||||
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -57,7 +57,7 @@ def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
|
||||
elasticsearch_ja_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch):
|
||||
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -87,7 +87,7 @@ def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monk
|
||||
elasticsearch_ja_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch):
|
||||
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -38,7 +38,7 @@ def _build_fake_elasticsearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def elasticsearch_module(monkeypatch):
|
||||
def elasticsearch_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -287,7 +287,7 @@ def test_search_by_vector_and_full_text(elasticsearch_module):
|
||||
assert "bool" in query
|
||||
|
||||
|
||||
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
|
||||
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -331,7 +331,7 @@ def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
|
||||
elasticsearch_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch):
|
||||
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = elasticsearch_module.ElasticSearchVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -38,7 +38,7 @@ def _build_fake_hologres_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hologres_module(monkeypatch):
|
||||
def hologres_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_hologres_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -266,7 +266,7 @@ def test_delete_handles_existing_and_missing_tables(hologres_module):
|
||||
vector._client.drop_table.assert_called_once_with(vector.table_name)
|
||||
|
||||
|
||||
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch):
|
||||
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
@ -281,7 +281,7 @@ def test_create_collection_returns_early_when_cache_hits(hologres_module, monkey
|
||||
hologres_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch):
|
||||
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
@ -313,7 +313,7 @@ def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatc
|
||||
hologres_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch):
|
||||
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
@ -331,7 +331,7 @@ def test_create_collection_raises_when_table_never_becomes_ready(hologres_module
|
||||
hologres_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch):
|
||||
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = hologres_module.HologresVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -29,7 +29,7 @@ def _build_fake_elasticsearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huawei_module(monkeypatch):
|
||||
def huawei_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -155,7 +155,7 @@ def test_search_by_vector_and_full_text(huawei_module):
|
||||
assert docs[0].page_content == "text-hit"
|
||||
|
||||
|
||||
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch):
|
||||
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class FakeDocument:
|
||||
def __init__(self, page_content, vector, metadata):
|
||||
self.page_content = page_content
|
||||
@ -185,7 +185,7 @@ def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch
|
||||
assert docs == []
|
||||
|
||||
|
||||
def test_create_and_create_collection_paths(huawei_module, monkeypatch):
|
||||
def test_create_and_create_collection_paths(huawei_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -218,7 +218,7 @@ def test_create_and_create_collection_paths(huawei_module, monkeypatch):
|
||||
huawei_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_huawei_factory_branches(huawei_module, monkeypatch):
|
||||
def test_huawei_factory_branches(huawei_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = huawei_module.HuaweiCloudVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -23,7 +23,7 @@ def _build_fake_iris_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def iris_module(monkeypatch):
|
||||
def iris_module(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
|
||||
|
||||
import dify_vdb_iris.iris_vector as module
|
||||
@ -249,7 +249,7 @@ def test_iris_vector_init_get_cursor_and_create(iris_module):
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
|
||||
|
||||
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
|
||||
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch: pytest.MonkeyPatch):
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", _config(iris_module))
|
||||
|
||||
@ -297,7 +297,7 @@ def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
|
||||
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
|
||||
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch: pytest.MonkeyPatch):
|
||||
cfg = _config(iris_module, IRIS_TEXT_INDEX=True)
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", cfg)
|
||||
@ -344,7 +344,7 @@ def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
|
||||
assert vector_like.search_by_full_text("100%", top_k=1) == []
|
||||
|
||||
|
||||
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch):
|
||||
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch: pytest.MonkeyPatch):
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True))
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def _build_fake_opensearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lindorm_module(monkeypatch):
|
||||
def lindorm_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -100,7 +100,7 @@ def test_to_opensearch_params_and_init(lindorm_module):
|
||||
assert vector_ugc._routing == "route"
|
||||
|
||||
|
||||
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch):
|
||||
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
@ -301,7 +301,7 @@ def test_search_by_full_text_success_and_error(lindorm_module):
|
||||
vector.search_by_full_text("hello")
|
||||
|
||||
|
||||
def test_create_collection_paths(lindorm_module, monkeypatch):
|
||||
def test_create_collection_paths(lindorm_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
@ -331,7 +331,7 @@ def test_create_collection_paths(lindorm_module, monkeypatch):
|
||||
vector._client.indices.create.assert_not_called()
|
||||
|
||||
|
||||
def test_lindorm_factory_branches(lindorm_module, monkeypatch):
|
||||
def test_lindorm_factory_branches(lindorm_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = lindorm_module.LindormVectorStoreFactory()
|
||||
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200")
|
||||
|
||||
@ -32,7 +32,7 @@ def _build_fake_mo_vector_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def matrixone_module(monkeypatch):
|
||||
def matrixone_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_mo_vector_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -70,7 +70,7 @@ def test_matrixone_config_validation(matrixone_module, field, value, message):
|
||||
matrixone_module.MatrixoneConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch):
|
||||
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -86,7 +86,7 @@ def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module,
|
||||
matrixone_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch):
|
||||
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -146,7 +146,7 @@ def test_get_type_and_create_delegate_to_add_texts(matrixone_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch):
|
||||
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -165,7 +165,7 @@ def test_get_client_handles_full_text_index_creation_error(matrixone_module, mon
|
||||
matrixone_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch):
|
||||
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
vector.client = MagicMock()
|
||||
monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
@ -224,7 +224,7 @@ def test_search_by_vector_builds_documents(matrixone_module):
|
||||
assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}}
|
||||
|
||||
|
||||
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch):
|
||||
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = matrixone_module.MatrixoneVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -99,7 +99,7 @@ def _build_fake_pymilvus_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def milvus_module(monkeypatch):
|
||||
def milvus_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_pymilvus_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -327,7 +327,7 @@ def test_process_search_results_and_search_methods(milvus_module):
|
||||
assert "document_id" in vector._client.search.call_args.kwargs["filter"]
|
||||
|
||||
|
||||
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch):
|
||||
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -351,7 +351,7 @@ def test_create_collection_cache_and_existing_collection(milvus_module, monkeypa
|
||||
milvus_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch):
|
||||
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -385,7 +385,7 @@ def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch)
|
||||
assert call_kwargs["consistency_level"] == "Session"
|
||||
|
||||
|
||||
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch):
|
||||
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = milvus_module.MilvusVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -38,7 +38,7 @@ def _build_fake_clickhouse_connect_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def myscale_module(monkeypatch):
|
||||
def myscale_module(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_module = _build_fake_clickhouse_connect_module()
|
||||
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
|
||||
|
||||
@ -90,7 +90,7 @@ def test_delete_by_ids_short_circuits_on_empty_list(myscale_module):
|
||||
vector._client.command.assert_not_called()
|
||||
|
||||
|
||||
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch):
|
||||
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = myscale_module.MyScaleVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
@ -160,7 +160,7 @@ def test_create_collection_builds_expected_sql(myscale_module):
|
||||
assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql
|
||||
|
||||
|
||||
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
|
||||
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
docs = [
|
||||
|
||||
@ -53,7 +53,7 @@ def _build_fake_pyobvector_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_module(monkeypatch):
|
||||
def oceanbase_module(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
|
||||
|
||||
import dify_vdb_oceanbase.oceanbase_vector as module
|
||||
@ -208,7 +208,7 @@ def test_create_delegates_to_collection_and_insert(oceanbase_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -234,7 +234,7 @@ def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_mod
|
||||
vector.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -271,7 +271,7 @@ def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, mo
|
||||
oceanbase_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_error_paths(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_error_paths(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -308,7 +308,7 @@ def test_create_collection_error_paths(oceanbase_module, monkeypatch):
|
||||
vector._create_collection()
|
||||
|
||||
|
||||
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -517,7 +517,7 @@ def test_delete_success_and_exception(oceanbase_module):
|
||||
vector.delete()
|
||||
|
||||
|
||||
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch):
|
||||
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = oceanbase_module.OceanBaseVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opengauss_module(monkeypatch):
|
||||
def opengauss_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_psycopg2_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -88,7 +88,7 @@ def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_modu
|
||||
opengauss_module.OpenGaussConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
|
||||
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -99,7 +99,7 @@ def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
|
||||
assert vector.pool is pool
|
||||
|
||||
|
||||
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
|
||||
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -126,7 +126,7 @@ def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
|
||||
opengauss_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch):
|
||||
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -158,7 +158,7 @@ def test_search_by_vector_validates_top_k(opengauss_module):
|
||||
vector.search_by_vector([0.1, 0.2], top_k=0)
|
||||
|
||||
|
||||
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch):
|
||||
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
@ -200,7 +200,7 @@ def test_create_calls_collection_insert_and_index(opengauss_module):
|
||||
vector._create_index.assert_called_once_with(2)
|
||||
|
||||
|
||||
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
|
||||
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -220,7 +220,7 @@ def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
|
||||
opengauss_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch):
|
||||
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -245,7 +245,7 @@ def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, m
|
||||
assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql)
|
||||
|
||||
|
||||
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch):
|
||||
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
@ -342,7 +342,7 @@ def test_search_by_full_text_validates_top_k(opengauss_module):
|
||||
vector.search_by_full_text("query", top_k=0)
|
||||
|
||||
|
||||
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
|
||||
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
lock = MagicMock()
|
||||
@ -370,7 +370,7 @@ def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
|
||||
opengauss_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch):
|
||||
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = opengauss_module.OpenGaussFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -59,7 +59,7 @@ def _build_fake_opensearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opensearch_module(monkeypatch):
|
||||
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -95,7 +95,7 @@ class TestOpenSearchConfig:
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] == ("admin", "password")
|
||||
|
||||
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch):
|
||||
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _Session:
|
||||
def get_credentials(self):
|
||||
return "creds"
|
||||
|
||||
@ -58,7 +58,7 @@ def _build_fake_opensearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opensearch_module(monkeypatch):
|
||||
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -116,7 +116,7 @@ def test_config_validation_for_aws_auth_and_https_fields(opensearch_module):
|
||||
opensearch_module.OpenSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch):
|
||||
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _Session:
|
||||
def get_credentials(self):
|
||||
return "creds"
|
||||
@ -167,7 +167,7 @@ def test_init_and_create_delegate_calls(opensearch_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch):
|
||||
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es"))
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "1"}),
|
||||
@ -308,7 +308,7 @@ def test_search_by_full_text_and_filters(opensearch_module):
|
||||
assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}]
|
||||
|
||||
|
||||
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch):
|
||||
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -331,7 +331,7 @@ def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch)
|
||||
opensearch_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch):
|
||||
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = opensearch_module.OpenSearchVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -51,7 +51,7 @@ def _connection_with_cursor(cursor):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oracle_module(monkeypatch):
|
||||
def oracle_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_oracle_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -94,7 +94,7 @@ def test_oracle_config_validation_autonomous_requirements(oracle_module):
|
||||
)
|
||||
|
||||
|
||||
def test_init_and_get_type(oracle_module, monkeypatch):
|
||||
def test_init_and_get_type(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool))
|
||||
vector = oracle_module.OracleVector("collection_1", _config(oracle_module))
|
||||
@ -139,7 +139,7 @@ def test_numpy_converters_and_type_handlers(oracle_module):
|
||||
assert out_float64.dtype == numpy.float64
|
||||
|
||||
|
||||
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch):
|
||||
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
connect = MagicMock(return_value="connection")
|
||||
monkeypatch.setattr(oracle_module.oracledb, "connect", connect)
|
||||
|
||||
@ -173,7 +173,7 @@ def test_create_delegates_collection_and_insert(oracle_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch):
|
||||
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector.input_type_handler = MagicMock()
|
||||
@ -279,7 +279,7 @@ def _fake_nltk_module(*, missing_data=False):
|
||||
return nltk, nltk_corpus
|
||||
|
||||
|
||||
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch):
|
||||
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
@ -305,7 +305,7 @@ def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatc
|
||||
assert "doc_id_0" in en_params
|
||||
|
||||
|
||||
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch):
|
||||
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector._get_connection = MagicMock()
|
||||
@ -320,7 +320,7 @@ def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeyp
|
||||
vector.search_by_full_text("english query")
|
||||
|
||||
|
||||
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
|
||||
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -346,7 +346,9 @@ def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
|
||||
oracle_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch):
|
||||
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(
|
||||
oracle_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
factory = oracle_module.OracleVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -79,7 +79,7 @@ def _patch_both(monkeypatch, module, calls, execute_results=None):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pgvecto_module(monkeypatch):
|
||||
def pgvecto_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_pgvecto_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -126,7 +126,7 @@ def test_collection_base_has_expected_annotations(pgvecto_module):
|
||||
assert {"id", "text", "meta", "vector"} <= set(annotations)
|
||||
|
||||
|
||||
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
@ -145,7 +145,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
@ -169,7 +169,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
runtime_calls = []
|
||||
@ -241,7 +241,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
|
||||
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
@ -313,7 +313,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
assert vector.search_by_full_text("hello") == []
|
||||
|
||||
|
||||
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch):
|
||||
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
factory = module.PGVectoRSFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
|
||||
@ -336,7 +336,7 @@ def test_create_delegates_collection_creation_and_insert():
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch):
|
||||
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
@ -387,7 +387,7 @@ def test_text_get_and_delete_methods():
|
||||
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
|
||||
|
||||
|
||||
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch):
|
||||
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
@ -464,7 +464,7 @@ def test_search_by_full_text_branches_for_bigm_and_standard():
|
||||
assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0]
|
||||
|
||||
|
||||
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch):
|
||||
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = pgvector_module.PGVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -121,7 +121,7 @@ def _build_fake_qdrant_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_module(monkeypatch):
|
||||
def qdrant_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_qdrant_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -170,7 +170,7 @@ def test_init_and_basic_behaviour(qdrant_module):
|
||||
vector.add_texts.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_and_add_texts(qdrant_module, monkeypatch):
|
||||
def test_create_collection_and_add_texts(qdrant_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -288,7 +288,7 @@ def test_search_and_helper_methods(qdrant_module):
|
||||
assert doc.page_content == "doc"
|
||||
|
||||
|
||||
def test_qdrant_factory_paths(qdrant_module, monkeypatch):
|
||||
def test_qdrant_factory_paths(qdrant_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = qdrant_module.QdrantVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -59,7 +59,7 @@ def _patch_both(monkeypatch, module, session):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def relyt_module(monkeypatch):
|
||||
def relyt_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_relyt_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -97,7 +97,7 @@ def test_relyt_config_validation(relyt_module, field, value, message):
|
||||
relyt_module.RelytConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
|
||||
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
engine = MagicMock()
|
||||
monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine))
|
||||
vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1")
|
||||
@ -114,7 +114,7 @@ def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -142,7 +142,7 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
relyt_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch):
|
||||
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._group_id = "group-1"
|
||||
@ -212,7 +212,7 @@ def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module):
|
||||
|
||||
|
||||
# 3. delete_by_ids translates to uuids
|
||||
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
|
||||
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -225,7 +225,7 @@ def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
|
||||
|
||||
|
||||
# 4. text_exists True
|
||||
def test_text_exists_true(relyt_module, monkeypatch):
|
||||
def test_text_exists_true(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -236,7 +236,7 @@ def test_text_exists_true(relyt_module, monkeypatch):
|
||||
|
||||
|
||||
# 5. text_exists False
|
||||
def test_text_exists_false(relyt_module, monkeypatch):
|
||||
def test_text_exists_false(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -284,7 +284,7 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
|
||||
|
||||
|
||||
# 8. delete commits session
|
||||
def test_delete_drops_table(relyt_module, monkeypatch):
|
||||
def test_delete_drops_table(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -295,7 +295,7 @@ def test_delete_drops_table(relyt_module, monkeypatch):
|
||||
session.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = relyt_module.RelytVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -77,7 +77,7 @@ def _build_fake_tablestore_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tablestore_module(monkeypatch):
|
||||
def tablestore_module(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_module = _build_fake_tablestore_module()
|
||||
monkeypatch.setitem(sys.modules, "tablestore", fake_module)
|
||||
|
||||
@ -177,7 +177,7 @@ def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module):
|
||||
vector._delete_table_if_exist.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch):
|
||||
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -289,7 +289,7 @@ def test_write_row_and_search_helpers(tablestore_module):
|
||||
assert "score" not in docs[0].metadata
|
||||
|
||||
|
||||
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch):
|
||||
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = tablestore_module.TableStoreVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -136,7 +136,7 @@ def _build_fake_tencent_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tencent_module(monkeypatch):
|
||||
def tencent_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_tencent_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -187,7 +187,7 @@ def test_config_and_init_paths(tencent_module):
|
||||
assert vector._enable_hybrid_search is False
|
||||
|
||||
|
||||
def test_create_collection_branches(tencent_module, monkeypatch):
|
||||
def test_create_collection_branches(tencent_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = tencent_module.TencentVector("collection_1", _config(tencent_module))
|
||||
|
||||
lock = MagicMock()
|
||||
@ -279,7 +279,7 @@ def test_create_add_delete_and_search_behaviour(tencent_module):
|
||||
vector._client.drop_collection.assert_called_once()
|
||||
|
||||
|
||||
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch):
|
||||
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = tencent_module.TencentVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -46,7 +46,7 @@ def test_tidb_config_validation(tidb_module, field, value, message):
|
||||
tidb_module.TiDBVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
|
||||
def test_init_get_type_and_distance_func(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine"))
|
||||
|
||||
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2")
|
||||
@ -63,7 +63,7 @@ def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
|
||||
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
|
||||
|
||||
|
||||
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch):
|
||||
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
fake_tidb_vector = types.ModuleType("tidb_vector")
|
||||
fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy")
|
||||
|
||||
@ -107,7 +107,7 @@ def test_create_calls_collection_and_add_texts(tidb_module):
|
||||
assert vector._dimension == 2
|
||||
|
||||
|
||||
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
|
||||
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -127,7 +127,7 @@ def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
|
||||
tidb_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch):
|
||||
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -160,7 +160,7 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
tidb_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
|
||||
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _InsertStmt:
|
||||
def __init__(self, table):
|
||||
self.table = table
|
||||
@ -198,7 +198,7 @@ def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tidb_vector_with_session(tidb_module, monkeypatch):
|
||||
def tidb_vector_with_session(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
@ -354,7 +354,7 @@ def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module):
|
||||
|
||||
|
||||
# Test search_by_vector filters and scores
|
||||
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = [
|
||||
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2),
|
||||
@ -392,7 +392,7 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
|
||||
|
||||
# Test delete drops table
|
||||
def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
def test_delete_drops_table(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = None
|
||||
|
||||
@ -413,7 +413,7 @@ def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
||||
|
||||
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = tidb_module.TiDBVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -36,7 +36,7 @@ def _build_fake_upstash_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upstash_module(monkeypatch):
|
||||
def upstash_module(monkeypatch: pytest.MonkeyPatch):
|
||||
# Remove patched modules if present
|
||||
for modname in ["upstash_vector", "dify_vdb_upstash.upstash_vector"]:
|
||||
if modname in sys.modules:
|
||||
@ -65,7 +65,7 @@ def test_upstash_config_validation(upstash_module, field, value, message):
|
||||
upstash_module.UpstashVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_dimension(upstash_module, monkeypatch):
|
||||
def test_init_get_type_and_dimension(upstash_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
|
||||
|
||||
assert vector.get_type() == upstash_module.VectorType.UPSTASH
|
||||
@ -162,7 +162,7 @@ def test_search_by_vector_filter_threshold_and_delete(upstash_module):
|
||||
vector.index.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch):
|
||||
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = upstash_module.UpstashVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vastbase_module(monkeypatch):
|
||||
def vastbase_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_psycopg2_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -93,7 +93,7 @@ def test_vastbase_config_rejects_invalid_connection_window(vastbase_module):
|
||||
)
|
||||
|
||||
|
||||
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
|
||||
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -114,7 +114,7 @@ def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
|
||||
pool.putconn.assert_called_once_with(conn)
|
||||
|
||||
|
||||
def test_create_and_add_texts(vastbase_module, monkeypatch):
|
||||
def test_create_and_add_texts(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector._create_collection = MagicMock()
|
||||
@ -205,7 +205,7 @@ def test_search_by_vector_and_full_text(vastbase_module):
|
||||
assert full_docs[0].page_content == "full-text"
|
||||
|
||||
|
||||
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch):
|
||||
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -240,7 +240,7 @@ def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeyp
|
||||
vastbase_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch):
|
||||
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = vastbase_module.VastbaseVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -79,7 +79,7 @@ def _build_fake_vikingdb_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vikingdb_module(monkeypatch):
|
||||
def vikingdb_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_vikingdb_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -117,7 +117,7 @@ def test_init_get_type_and_has_checks(vikingdb_module):
|
||||
assert vector._has_index() is False
|
||||
|
||||
|
||||
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch):
|
||||
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -253,7 +253,7 @@ def test_delete_drops_index_and_collection_when_present(vikingdb_module):
|
||||
vector._client.drop_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch):
|
||||
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = vikingdb_module.VikingDBVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
@ -293,7 +293,9 @@ def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, mo
|
||||
("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"),
|
||||
],
|
||||
)
|
||||
def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message):
|
||||
def test_vikingdb_factory_raises_when_required_config_missing(
|
||||
vikingdb_module, monkeypatch: pytest.MonkeyPatch, field, message
|
||||
):
|
||||
factory = vikingdb_module.VikingDBVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None
|
||||
|
||||
@ -102,6 +102,7 @@ graphon = { git = "https://github.com/QuantumGhost/graphon", branch = "hitl-form
|
||||
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
|
||||
package = false
|
||||
override-dependencies = [
|
||||
"litellm>=1.83.7",
|
||||
"pyarrow>=18.0.0",
|
||||
]
|
||||
|
||||
|
||||
@ -107,15 +107,14 @@ class FileService:
|
||||
hash=hashlib.sha3_256(content).hexdigest(),
|
||||
source_url=source_url,
|
||||
)
|
||||
# The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
|
||||
# We can directly generate the `source_url` here before committing.
|
||||
if not upload_file.source_url:
|
||||
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
session.add(upload_file)
|
||||
session.commit()
|
||||
|
||||
if not upload_file.source_url:
|
||||
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -13,7 +13,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import ConversationFromSource
|
||||
from models.enums import AppStatus, ConversationFromSource
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
@ -28,7 +28,7 @@ class TestChatMessageApiPermissions:
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.status = AppStatus.NORMAL
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
@ -78,7 +78,7 @@ class TestChatMessageApiPermissions:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
@ -130,7 +130,7 @@ class TestChatMessageApiPermissions:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
|
||||
@ -14,7 +14,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.enums import AppStatus, FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, MessageFeedback
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@ -29,7 +29,7 @@ class TestFeedbackExportApi:
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.status = AppStatus.NORMAL
|
||||
app.name = "Test App"
|
||||
return app
|
||||
|
||||
@ -135,7 +135,7 @@ class TestFeedbackExportApi:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
@ -167,7 +167,13 @@ class TestFeedbackExportApi:
|
||||
mock_export_feedbacks.assert_called_once()
|
||||
|
||||
def test_feedback_export_csv_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
sample_feedback_data,
|
||||
):
|
||||
"""Test feedback export in CSV format."""
|
||||
|
||||
@ -202,7 +208,13 @@ class TestFeedbackExportApi:
|
||||
assert "text/csv" in response.content_type
|
||||
|
||||
def test_feedback_export_json_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
sample_feedback_data,
|
||||
):
|
||||
"""Test feedback export in JSON format."""
|
||||
|
||||
@ -246,7 +258,7 @@ class TestFeedbackExportApi:
|
||||
assert "application/json" in response.content_type
|
||||
|
||||
def test_feedback_export_with_filters(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with various filters."""
|
||||
|
||||
@ -287,7 +299,7 @@ class TestFeedbackExportApi:
|
||||
)
|
||||
|
||||
def test_feedback_export_invalid_date_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with invalid date format."""
|
||||
|
||||
@ -312,7 +324,7 @@ class TestFeedbackExportApi:
|
||||
assert "Parameter validation error" in response_json["error"]
|
||||
|
||||
def test_feedback_export_server_error(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with server error."""
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import AppStatus
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@ -25,7 +26,7 @@ class TestModelConfigResourcePermissions:
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.status = AppStatus.NORMAL
|
||||
app.app_model_config_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@ -73,7 +74,7 @@ class TestModelConfigResourcePermissions:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
@ -19,7 +21,7 @@ def _gen_var_stream() -> Generator[DatasourceMessage, None, None]:
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_accumulates_variables(mocker):
|
||||
def test_stream_node_events_accumulates_variables(mocker: MockerFixture):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream())
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
@ -44,7 +46,7 @@ class _GP:
|
||||
call_depth = 0
|
||||
|
||||
|
||||
def test_node_integration_minimal_stream(mocker):
|
||||
def test_node_integration_minimal_stream(mocker: MockerFixture):
|
||||
sys_d = {
|
||||
"sys": {
|
||||
"datasource_type": "online_document",
|
||||
|
||||
@ -2,6 +2,8 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
@ -71,7 +73,7 @@ def init_tool_node(config: dict):
|
||||
return node
|
||||
|
||||
|
||||
def test_tool_variable_invoke(monkeypatch):
|
||||
def test_tool_variable_invoke(monkeypatch: pytest.MonkeyPatch):
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
@ -106,7 +108,7 @@ def test_tool_variable_invoke(monkeypatch):
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
|
||||
def test_tool_mixed_invoke(monkeypatch):
|
||||
def test_tool_mixed_invoke(monkeypatch: pytest.MonkeyPatch):
|
||||
node = init_tool_node(
|
||||
config={
|
||||
"id": "1",
|
||||
|
||||
@ -11,7 +11,7 @@ from libs import helper as helper_module
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch):
|
||||
def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch: pytest.MonkeyPatch):
|
||||
prefix = f"test_rate_limit:{uuid.uuid4().hex}"
|
||||
limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60)
|
||||
key = limiter._get_key("203.0.113.10")
|
||||
|
||||
@ -6,7 +6,7 @@ from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
from models import Account, CreatorUserRole
|
||||
from models.enums import ConversationFromSource, MessageFileBelongsTo
|
||||
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -246,7 +246,7 @@ class TestAgentService:
|
||||
tool_input=json.dumps({"test_tool": {"input": "test_input"}}),
|
||||
observation=json.dumps({"test_tool": {"output": "test_output"}}),
|
||||
tokens=50,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db_session_with_containers.add(thought1)
|
||||
@ -294,7 +294,7 @@ class TestAgentService:
|
||||
agent_thoughts = self._create_test_agent_thoughts(db_session_with_containers, message)
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result structure
|
||||
assert result is not None
|
||||
@ -370,7 +370,7 @@ class TestAgentService:
|
||||
|
||||
# Execute the method under test with non-existent message
|
||||
with pytest.raises(ValueError, match="Message not found"):
|
||||
AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4())
|
||||
AgentService.get_agent_logs(app, conversation.id, fake.uuid4())
|
||||
|
||||
def test_get_agent_logs_with_end_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -451,7 +451,7 @@ class TestAgentService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -523,7 +523,7 @@ class TestAgentService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -561,14 +561,14 @@ class TestAgentService:
|
||||
tool_input=json.dumps({"error_tool": {"input": "test_input"}}),
|
||||
observation=json.dumps({"error_tool": {"output": "error_output"}}),
|
||||
tokens=50,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db_session_with_containers.add(thought_with_error)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -592,7 +592,7 @@ class TestAgentService:
|
||||
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -654,7 +654,7 @@ class TestAgentService:
|
||||
|
||||
# Execute the method under test
|
||||
with pytest.raises(ValueError, match="App model config not found"):
|
||||
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
def test_get_agent_logs_agent_config_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -673,7 +673,7 @@ class TestAgentService:
|
||||
|
||||
# Execute the method under test
|
||||
with pytest.raises(ValueError, match="Agent config not found"):
|
||||
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
def test_list_agent_providers_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -687,7 +687,7 @@ class TestAgentService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.list_agent_providers(str(account.id), str(app.tenant_id))
|
||||
result = AgentService.list_agent_providers(account.id, app.tenant_id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -696,7 +696,7 @@ class TestAgentService:
|
||||
|
||||
# Verify the mock was called correctly
|
||||
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
|
||||
mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id))
|
||||
mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(app.tenant_id)
|
||||
|
||||
def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
@ -710,7 +710,7 @@ class TestAgentService:
|
||||
provider_name = "test_provider"
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name)
|
||||
result = AgentService.get_agent_provider(account.id, app.tenant_id, provider_name)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -718,7 +718,7 @@ class TestAgentService:
|
||||
|
||||
# Verify the mock was called correctly
|
||||
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
|
||||
mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name)
|
||||
mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(app.tenant_id, provider_name)
|
||||
|
||||
def test_get_agent_provider_plugin_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -740,7 +740,7 @@ class TestAgentService:
|
||||
|
||||
# Execute the method under test
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name)
|
||||
AgentService.get_agent_provider(account.id, app.tenant_id, provider_name)
|
||||
|
||||
def test_get_agent_logs_with_complex_tool_data(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -796,14 +796,14 @@ class TestAgentService:
|
||||
{"tool1": {"output1": "result1"}, "tool2": {"output2": "result2"}, "tool3": {"output3": "result3"}}
|
||||
),
|
||||
tokens=100,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db_session_with_containers.add(complex_thought)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -891,14 +891,14 @@ class TestAgentService:
|
||||
observation=json.dumps({"file_tool": {"output": "test_output"}}),
|
||||
message_files=json.dumps(["file1", "file2"]),
|
||||
tokens=50,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db_session_with_containers.add(thought_with_files)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -926,7 +926,7 @@ class TestAgentService:
|
||||
mock_external_service_dependencies["current_user"].timezone = "Asia/Shanghai"
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -960,14 +960,14 @@ class TestAgentService:
|
||||
tool_input="", # Empty input
|
||||
observation="", # Empty observation
|
||||
tokens=50,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db_session_with_containers.add(empty_thought)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
@ -1001,14 +1001,14 @@ class TestAgentService:
|
||||
tool_input="invalid json", # Malformed JSON
|
||||
observation="invalid json", # Malformed JSON
|
||||
tokens=50,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db_session_with_containers.add(malformed_thought)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
result = AgentService.get_agent_logs(app, conversation.id, message.id)
|
||||
|
||||
# Verify the result - should handle malformed JSON gracefully
|
||||
assert result is not None
|
||||
|
||||
@ -198,7 +198,7 @@ class TestAppDslService:
|
||||
def test_check_version_compatibility_newer_version_returns_pending(self):
|
||||
assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING
|
||||
|
||||
def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch):
|
||||
def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0")
|
||||
assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING
|
||||
|
||||
@ -272,7 +272,9 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Missing app data" in result.error
|
||||
|
||||
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_import_app_yaml_error_returns_failed(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
def bad_safe_load(_content: str):
|
||||
raise yaml.YAMLError("bad")
|
||||
|
||||
@ -287,7 +289,9 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.error.startswith("Invalid YAML format:")
|
||||
|
||||
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_import_app_unexpected_error_returns_failed(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
"_create_or_update_app",
|
||||
@ -305,7 +309,9 @@ class TestAppDslService:
|
||||
|
||||
# ── Import: YAML URL ──────────────────────────────────────────────
|
||||
|
||||
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_import_app_yaml_url_fetch_error_returns_failed(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.ssrf_proxy,
|
||||
"get",
|
||||
@ -321,7 +327,9 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Error fetching YAML from URL: boom" in result.error
|
||||
|
||||
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_import_app_yaml_url_empty_content_returns_failed(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
response = MagicMock()
|
||||
response.content = b""
|
||||
response.raise_for_status.return_value = None
|
||||
@ -336,7 +344,9 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Empty content" in result.error
|
||||
|
||||
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_import_app_yaml_url_file_too_large_returns_failed(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
response = MagicMock()
|
||||
response.content = b"x" * (DSL_MAX_SIZE + 1)
|
||||
response.raise_for_status.return_value = None
|
||||
@ -379,7 +389,9 @@ class TestAppDslService:
|
||||
assert result.imported_dsl_version == "99.0.0"
|
||||
assert requested_urls == [yaml_url]
|
||||
|
||||
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_import_app_yaml_url_github_blob_rewrites_to_raw(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
|
||||
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
|
||||
yaml_bytes = _pending_yaml_content()
|
||||
@ -491,7 +503,7 @@ class TestAppDslService:
|
||||
|
||||
@pytest.mark.parametrize("has_workflow", [True, False])
|
||||
def test_import_app_legacy_versions_extract_dependencies(
|
||||
self, db_session_with_containers: Session, monkeypatch, has_workflow: bool
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch, has_workflow: bool
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
@ -554,7 +566,9 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "expired" in result.error
|
||||
|
||||
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_confirm_import_success_deletes_redis_key(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
import_id = str(uuid4())
|
||||
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
|
||||
|
||||
@ -614,7 +628,9 @@ class TestAppDslService:
|
||||
result = service.check_dependencies(app_model=app_model)
|
||||
assert result.leaked_dependencies == []
|
||||
|
||||
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_check_dependencies_calls_analysis_service(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
app_id = str(uuid4())
|
||||
pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id)
|
||||
redis_client.setex(
|
||||
@ -665,7 +681,9 @@ class TestAppDslService:
|
||||
with pytest.raises(ValueError, match="loss app mode"):
|
||||
service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock())
|
||||
|
||||
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch):
|
||||
def test_create_or_update_app_existing_app_updates_fields(
|
||||
self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
fixed_now = object()
|
||||
monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now)
|
||||
|
||||
@ -778,8 +796,8 @@ class TestAppDslService:
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Missing model_config"):
|
||||
service._create_or_update_app(
|
||||
app=_app_stub(mode=AppMode.CHAT.value),
|
||||
data={"app": {"mode": AppMode.CHAT.value}},
|
||||
app=_app_stub(mode=AppMode.CHAT),
|
||||
data={"app": {"mode": AppMode.CHAT}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
@ -794,7 +812,7 @@ class TestAppDslService:
|
||||
service._create_or_update_app(
|
||||
app=app,
|
||||
data={
|
||||
"app": {"mode": AppMode.CHAT.value},
|
||||
"app": {"mode": AppMode.CHAT},
|
||||
"model_config": {"model": {"provider": "openai"}},
|
||||
},
|
||||
account=account,
|
||||
@ -807,14 +825,14 @@ class TestAppDslService:
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
service._create_or_update_app(
|
||||
app=_app_stub(mode=AppMode.RAG_PIPELINE.value),
|
||||
data={"app": {"mode": AppMode.RAG_PIPELINE.value}},
|
||||
app=_app_stub(mode=AppMode.RAG_PIPELINE),
|
||||
data={"app": {"mode": AppMode.RAG_PIPELINE}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
# ── Export ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_export_dsl_delegates_by_mode(self, monkeypatch):
|
||||
def test_export_dsl_delegates_by_mode(self, monkeypatch: pytest.MonkeyPatch):
|
||||
workflow_calls: list[bool] = []
|
||||
model_calls: list[bool] = []
|
||||
monkeypatch.setattr(
|
||||
@ -836,14 +854,14 @@ class TestAppDslService:
|
||||
assert workflow_calls == [True]
|
||||
|
||||
chat_app = _app_stub(
|
||||
mode=AppMode.CHAT.value,
|
||||
mode=AppMode.CHAT,
|
||||
icon_type="emoji",
|
||||
app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}),
|
||||
)
|
||||
AppDslService.export_dsl(chat_app)
|
||||
assert model_calls == [True]
|
||||
|
||||
def test_export_dsl_preserves_icon_and_icon_type(self, monkeypatch):
|
||||
def test_export_dsl_preserves_icon_and_icon_type(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
"_append_workflow_export_data",
|
||||
@ -1011,7 +1029,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Workflow Export Data ───────────────────────────────────────────
|
||||
|
||||
def test_append_workflow_export_data_filters_and_overrides(self, monkeypatch):
|
||||
def test_append_workflow_export_data_filters_and_overrides(self, monkeypatch: pytest.MonkeyPatch):
|
||||
workflow_dict = {
|
||||
"graph": {
|
||||
"nodes": [
|
||||
@ -1111,7 +1129,7 @@ class TestAppDslService:
|
||||
assert nodes[5]["data"]["subscription_id"] == ""
|
||||
assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}]
|
||||
|
||||
def test_append_workflow_export_data_missing_workflow_raises(self, monkeypatch):
|
||||
def test_append_workflow_export_data_missing_workflow_raises(self, monkeypatch: pytest.MonkeyPatch):
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = None
|
||||
monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service)
|
||||
@ -1126,7 +1144,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Model Config Export Data ──────────────────────────────────────
|
||||
|
||||
def test_append_model_config_export_data_filters_credential_id(self, monkeypatch):
|
||||
def test_append_model_config_export_data_filters_credential_id(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
"_extract_dependencies_from_model_config",
|
||||
@ -1160,7 +1178,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Dependency Extraction ─────────────────────────────────────────
|
||||
|
||||
def test_extract_dependencies_from_workflow_graph_covers_all_node_types(self, monkeypatch):
|
||||
def test_extract_dependencies_from_workflow_graph_covers_all_node_types(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_tool_dependency",
|
||||
@ -1230,7 +1248,7 @@ class TestAppDslService:
|
||||
"model:m4",
|
||||
]
|
||||
|
||||
def test_extract_dependencies_from_workflow_graph_handles_exceptions(self, monkeypatch):
|
||||
def test_extract_dependencies_from_workflow_graph_handles_exceptions(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.ToolNodeData,
|
||||
"model_validate",
|
||||
@ -1241,7 +1259,7 @@ class TestAppDslService:
|
||||
)
|
||||
assert deps == []
|
||||
|
||||
def test_extract_dependencies_from_model_config_parses_providers(self, monkeypatch):
|
||||
def test_extract_dependencies_from_model_config_parses_providers(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_model_provider_dependency",
|
||||
@ -1264,7 +1282,7 @@ class TestAppDslService:
|
||||
)
|
||||
assert deps == ["model:p1", "model:p2", "tool:t1"]
|
||||
|
||||
def test_extract_dependencies_from_model_config_handles_exceptions(self, monkeypatch):
|
||||
def test_extract_dependencies_from_model_config_handles_exceptions(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"analyze_model_provider_dependency",
|
||||
@ -1278,7 +1296,7 @@ class TestAppDslService:
|
||||
def test_get_leaked_dependencies_empty_returns_empty(self):
|
||||
assert AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, []) == []
|
||||
|
||||
def test_get_leaked_dependencies_delegates(self, monkeypatch):
|
||||
def test_get_leaked_dependencies_delegates(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.DependenciesAnalysisService,
|
||||
"get_leaked_dependencies",
|
||||
@ -1289,7 +1307,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Encryption/Decryption ─────────────────────────────────────────
|
||||
|
||||
def test_encrypt_decrypt_dataset_id_respects_config(self, monkeypatch):
|
||||
def test_encrypt_decrypt_dataset_id_respects_config(self, monkeypatch: pytest.MonkeyPatch):
|
||||
tenant_id = _DEFAULT_TENANT_ID
|
||||
dataset_uuid = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
@ -1314,7 +1332,7 @@ class TestAppDslService:
|
||||
value = "00000000-0000-0000-0000-000000000000"
|
||||
assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id=_DEFAULT_TENANT_ID) == value
|
||||
|
||||
def test_decrypt_dataset_id_returns_none_on_invalid_data(self, monkeypatch):
|
||||
def test_decrypt_dataset_id_returns_none_on_invalid_data(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.dify_config,
|
||||
"DSL_EXPORT_ENCRYPT_DATASET_ID",
|
||||
@ -1322,7 +1340,7 @@ class TestAppDslService:
|
||||
)
|
||||
assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id=_DEFAULT_TENANT_ID) is None
|
||||
|
||||
def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(self, monkeypatch):
|
||||
def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.dify_config,
|
||||
"DSL_EXPORT_ENCRYPT_DATASET_ID",
|
||||
|
||||
@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.enums import AppStatus, CustomizeTokenStrategy
|
||||
from models.model import App, IconType, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
@ -1079,9 +1080,9 @@ class TestAppService:
|
||||
site.app_id = app.id
|
||||
site.code = fake.postalcode()
|
||||
site.title = fake.company()
|
||||
site.status = "normal"
|
||||
site.status = AppStatus.NORMAL
|
||||
site.default_language = "en-US"
|
||||
site.customize_token_strategy = "uuid"
|
||||
site.customize_token_strategy = CustomizeTokenStrategy.UUID
|
||||
|
||||
db_session_with_containers.add(site)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import TenantAccountRole
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAnnotation
|
||||
@ -22,7 +23,7 @@ from services.message_service import MessageService
|
||||
|
||||
class ConversationServiceIntegrationTestDataFactory:
|
||||
@staticmethod
|
||||
def create_app_and_account(db_session_with_containers):
|
||||
def create_app_and_account(db_session_with_containers: Session):
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
@ -41,7 +42,7 @@ class ConversationServiceIntegrationTestDataFactory:
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
@ -155,7 +156,7 @@ class ConversationServiceIntegrationTestDataFactory:
|
||||
total_price=Decimal(0),
|
||||
currency="USD",
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE,
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
|
||||
@ -25,7 +25,7 @@ from services.errors.conversation import (
|
||||
|
||||
class ConversationServiceVariableIntegrationFactory:
|
||||
@staticmethod
|
||||
def create_app_and_account(db_session_with_containers):
|
||||
def create_app_and_account(db_session_with_containers: Session):
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest.mock import create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
@ -119,13 +120,13 @@ def current_user_mock():
|
||||
yield current_user
|
||||
|
||||
|
||||
def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers):
|
||||
def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.get_document(dataset.id, None) is None
|
||||
|
||||
|
||||
def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers):
|
||||
def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset)
|
||||
|
||||
@ -135,7 +136,7 @@ def test_get_document_queries_by_dataset_and_document_id(db_session_with_contain
|
||||
assert result.id == document.id
|
||||
|
||||
|
||||
def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers):
|
||||
def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
result = DocumentService.get_documents_by_ids(dataset.id, [])
|
||||
@ -143,7 +144,7 @@ def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_cont
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers):
|
||||
def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt")
|
||||
doc_b = DocumentServiceIntegrationFactory.create_document(
|
||||
@ -158,13 +159,13 @@ def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers
|
||||
assert {document.id for document in result} == {doc_a.id, doc_b.id}
|
||||
|
||||
|
||||
def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers):
|
||||
def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.update_documents_need_summary(dataset.id, []) == 0
|
||||
|
||||
|
||||
def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers):
|
||||
def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
paragraph_doc = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -195,7 +196,7 @@ def test_update_documents_need_summary_updates_matching_non_qa_documents(db_sess
|
||||
assert refreshed_qa.need_summary is True
|
||||
|
||||
|
||||
def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers):
|
||||
def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
@ -215,7 +216,7 @@ def test_get_document_download_url_uses_signed_url_helper(db_session_with_contai
|
||||
get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers):
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -232,7 +233,9 @@ def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers):
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -248,7 +251,7 @@ def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers):
|
||||
def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -265,7 +268,9 @@ def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_sessio
|
||||
assert result == "99"
|
||||
|
||||
|
||||
def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers):
|
||||
def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -278,7 +283,7 @@ def test_get_upload_file_for_upload_file_document_raises_when_file_service_retur
|
||||
DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
|
||||
|
||||
def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers):
|
||||
def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
@ -296,7 +301,9 @@ def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session
|
||||
assert result.id == upload_file.id
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers):
|
||||
def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
with pytest.raises(NotFound, match="Document not found"):
|
||||
@ -307,7 +314,9 @@ def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_doc
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers):
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
@ -329,7 +338,9 @@ def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_a
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers):
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -345,7 +356,9 @@ def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers):
|
||||
def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
@ -395,7 +408,7 @@ def test_prepare_document_batch_download_zip_raises_not_found_for_missing_datase
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(
|
||||
db_session_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
current_user_mock,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(
|
||||
@ -418,7 +431,7 @@ def test_prepare_document_batch_download_zip_translates_permission_error_to_forb
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(
|
||||
db_session_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
current_user_mock,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(
|
||||
@ -461,7 +474,7 @@ def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_o
|
||||
assert download_name.endswith(".zip")
|
||||
|
||||
|
||||
def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers):
|
||||
def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
enabled_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -480,7 +493,9 @@ def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_co
|
||||
assert [document.id for document in result] == [enabled_document.id]
|
||||
|
||||
|
||||
def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers):
|
||||
def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
available_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -501,7 +516,7 @@ def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchive
|
||||
assert [document.id for document in result] == [available_document.id]
|
||||
|
||||
|
||||
def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers):
|
||||
def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
error_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
@ -526,7 +541,7 @@ def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db
|
||||
assert {document.id for document in result} == {error_document.id, paused_document.id}
|
||||
|
||||
|
||||
def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers):
|
||||
def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
batch = f"batch-{uuid4()}"
|
||||
matching_document = DocumentServiceIntegrationFactory.create_document(
|
||||
@ -549,7 +564,7 @@ def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_cont
|
||||
assert [document.id for document in result] == [matching_document.id]
|
||||
|
||||
|
||||
def test_get_document_file_detail_returns_upload_file(db_session_with_containers):
|
||||
def test_get_document_file_detail_returns_upload_file(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
@ -563,7 +578,7 @@ def test_get_document_file_detail_returns_upload_file(db_session_with_containers
|
||||
assert result.id == upload_file.id
|
||||
|
||||
|
||||
def test_delete_document_emits_signal_and_commits(db_session_with_containers):
|
||||
def test_delete_document_emits_signal_and_commits(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
@ -588,7 +603,7 @@ def test_delete_document_emits_signal_and_commits(db_session_with_containers):
|
||||
)
|
||||
|
||||
|
||||
def test_delete_documents_ignores_empty_input(db_session_with_containers):
|
||||
def test_delete_documents_ignores_empty_input(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
|
||||
@ -597,7 +612,7 @@ def test_delete_documents_ignores_empty_input(db_session_with_containers):
|
||||
delay.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers):
|
||||
def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX
|
||||
db_session_with_containers.commit()
|
||||
@ -637,14 +652,14 @@ def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_wi
|
||||
assert set(args[3]) == {upload_file_a.id, upload_file_b.id}
|
||||
|
||||
|
||||
def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers):
|
||||
def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3)
|
||||
|
||||
assert DocumentService.get_documents_position(dataset.id) == 4
|
||||
|
||||
|
||||
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers):
|
||||
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers: Session):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.get_documents_position(dataset.id) == 1
|
||||
|
||||
@ -2,6 +2,7 @@ import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models.dataset import Dataset, Document
|
||||
@ -58,7 +59,7 @@ def _create_document(
|
||||
return document
|
||||
|
||||
|
||||
def test_build_display_status_filters_available(db_session_with_containers):
|
||||
def test_build_display_status_filters_available(db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
available_doc = _create_document(
|
||||
db_session_with_containers,
|
||||
@ -97,7 +98,7 @@ def test_build_display_status_filters_available(db_session_with_containers):
|
||||
assert [row.id for row in rows] == [available_doc.id]
|
||||
|
||||
|
||||
def test_apply_display_status_filter_applies_when_status_present(db_session_with_containers):
|
||||
def test_apply_display_status_filter_applies_when_status_present(db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
waiting_doc = _create_document(
|
||||
db_session_with_containers,
|
||||
@ -121,7 +122,7 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with
|
||||
assert [row.id for row in rows] == [waiting_doc.id]
|
||||
|
||||
|
||||
def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_containers):
|
||||
def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
doc1 = _create_document(
|
||||
db_session_with_containers,
|
||||
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import TenantAccountRole
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import App, DefaultEndUserSessionID, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
@ -16,7 +17,7 @@ class TestEndUserServiceFactory:
|
||||
"""Factory class for creating test data and mock objects for end user service tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_app_and_account(db_session_with_containers):
|
||||
def create_app_and_account(db_session_with_containers: Session):
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
@ -35,7 +36,7 @@ class TestEndUserServiceFactory:
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
|
||||
@ -644,7 +644,7 @@ class TestFeatureService:
|
||||
assert result.max_plugin_package_size == 15728640
|
||||
|
||||
# Verify default license status
|
||||
assert result.license.status.value == "none"
|
||||
assert result.license.status == "none"
|
||||
assert result.license.expired_at == ""
|
||||
assert result.license.workspaces.enabled is False
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class TestFeedbackService:
|
||||
"""Test FeedbackService methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self, monkeypatch):
|
||||
def mock_db_session(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Mock database session."""
|
||||
mock_session = mock.Mock()
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
@ -122,7 +122,7 @@ class TestEmailDeliveryTestHandler:
|
||||
with pytest.raises(DeliveryTestUnsupportedError):
|
||||
handler.send_test(context=MagicMock(), method=MagicMock())
|
||||
|
||||
def test_send_test_feature_disabled(self, monkeypatch):
|
||||
def test_send_test_feature_disabled(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
@ -137,7 +137,7 @@ class TestEmailDeliveryTestHandler:
|
||||
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
def test_send_test_mail_not_inited(self, monkeypatch):
|
||||
def test_send_test_mail_not_inited(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
@ -154,7 +154,7 @@ class TestEmailDeliveryTestHandler:
|
||||
with pytest.raises(DeliveryTestError, match="Mail client is not initialized."):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
def test_send_test_no_recipients(self, monkeypatch):
|
||||
def test_send_test_no_recipients(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
@ -173,7 +173,7 @@ class TestEmailDeliveryTestHandler:
|
||||
with pytest.raises(DeliveryTestError, match="No recipients configured"):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
def test_send_test_success(self, monkeypatch):
|
||||
def test_send_test_success(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
@ -209,7 +209,7 @@ class TestEmailDeliveryTestHandler:
|
||||
assert kwargs["to"] == "test@example.com"
|
||||
assert "RENDERED_Subj" in kwargs["subject"]
|
||||
|
||||
def test_send_test_sanitizes_subject(self, monkeypatch):
|
||||
def test_send_test_sanitizes_subject(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.message_service import MessageService
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
@ -9,7 +10,7 @@ from tests.test_containers_integration_tests.helpers.execution_extra_content imp
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
def test_pagination_returns_extra_contents(db_session_with_containers):
|
||||
def test_pagination_returns_extra_contents(db_session_with_containers: Session):
|
||||
fixture = create_human_input_message_fixture(db_session_with_containers)
|
||||
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
|
||||
@ -16,7 +16,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
|
||||
from tasks.create_segment_to_index_task import create_segment_to_index_task
|
||||
@ -73,7 +73,7 @@ class TestCreateSegmentToIndexTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(account)
|
||||
@ -82,7 +82,7 @@ class TestCreateSegmentToIndexTask:
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
status=TenantStatus.NORMAL,
|
||||
plan="basic",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
|
||||
@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from tasks.document_indexing_task import (
|
||||
@ -54,7 +54,7 @@ class _TrackedSessionContext:
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_testcontainers_db(db_session_with_containers):
|
||||
def _ensure_testcontainers_db(db_session_with_containers: Session):
|
||||
"""Ensure this suite always runs on testcontainers infrastructure."""
|
||||
return db_session_with_containers
|
||||
|
||||
@ -121,12 +121,12 @@ class TestDatasetIndexingTaskIntegration:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
tenant = Tenant(name=fake.company(), status="normal")
|
||||
tenant = Tenant(name=fake.company(), status=TenantStatus.NORMAL)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from libs.email_i18n import EmailType
|
||||
from models import TenantStatus
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task
|
||||
|
||||
@ -55,7 +56,7 @@ class TestMailAccountDeletionTask:
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
status=TenantStatus.NORMAL,
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@ -18,6 +18,7 @@ from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from libs.email_i18n import EmailType
|
||||
from models import AccountStatus, TenantStatus
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||
|
||||
@ -91,7 +92,7 @@ class TestSendEmailCodeLoginMailTask:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(account)
|
||||
@ -120,7 +121,7 @@ class TestSendEmailCodeLoginMailTask:
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
plan="basic",
|
||||
status="normal",
|
||||
status=TenantStatus.NORMAL,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(tenant)
|
||||
|
||||
@ -31,7 +31,7 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(db_session_with_containers):
|
||||
def cleanup_database(db_session_with_containers: Session):
|
||||
db_session_with_containers.execute(delete(HumanInputFormRecipient))
|
||||
db_session_with_containers.execute(delete(HumanInputDelivery))
|
||||
db_session_with_containers.execute(delete(HumanInputForm))
|
||||
@ -43,7 +43,7 @@ def cleanup_database(db_session_with_containers):
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
def _create_workspace_member(db_session_with_containers):
|
||||
def _create_workspace_member(db_session_with_containers: Session):
|
||||
account = Account(
|
||||
email="owner@example.com",
|
||||
name="Owner",
|
||||
|
||||
@ -21,7 +21,7 @@ from tasks.remove_app_and_related_data_task import (
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(db_session_with_containers):
|
||||
def cleanup_database(db_session_with_containers: Session):
|
||||
db_session_with_containers.execute(delete(WorkflowDraftVariable))
|
||||
db_session_with_containers.execute(delete(WorkflowDraftVariableFile))
|
||||
db_session_with_containers.execute(delete(UploadFile))
|
||||
@ -30,7 +30,7 @@ def cleanup_database(db_session_with_containers):
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
def _create_tenant_and_app(db_session_with_containers):
|
||||
def _create_tenant_and_app(db_session_with_containers: Session):
|
||||
tenant = Tenant(name=f"test_tenant_{uuid.uuid4()}")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
"""Unit tests for the Markdown API docs generator."""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_generate_swagger_markdown_docs_module():
|
||||
api_dir = Path(__file__).resolve().parents[3]
|
||||
script_path = api_dir / "dev" / "generate_swagger_markdown_docs.py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location("generate_swagger_markdown_docs", script_path)
|
||||
assert spec
|
||||
assert spec.loader
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module) # type: ignore[attr-defined]
|
||||
return module
|
||||
|
||||
|
||||
def test_generate_markdown_docs_keeps_split_docs_and_merges_fastopenapi_into_console(tmp_path, monkeypatch):
|
||||
module = _load_generate_swagger_markdown_docs_module()
|
||||
swagger_dir = tmp_path / "openapi"
|
||||
markdown_dir = tmp_path / "markdown"
|
||||
stale_combined_doc = markdown_dir / "api-reference.md"
|
||||
markdown_dir.mkdir()
|
||||
stale_combined_doc.write_text("stale", encoding="utf-8")
|
||||
|
||||
def write_specs(output_dir: Path) -> list[Path]:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
paths = []
|
||||
for target in module.SPEC_TARGETS:
|
||||
path = output_dir / target.filename
|
||||
path.write_text("{}", encoding="utf-8")
|
||||
paths.append(path)
|
||||
return paths
|
||||
|
||||
def write_fastopenapi_specs(output_dir: Path) -> list[Path]:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / module.FASTOPENAPI_SPEC_TARGETS[0].filename
|
||||
path.write_text("{}", encoding="utf-8")
|
||||
return [path]
|
||||
|
||||
def convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
|
||||
markdown_path.write_text(f"# {spec_path.stem}\n\n## Routes\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(module, "generate_specs", write_specs)
|
||||
monkeypatch.setattr(module, "generate_fastopenapi_specs", write_fastopenapi_specs)
|
||||
monkeypatch.setattr(module, "_convert_spec_to_markdown", convert_spec_to_markdown)
|
||||
|
||||
written_paths = module.generate_markdown_docs(swagger_dir, markdown_dir)
|
||||
|
||||
assert [path.name for path in written_paths] == [
|
||||
"console-swagger.md",
|
||||
"web-swagger.md",
|
||||
"service-swagger.md",
|
||||
]
|
||||
assert not stale_combined_doc.exists()
|
||||
assert not list(swagger_dir.glob("*.json"))
|
||||
|
||||
console_markdown = (markdown_dir / "console-swagger.md").read_text(encoding="utf-8")
|
||||
assert "## FastOpenAPI Preview (OpenAPI 3.0)" in console_markdown
|
||||
assert "### fastopenapi-console-openapi" in console_markdown
|
||||
assert "#### Routes" in console_markdown
|
||||
assert "FastOpenAPI Preview" not in (markdown_dir / "web-swagger.md").read_text(encoding="utf-8")
|
||||
assert "FastOpenAPI Preview" not in (markdown_dir / "service-swagger.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_generate_markdown_docs_only_removes_generated_specs_from_separate_swagger_dir(tmp_path, monkeypatch):
|
||||
module = _load_generate_swagger_markdown_docs_module()
|
||||
swagger_dir = tmp_path / "swagger"
|
||||
markdown_dir = tmp_path / "markdown"
|
||||
swagger_dir.mkdir()
|
||||
existing_file = swagger_dir / "existing.txt"
|
||||
existing_file.write_text("keep me", encoding="utf-8")
|
||||
|
||||
def write_specs(output_dir: Path) -> list[Path]:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
paths = []
|
||||
for target in module.SPEC_TARGETS:
|
||||
path = output_dir / target.filename
|
||||
path.write_text("{}", encoding="utf-8")
|
||||
paths.append(path)
|
||||
return paths
|
||||
|
||||
def write_fastopenapi_specs(output_dir: Path) -> list[Path]:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / module.FASTOPENAPI_SPEC_TARGETS[0].filename
|
||||
path.write_text("{}", encoding="utf-8")
|
||||
return [path]
|
||||
|
||||
def convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
|
||||
markdown_path.write_text(f"# {spec_path.stem}\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(module, "generate_specs", write_specs)
|
||||
monkeypatch.setattr(module, "generate_fastopenapi_specs", write_fastopenapi_specs)
|
||||
monkeypatch.setattr(module, "_convert_spec_to_markdown", convert_spec_to_markdown)
|
||||
|
||||
module.generate_markdown_docs(swagger_dir, markdown_dir)
|
||||
|
||||
assert existing_file.read_text(encoding="utf-8") == "keep me"
|
||||
assert not list(swagger_dir.glob("*.json"))
|
||||
@ -6,6 +6,16 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _walk_values(value):
|
||||
yield value
|
||||
if isinstance(value, dict):
|
||||
for child in value.values():
|
||||
yield from _walk_values(child)
|
||||
elif isinstance(value, list):
|
||||
for child in value:
|
||||
yield from _walk_values(child)
|
||||
|
||||
|
||||
def _load_generate_swagger_specs_module():
|
||||
api_dir = Path(__file__).resolve().parents[3]
|
||||
script_path = api_dir / "dev" / "generate_swagger_specs.py"
|
||||
@ -35,3 +45,32 @@ def test_generate_specs_writes_console_web_and_service_swagger_files(tmp_path):
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
assert payload["swagger"] == "2.0"
|
||||
assert "paths" in payload
|
||||
|
||||
|
||||
def test_generate_specs_writes_swagger_with_resolvable_references_and_no_nulls(tmp_path):
|
||||
module = _load_generate_swagger_specs_module()
|
||||
|
||||
written_paths = module.generate_specs(tmp_path)
|
||||
|
||||
for path in written_paths:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
definitions = payload["definitions"]
|
||||
refs = {
|
||||
item["$ref"].removeprefix("#/definitions/")
|
||||
for item in _walk_values(payload)
|
||||
if isinstance(item, dict) and isinstance(item.get("$ref"), str)
|
||||
}
|
||||
|
||||
assert refs <= set(definitions)
|
||||
assert all(value is not None for value in _walk_values(payload))
|
||||
|
||||
|
||||
def test_generate_specs_is_idempotent(tmp_path):
|
||||
module = _load_generate_swagger_specs_module()
|
||||
|
||||
first_paths = module.generate_specs(tmp_path / "first")
|
||||
second_paths = module.generate_specs(tmp_path / "second")
|
||||
|
||||
assert [path.name for path in first_paths] == [path.name for path in second_paths]
|
||||
for first_path, second_path in zip(first_paths, second_paths):
|
||||
assert first_path.read_text(encoding="utf-8") == second_path.read_text(encoding="utf-8")
|
||||
|
||||
@ -57,7 +57,7 @@ class TestGuessFileInfoFromResponse:
|
||||
(False, "bin"),
|
||||
],
|
||||
)
|
||||
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
|
||||
def test_generated_filename_when_missing(self, monkeypatch: pytest.MonkeyPatch, magic_available, expected_ext):
|
||||
if magic_available:
|
||||
if helpers.magic is None:
|
||||
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
|
||||
@ -155,7 +155,7 @@ class TestMagicImportWarnings:
|
||||
)
|
||||
def test_magic_import_warning_per_platform(
|
||||
self,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
platform_name,
|
||||
expected_message,
|
||||
):
|
||||
|
||||
@ -17,6 +17,14 @@ class ProductModel(BaseModel):
|
||||
price: float
|
||||
|
||||
|
||||
class ChildModel(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
class ParentModel(BaseModel):
|
||||
child: ChildModel
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_console_ns():
|
||||
"""Mock the console_ns to avoid circular imports during test collection."""
|
||||
@ -64,6 +72,22 @@ def test_register_schema_model_passes_schema_from_pydantic():
|
||||
assert schema == expected_schema
|
||||
|
||||
|
||||
def test_register_schema_model_promotes_nested_pydantic_definitions():
|
||||
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_model(namespace, ParentModel)
|
||||
|
||||
called_schemas = {call.args[0]: call.args[1] for call in namespace.schema_model.call_args_list}
|
||||
parent_schema = ParentModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
|
||||
assert set(called_schemas) == {"ParentModel", "ChildModel"}
|
||||
assert "$defs" not in called_schemas["ParentModel"]
|
||||
assert called_schemas["ParentModel"]["properties"]["child"]["$ref"] == "#/definitions/ChildModel"
|
||||
assert called_schemas["ChildModel"] == parent_schema["$defs"]["ChildModel"]
|
||||
|
||||
|
||||
def test_register_schema_models_registers_multiple_models():
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
@ -77,7 +101,7 @@ def test_register_schema_models_registers_multiple_models():
|
||||
assert called_names == ["UserModel", "ProductModel"]
|
||||
|
||||
|
||||
def test_register_schema_models_calls_register_schema_model(monkeypatch):
|
||||
def test_register_schema_models_calls_register_schema_model(monkeypatch: pytest.MonkeyPatch):
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
@ -68,7 +68,7 @@ def _segment():
|
||||
)
|
||||
|
||||
|
||||
def test_get_segment_with_summary(monkeypatch):
|
||||
def test_get_segment_with_summary(monkeypatch: pytest.MonkeyPatch):
|
||||
segment = _segment()
|
||||
summary = SimpleNamespace(summary_content="summary")
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -35,7 +36,7 @@ def dataset():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
def bypass_decorators(mocker: MockerFixture):
|
||||
"""Bypass all decorators on the API method."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.setup_required",
|
||||
@ -56,7 +57,7 @@ def bypass_decorators(mocker):
|
||||
|
||||
|
||||
class TestHitTestingApi:
|
||||
def test_hit_testing_success(self, app, dataset, dataset_id):
|
||||
def test_hit_testing_success(self, app: Flask, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -99,7 +100,7 @@ class TestHitTestingApi:
|
||||
assert "records" in result
|
||||
assert result["records"] == []
|
||||
|
||||
def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id):
|
||||
def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -150,7 +151,7 @@ class TestHitTestingApi:
|
||||
assert result["query"] == payload["query"]
|
||||
assert result["records"] == records
|
||||
|
||||
def test_hit_testing_dataset_not_found(self, app, dataset_id):
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -175,7 +176,7 @@ class TestHitTestingApi:
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
|
||||
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
|
||||
def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -60,7 +61,7 @@ def metadata_id():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
def bypass_decorators(mocker: MockerFixture):
|
||||
"""Bypass setup/login/license decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.setup_required",
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
@ -31,7 +32,7 @@ def app():
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_auth_and_setup(mocker):
|
||||
def bypass_auth_and_setup(mocker: MockerFixture):
|
||||
"""Bypass setup/login/account decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.login_required",
|
||||
@ -48,7 +49,7 @@ def bypass_auth_and_setup(mocker):
|
||||
|
||||
|
||||
class TestWebsiteCrawlApi:
|
||||
def test_crawl_success(self, app, mocker):
|
||||
def test_crawl_success(self, app, mocker: MockerFixture):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -85,7 +86,7 @@ class TestWebsiteCrawlApi:
|
||||
assert status == 200
|
||||
assert result["job_id"] == "job-1"
|
||||
|
||||
def test_crawl_invalid_payload(self, app, mocker):
|
||||
def test_crawl_invalid_payload(self, app, mocker: MockerFixture):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -113,7 +114,7 @@ class TestWebsiteCrawlApi:
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
|
||||
method(api)
|
||||
|
||||
def test_crawl_service_error(self, app, mocker):
|
||||
def test_crawl_service_error(self, app, mocker: MockerFixture):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -150,7 +151,7 @@ class TestWebsiteCrawlApi:
|
||||
|
||||
|
||||
class TestWebsiteCrawlStatusApi:
|
||||
def test_get_status_success(self, app, mocker):
|
||||
def test_get_status_success(self, app, mocker: MockerFixture):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -181,7 +182,7 @@ class TestWebsiteCrawlStatusApi:
|
||||
assert status == 200
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_get_status_invalid_provider(self, app, mocker):
|
||||
def test_get_status_invalid_provider(self, app, mocker: MockerFixture):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -203,7 +204,7 @@ class TestWebsiteCrawlStatusApi:
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
|
||||
method(api, job_id)
|
||||
|
||||
def test_get_status_service_error(self, app, mocker):
|
||||
def test_get_status_service_error(self, app, mocker: MockerFixture):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
@ -16,7 +17,7 @@ class TestGetRagPipeline:
|
||||
with pytest.raises(ValueError, match="missing pipeline_id"):
|
||||
dummy_view()
|
||||
|
||||
def test_pipeline_not_found(self, mocker):
|
||||
def test_pipeline_not_found(self, mocker: MockerFixture):
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return "ok"
|
||||
@ -34,7 +35,7 @@ class TestGetRagPipeline:
|
||||
with pytest.raises(PipelineNotFoundError):
|
||||
dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
def test_pipeline_found_and_injected(self, mocker):
|
||||
def test_pipeline_found_and_injected(self, mocker: MockerFixture):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
pipeline.id = "pipeline-1"
|
||||
pipeline.tenant_id = "tenant-1"
|
||||
@ -57,7 +58,7 @@ class TestGetRagPipeline:
|
||||
|
||||
assert result is pipeline
|
||||
|
||||
def test_pipeline_id_removed_from_kwargs(self, mocker):
|
||||
def test_pipeline_id_removed_from_kwargs(self, mocker: MockerFixture):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
@ -79,7 +80,7 @@ class TestGetRagPipeline:
|
||||
|
||||
assert result == "ok"
|
||||
|
||||
def test_pipeline_id_cast_to_string(self, mocker):
|
||||
def test_pipeline_id_cast_to_string(self, mocker: MockerFixture):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
|
||||
@ -4,6 +4,7 @@ import uuid
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.console.admin import (
|
||||
@ -18,7 +19,7 @@ from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_only_edition_cloud(mocker):
|
||||
def bypass_only_edition_cloud(mocker: MockerFixture):
|
||||
"""
|
||||
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
|
||||
"""
|
||||
@ -29,7 +30,7 @@ def bypass_only_edition_cloud(mocker):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_auth(mocker):
|
||||
def mock_admin_auth(mocker: MockerFixture):
|
||||
"""
|
||||
Provide valid admin authentication for controller tests.
|
||||
"""
|
||||
@ -44,7 +45,7 @@ def mock_admin_auth(mocker):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_console_payload(mocker):
|
||||
def mock_console_payload(mocker: MockerFixture):
|
||||
payload = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "en-US",
|
||||
@ -62,7 +63,7 @@ def mock_console_payload(mocker):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_banner_payload(mocker):
|
||||
def mock_banner_payload(mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"flask_restx.namespace.Namespace.payload",
|
||||
new_callable=PropertyMock,
|
||||
@ -78,7 +79,7 @@ def mock_banner_payload(mocker):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(mocker):
|
||||
def mock_session_factory(mocker: MockerFixture):
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.add = Mock()
|
||||
@ -97,7 +98,7 @@ class TestDeleteExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = DeleteExploreBannerApi()
|
||||
|
||||
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
|
||||
def test_delete_banner_not_found(self, mocker: MockerFixture, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
@ -106,7 +107,7 @@ class TestDeleteExploreBannerApi:
|
||||
with pytest.raises(NotFound, match="is not found"):
|
||||
self.api.delete(uuid.uuid4())
|
||||
|
||||
def test_delete_banner_success(self, mocker, mock_admin_auth):
|
||||
def test_delete_banner_success(self, mocker: MockerFixture, mock_admin_auth):
|
||||
mock_banner = Mock()
|
||||
|
||||
mocker.patch(
|
||||
@ -126,7 +127,7 @@ class TestInsertExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreBannerApi()
|
||||
|
||||
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
|
||||
def test_insert_banner_success(self, mocker: MockerFixture, mock_admin_auth, mock_banner_payload):
|
||||
mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
@ -168,7 +169,7 @@ class TestInsertExploreAppApiDelete:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppApi()
|
||||
|
||||
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
|
||||
def test_delete_when_not_in_explore(self, mocker: MockerFixture, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
@ -183,7 +184,7 @@ class TestInsertExploreAppApiDelete:
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
|
||||
def test_delete_when_in_explore_with_trial_app(self, mocker: MockerFixture, mock_admin_auth):
|
||||
"""Test deleting an app from explore that has a trial app."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
@ -225,7 +226,7 @@ class TestInsertExploreAppApiDelete:
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is False
|
||||
|
||||
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
|
||||
def test_delete_with_installed_apps(self, mocker: MockerFixture, mock_admin_auth):
|
||||
"""Test deleting an app that has installed apps in other tenants."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
@ -270,7 +271,7 @@ class TestInsertExploreAppListApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppListApi()
|
||||
|
||||
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
|
||||
def test_app_not_found(self, mocker: MockerFixture, mock_admin_auth, mock_console_payload):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
@ -281,7 +282,7 @@ class TestInsertExploreAppListApi:
|
||||
|
||||
def test_create_recommended_app(
|
||||
self,
|
||||
mocker,
|
||||
mocker: MockerFixture,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
):
|
||||
@ -318,7 +319,9 @@ class TestInsertExploreAppListApi:
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
|
||||
def test_update_recommended_app(
|
||||
self, mocker: MockerFixture, mock_admin_auth, mock_console_payload, mock_session_factory
|
||||
):
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
@ -344,7 +347,7 @@ class TestInsertExploreAppListApi:
|
||||
|
||||
def test_site_data_overrides_payload(
|
||||
self,
|
||||
mocker,
|
||||
mocker: MockerFixture,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
@ -381,7 +384,7 @@ class TestInsertExploreAppListApi:
|
||||
|
||||
def test_create_trial_app_when_can_trial_enabled(
|
||||
self,
|
||||
mocker,
|
||||
mocker: MockerFixture,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
@ -413,7 +416,7 @@ class TestInsertExploreAppListApi:
|
||||
|
||||
def test_update_recommended_app_with_trial(
|
||||
self,
|
||||
mocker,
|
||||
mocker: MockerFixture,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
@ -450,7 +453,7 @@ class TestInsertExploreAppListApi:
|
||||
|
||||
def test_update_recommended_app_without_trial(
|
||||
self,
|
||||
mocker,
|
||||
mocker: MockerFixture,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
||||
@ -11,7 +12,7 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestFeatureApi:
|
||||
def test_get_tenant_features_success(self, mocker):
|
||||
def test_get_tenant_features_success(self, mocker: MockerFixture):
|
||||
from controllers.console.feature import FeatureApi
|
||||
|
||||
mocker.patch(
|
||||
@ -32,7 +33,7 @@ class TestFeatureApi:
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
def test_get_system_features_authenticated(self, mocker):
|
||||
def test_get_system_features_authenticated(self, mocker: MockerFixture):
|
||||
"""
|
||||
current_user.is_authenticated == True
|
||||
"""
|
||||
@ -56,7 +57,7 @@ class TestSystemFeatureApi:
|
||||
|
||||
assert result == {"features": {"sys_feature": True}}
|
||||
|
||||
def test_get_system_features_unauthenticated(self, mocker):
|
||||
def test_get_system_features_unauthenticated(self, mocker: MockerFixture):
|
||||
"""
|
||||
current_user.is_authenticated raises Unauthorized
|
||||
"""
|
||||
|
||||
@ -32,7 +32,7 @@ class TestDefaultModelApi:
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"model_type": ModelType.LLM.value},
|
||||
query_string={"model_type": ModelType.LLM},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
@ -53,7 +53,7 @@ class TestDefaultModelApi:
|
||||
payload = {
|
||||
"model_settings": [
|
||||
{
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
}
|
||||
@ -77,7 +77,7 @@ class TestDefaultModelApi:
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
|
||||
app.test_request_context("/", query_string={"model_type": ModelType.LLM}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
@ -113,7 +113,7 @@ class TestModelProviderModelApi:
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
"load_balancing": {
|
||||
"configs": [{"weight": 1}],
|
||||
"enabled": True,
|
||||
@ -139,7 +139,7 @@ class TestModelProviderModelApi:
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
}
|
||||
|
||||
with (
|
||||
@ -180,7 +180,7 @@ class TestModelProviderModelCredentialApi:
|
||||
"/",
|
||||
query_string={
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
@ -208,7 +208,7 @@ class TestModelProviderModelCredentialApi:
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
@ -229,7 +229,7 @@ class TestModelProviderModelCredentialApi:
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
|
||||
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
|
||||
@ -248,7 +248,7 @@ class TestModelProviderModelCredentialApi:
|
||||
|
||||
payload = {
|
||||
"model": "gpt",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
}
|
||||
|
||||
@ -269,7 +269,7 @@ class TestModelProviderModelCredentialSwitchApi:
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
"credential_id": "abc",
|
||||
}
|
||||
|
||||
@ -293,7 +293,7 @@ class TestModelEnableDisableApis:
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
}
|
||||
|
||||
with (
|
||||
@ -314,7 +314,7 @@ class TestModelEnableDisableApis:
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
}
|
||||
|
||||
with (
|
||||
@ -337,7 +337,7 @@ class TestModelProviderModelValidateApi:
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
@ -360,7 +360,7 @@ class TestModelProviderModelValidateApi:
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"model_type": ModelType.LLM.value,
|
||||
"model_type": ModelType.LLM,
|
||||
"credentials": {},
|
||||
}
|
||||
|
||||
@ -412,7 +412,7 @@ class TestParameterAndAvailableModels:
|
||||
):
|
||||
service_mock.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
result = method(api, ModelType.LLM)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -442,6 +442,6 @@ class TestParameterAndAvailableModels:
|
||||
):
|
||||
service.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
result = method(api, ModelType.LLM)
|
||||
|
||||
assert result["data"] == []
|
||||
|
||||
@ -189,7 +189,7 @@ class TestGetUserTenant:
|
||||
"""Test get_user_tenant decorator"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that decorator injects tenant_model and user_model into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@ -244,7 +244,9 @@ class TestGetUserTenant:
|
||||
protected_view()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
def test_should_use_default_session_id_when_user_id_empty(
|
||||
self, mock_tenant_class, app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Test that default session ID is used when user_id is empty string"""
|
||||
|
||||
# Arrange
|
||||
|
||||
@ -340,7 +340,7 @@ class TestConversationAppModeValidation:
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
[
|
||||
AppMode.CHAT.value,
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
],
|
||||
@ -365,7 +365,7 @@ class TestConversationAppModeValidation:
|
||||
app raises NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.COMPLETION.value
|
||||
app.mode = AppMode.COMPLETION
|
||||
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}
|
||||
@ -498,7 +498,7 @@ class TestConversationApiController:
|
||||
def test_list_not_chat(self, app) -> None:
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations", method="GET"):
|
||||
@ -531,7 +531,7 @@ class TestConversationApiController:
|
||||
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
@ -546,7 +546,7 @@ class TestConversationDetailApiController:
|
||||
def test_delete_not_chat(self, app) -> None:
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1", method="DELETE"):
|
||||
@ -562,7 +562,7 @@ class TestConversationDetailApiController:
|
||||
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1", method="DELETE"):
|
||||
@ -580,7 +580,7 @@ class TestConversationRenameApiController:
|
||||
|
||||
api = ConversationRenameApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
@ -596,7 +596,7 @@ class TestConversationVariablesApiController:
|
||||
def test_not_chat(self, app) -> None:
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/conversations/1/variables", method="GET"):
|
||||
@ -612,7 +612,7 @@ class TestConversationVariablesApiController:
|
||||
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
@ -645,7 +645,7 @@ class TestConversationVariablesApiController:
|
||||
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
@ -671,7 +671,7 @@ class TestConversationVariableDetailApiController:
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
@ -697,7 +697,7 @@ class TestConversationVariableDetailApiController:
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
@ -731,7 +731,7 @@ class TestConversationVariableDetailApiController:
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import Mock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from controllers.service_api.end_user.end_user import EndUserApi
|
||||
from controllers.service_api.end_user.error import EndUserNotFoundError
|
||||
@ -21,7 +22,9 @@ class TestEndUserApi:
|
||||
app.tenant_id = str(uuid4())
|
||||
return app
|
||||
|
||||
def test_get_end_user_returns_all_attributes(self, mocker, resource: EndUserApi, app_model: App) -> None:
|
||||
def test_get_end_user_returns_all_attributes(
|
||||
self, mocker: MockerFixture, resource: EndUserApi, app_model: App
|
||||
) -> None:
|
||||
end_user = Mock(spec=EndUser)
|
||||
end_user.id = str(uuid4())
|
||||
end_user.tenant_id = app_model.tenant_id
|
||||
@ -54,7 +57,7 @@ class TestEndUserApi:
|
||||
assert result["created_at"].startswith("2024-01-01T00:00:00")
|
||||
assert result["updated_at"].startswith("2024-01-02T00:00:00")
|
||||
|
||||
def test_get_end_user_not_found(self, mocker, resource: EndUserApi, app_model: App) -> None:
|
||||
def test_get_end_user_not_found(self, mocker: MockerFixture, resource: EndUserApi, app_model: App) -> None:
|
||||
mocker.patch("controllers.service_api.end_user.end_user.EndUserService.get_end_user_by_id", return_value=None)
|
||||
|
||||
with pytest.raises(EndUserNotFoundError):
|
||||
|
||||
@ -12,12 +12,13 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_action_class(mocker):
|
||||
def mock_action_class(mocker: MockerFixture):
|
||||
mock_action = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
|
||||
@ -213,7 +214,9 @@ class TestInvoke:
|
||||
(None, None, "msg"),
|
||||
],
|
||||
)
|
||||
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
|
||||
def test_invoke_optional_arguments(
|
||||
self, strategy, mocker: MockerFixture, conversation_id, app_id, message_id
|
||||
) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
import core.agent.base_agent_runner as module
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@ -13,7 +14,7 @@ from core.agent.base_agent_runner import BaseAgentRunner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(mocker):
|
||||
def mock_db_session(mocker: MockerFixture):
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
return session
|
||||
@ -41,13 +42,13 @@ def runner(mocker, mock_db_session):
|
||||
|
||||
|
||||
class TestRepack:
|
||||
def test_sets_empty_if_none(self, runner, mocker):
|
||||
def test_sets_empty_if_none(self, runner, mocker: MockerFixture):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = None
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == ""
|
||||
|
||||
def test_keeps_existing(self, runner, mocker):
|
||||
def test_keeps_existing(self, runner, mocker: MockerFixture):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = "abc"
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
@ -60,7 +61,7 @@ class TestRepack:
|
||||
|
||||
|
||||
class TestUpdatePromptTool:
|
||||
def build_param(self, mocker, **kwargs):
|
||||
def build_param(self, mocker: MockerFixture, **kwargs):
|
||||
p = mocker.MagicMock()
|
||||
p.form = kwargs.get("form")
|
||||
|
||||
@ -75,7 +76,7 @@ class TestUpdatePromptTool:
|
||||
p.required = kwargs.get("required", False)
|
||||
return p
|
||||
|
||||
def test_skip_non_llm(self, runner, mocker):
|
||||
def test_skip_non_llm(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form="NOT_LLM")
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
@ -86,7 +87,7 @@ class TestUpdatePromptTool:
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_enum_and_required(self, runner, mocker):
|
||||
def test_enum_and_required(self, runner, mocker: MockerFixture):
|
||||
option = mocker.MagicMock(value="opt1")
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
@ -104,7 +105,7 @@ class TestUpdatePromptTool:
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert "p1" in result.parameters["required"]
|
||||
|
||||
def test_skip_file_type_param(self, runner, mocker):
|
||||
def test_skip_file_type_param(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
|
||||
param.type = module.ToolParameter.ToolParameterType.FILE
|
||||
@ -116,7 +117,7 @@ class TestUpdatePromptTool:
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_duplicate_required_not_duplicated(self, runner, mocker):
|
||||
def test_duplicate_required_not_duplicated(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = self.build_param(
|
||||
@ -141,7 +142,7 @@ class TestUpdatePromptTool:
|
||||
|
||||
|
||||
class TestCreateAgentThought:
|
||||
def test_with_files(self, runner, mock_db_session, mocker):
|
||||
def test_with_files(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_thought = mocker.MagicMock(id=10)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
@ -149,7 +150,7 @@ class TestCreateAgentThought:
|
||||
assert result == "10"
|
||||
assert runner.agent_thought_count == 1
|
||||
|
||||
def test_without_files(self, runner, mock_db_session, mocker):
|
||||
def test_without_files(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_thought = mocker.MagicMock(id=11)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
@ -163,7 +164,7 @@ class TestCreateAgentThought:
|
||||
|
||||
|
||||
class TestSaveAgentThought:
|
||||
def setup_agent(self, mocker):
|
||||
def setup_agent(self, mocker: MockerFixture):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;tool2"
|
||||
agent.tool_labels = {}
|
||||
@ -175,7 +176,7 @@ class TestSaveAgentThought:
|
||||
with pytest.raises(ValueError):
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
def test_full_update(self, runner, mock_db_session, mocker):
|
||||
def test_full_update(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
@ -210,7 +211,7 @@ class TestSaveAgentThought:
|
||||
assert agent.tokens == 3
|
||||
assert "tool1" in json.loads(agent.tool_labels_str)
|
||||
|
||||
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
|
||||
def test_label_fallback_when_none(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
agent.tool = "unknown_tool"
|
||||
mock_db_session.scalar.return_value = agent
|
||||
@ -220,7 +221,7 @@ class TestSaveAgentThought:
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "unknown_tool" in labels
|
||||
|
||||
def test_json_failure_paths(self, runner, mock_db_session, mocker):
|
||||
def test_json_failure_paths(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
@ -241,13 +242,13 @@ class TestSaveAgentThought:
|
||||
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_messages_ids_none(self, runner, mock_db_session, mocker):
|
||||
def test_messages_ids_none(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
|
||||
def test_success_dict_serialization(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
@ -273,19 +274,19 @@ class TestSaveAgentThought:
|
||||
|
||||
|
||||
class TestOrganizeUserPrompt:
|
||||
def test_no_files(self, runner, mock_db_session, mocker):
|
||||
def test_no_files(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_with_files_no_config(self, runner, mock_db_session, mocker):
|
||||
def test_with_files_no_config(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
|
||||
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
@ -305,27 +306,27 @@ class TestOrganizeUserPrompt:
|
||||
|
||||
|
||||
class TestOrganizeHistory:
|
||||
def test_empty(self, runner, mock_db_session, mocker):
|
||||
def test_empty(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_answer_only(self, runner, mock_db_session, mocker):
|
||||
def test_with_answer_only(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
|
||||
|
||||
def test_skip_current_message(self, runner, mock_db_session, mocker):
|
||||
def test_skip_current_message(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
|
||||
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input="invalid",
|
||||
@ -341,7 +342,7 @@ class TestOrganizeHistory:
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
|
||||
def test_empty_tool_name_split(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(tool=";", thought="thinking")
|
||||
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
@ -350,7 +351,7 @@ class TestOrganizeHistory:
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
|
||||
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=json.dumps({"tool1": {"x": 1}}),
|
||||
@ -379,7 +380,7 @@ class TestOrganizeHistory:
|
||||
|
||||
|
||||
class TestConvertToolToPromptMessageTool:
|
||||
def test_basic_conversion(self, runner, mocker):
|
||||
def test_basic_conversion(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
runtime_param = mocker.MagicMock()
|
||||
@ -404,7 +405,7 @@ class TestConvertToolToPromptMessageTool:
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert entity == tool_entity
|
||||
|
||||
def test_full_conversion_multiple_params(self, runner, mocker):
|
||||
def test_full_conversion_multiple_params(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
# LLM param with input_schema override
|
||||
@ -441,7 +442,7 @@ class TestConvertToolToPromptMessageTool:
|
||||
|
||||
|
||||
class TestInitPromptToolsExtended:
|
||||
def test_agent_tool_branch(self, runner, mocker):
|
||||
def test_agent_tool_branch(self, runner, mocker: MockerFixture):
|
||||
agent_tool = mocker.MagicMock(tool_name="agent_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
|
||||
@ -449,7 +450,7 @@ class TestInitPromptToolsExtended:
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert "agent_tool" in tools
|
||||
|
||||
def test_exception_in_conversion(self, runner, mocker):
|
||||
def test_exception_in_conversion(self, runner, mocker: MockerFixture):
|
||||
agent_tool = mocker.MagicMock(tool_name="bad_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
|
||||
@ -464,7 +465,7 @@ class TestInitPromptToolsExtended:
|
||||
|
||||
|
||||
class TestAdditionalCoverage:
|
||||
def test_update_prompt_with_input_schema(self, runner, mocker):
|
||||
def test_update_prompt_with_input_schema(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = mocker.MagicMock()
|
||||
@ -487,7 +488,7 @@ class TestAdditionalCoverage:
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"]["p1"]["type"] == "number"
|
||||
|
||||
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
|
||||
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {"tool1": {"en_US": "existing"}}
|
||||
@ -498,7 +499,7 @@ class TestAdditionalCoverage:
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert labels["tool1"]["en_US"] == "existing"
|
||||
|
||||
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
|
||||
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
@ -508,7 +509,7 @@ class TestAdditionalCoverage:
|
||||
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
|
||||
assert agent.tool_meta_str == "meta_string"
|
||||
|
||||
def test_convert_dataset_retriever_tool(self, runner, mocker):
|
||||
def test_convert_dataset_retriever_tool(self, runner, mocker: MockerFixture):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
@ -525,7 +526,7 @@ class TestAdditionalCoverage:
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
assert prompt is not None
|
||||
|
||||
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
|
||||
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
|
||||
file_config = mocker.MagicMock()
|
||||
@ -544,7 +545,7 @@ class TestAdditionalCoverage:
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result is not None
|
||||
|
||||
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
|
||||
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(tool=None, thought="thinking")
|
||||
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
@ -554,7 +555,7 @@ class TestAdditionalCoverage:
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
|
||||
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1;tool2",
|
||||
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
|
||||
@ -572,7 +573,7 @@ class TestAdditionalCoverage:
|
||||
|
||||
# ================= Additional Surgical Coverage =================
|
||||
|
||||
def test_convert_tool_select_enum_branch(self, runner, mocker):
|
||||
def test_convert_tool_select_enum_branch(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
@ -599,7 +600,7 @@ class TestAdditionalCoverage:
|
||||
|
||||
|
||||
class TestConvertDatasetRetrieverTool:
|
||||
def test_required_param_added(self, runner, mocker):
|
||||
def test_required_param_added(self, runner, mocker: MockerFixture):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
@ -619,7 +620,7 @@ class TestConvertDatasetRetrieverTool:
|
||||
|
||||
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker: MockerFixture):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
@ -662,7 +663,7 @@ class TestBaseAgentRunnerInit:
|
||||
|
||||
|
||||
class TestBaseAgentRunnerCoverage:
|
||||
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
|
||||
def test_convert_tool_skips_non_llm_param(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
@ -680,7 +681,7 @@ class TestBaseAgentRunnerCoverage:
|
||||
|
||||
assert prompt_tool.parameters["properties"] == {}
|
||||
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker: MockerFixture):
|
||||
dataset_tool = mocker.MagicMock()
|
||||
dataset_tool.entity.identity.name = "ds"
|
||||
runner.dataset_tools = [dataset_tool]
|
||||
@ -692,7 +693,7 @@ class TestBaseAgentRunnerCoverage:
|
||||
assert tools["ds"] == dataset_tool
|
||||
assert len(prompt_tools) == 1
|
||||
|
||||
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
|
||||
def test_update_prompt_message_tool_select_enum(self, runner, mocker: MockerFixture):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
@ -716,7 +717,7 @@ class TestBaseAgentRunnerCoverage:
|
||||
|
||||
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
|
||||
|
||||
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
|
||||
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
@ -754,7 +755,7 @@ class TestBaseAgentRunnerCoverage:
|
||||
assert isinstance(agent.observation, str)
|
||||
assert isinstance(agent.tool_meta_str, str)
|
||||
|
||||
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
|
||||
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;;"
|
||||
agent.tool_labels = {}
|
||||
@ -768,7 +769,7 @@ class TestBaseAgentRunnerCoverage:
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "" not in labels
|
||||
|
||||
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
|
||||
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
|
||||
@ -778,7 +779,7 @@ class TestBaseAgentRunnerCoverage:
|
||||
|
||||
assert system_message in result
|
||||
|
||||
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
|
||||
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker: MockerFixture):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=None,
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
@ -25,7 +26,7 @@ class DummyRunner(CotAgentRunner):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
def runner(mocker: MockerFixture):
|
||||
# Prevent BaseAgentRunner __init__ from hitting database
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
|
||||
@ -165,7 +166,7 @@ class TestHandleInvokeAction:
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker):
|
||||
def test_tool_with_json_string_args(self, runner, mocker: MockerFixture):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
@ -180,7 +181,7 @@ class TestHandleInvokeAction:
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker):
|
||||
def test_empty_history(self, runner, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
@ -190,7 +191,7 @@ class TestOrganizeHistoricPromptMessages:
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker):
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -202,7 +203,7 @@ class TestRun:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker):
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -223,7 +224,7 @@ class TestRun:
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker):
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker: MockerFixture):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
@ -245,7 +246,7 @@ class TestRun:
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker):
|
||||
def test_run_basic_flow(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -257,7 +258,7 @@ class TestRun:
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker):
|
||||
def test_run_max_iteration_error(self, runner, mocker: MockerFixture):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
@ -272,7 +273,7 @@ class TestRun:
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker):
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
@ -329,7 +330,7 @@ class TestRun:
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker):
|
||||
def test_run_when_no_action_branch(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -341,7 +342,7 @@ class TestRun:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker):
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -354,7 +355,7 @@ class TestRun:
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker):
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -410,7 +411,7 @@ class TestRun:
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker):
|
||||
def test_init_react_state_resets_state(self, runner, mocker: MockerFixture):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
@ -423,7 +424,7 @@ class TestInitReactState:
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker):
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker: MockerFixture):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
@ -457,7 +458,7 @@ class TestFillInputsEdgeCases:
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker):
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker: MockerFixture):
|
||||
from graphon.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
@ -480,7 +481,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker):
|
||||
def test_agent_history_transform_invocation(self, runner, mocker: MockerFixture):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
@ -495,7 +496,7 @@ class TestOrganizeHistoricPromptMessagesExtended:
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -507,7 +508,7 @@ class TestRunAdditionalBranches:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker):
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -521,7 +522,7 @@ class TestRunAdditionalBranches:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker):
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
@ -535,7 +536,7 @@ class TestRunAdditionalBranches:
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker):
|
||||
def test_run_with_string_final_answer(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
@ -55,7 +56,7 @@ def runner():
|
||||
|
||||
|
||||
class TestOrganizeSystemPrompt:
|
||||
def test_organize_system_prompt_success(self, runner, mocker):
|
||||
def test_organize_system_prompt_success(self, runner, mocker: MockerFixture):
|
||||
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
|
||||
|
||||
@ -154,7 +155,7 @@ class TestOrganizeUserQuery:
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_no_scratchpad(self, runner, mocker):
|
||||
def test_no_scratchpad(self, runner, mocker: MockerFixture):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
@ -164,7 +165,7 @@ class TestOrganizePromptMessages:
|
||||
assert "query" in result
|
||||
runner._organize_historic_prompt_messages.assert_called_once()
|
||||
|
||||
def test_with_final_scratchpad(self, runner, mocker):
|
||||
def test_with_final_scratchpad(self, runner, mocker: MockerFixture):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
@ -177,7 +178,7 @@ class TestOrganizePromptMessages:
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Final Answer: done" in combined
|
||||
|
||||
def test_with_thought_action_observation(self, runner, mocker):
|
||||
def test_with_thought_action_observation(self, runner, mocker: MockerFixture):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
@ -197,7 +198,7 @@ class TestOrganizePromptMessages:
|
||||
assert "Action: action" in combined
|
||||
assert "Observation: observe" in combined
|
||||
|
||||
def test_multiple_units_mixed(self, runner, mocker):
|
||||
def test_multiple_units_mixed(self, runner, mocker: MockerFixture):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -74,7 +75,7 @@ class TestOrganizeInstructionPrompt:
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker):
|
||||
def test_with_user_and_assistant_string(self, runner, mocker: MockerFixture):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
@ -89,7 +90,7 @@ class TestOrganizeHistoricPrompt:
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker):
|
||||
def test_assistant_list_with_text_content(self, runner, mocker: MockerFixture):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
@ -103,7 +104,7 @@ class TestOrganizeHistoricPrompt:
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker: MockerFixture):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
@ -116,7 +117,7 @@ class TestOrganizeHistoricPrompt:
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker):
|
||||
def test_empty_history(self, runner, mocker: MockerFixture):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
@ -136,7 +137,7 @@ class TestOrganizePromptMessages:
|
||||
def test_full_flow_with_scratchpad(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
mocker: MockerFixture,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
@ -171,7 +172,12 @@ class TestOrganizePromptMessages:
|
||||
assert "Question: What is Python?" in content
|
||||
|
||||
def test_no_scratchpad(
|
||||
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
self,
|
||||
runner,
|
||||
mocker: MockerFixture,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
@ -198,7 +204,7 @@ class TestOrganizePromptMessages:
|
||||
def test_partial_scratchpad_units(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
mocker: MockerFixture,
|
||||
thought,
|
||||
action,
|
||||
observation,
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
@ -68,7 +69,7 @@ class DummyResult:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
def runner(mocker: MockerFixture):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
|
||||
@ -230,7 +231,7 @@ class TestOrganizeUserQuery:
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker):
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker: MockerFixture):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
@ -352,7 +353,7 @@ class TestRunMethod:
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
@ -398,7 +399,7 @@ class TestRunMethod:
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker):
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker: MockerFixture):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
|
||||
@ -9,6 +9,7 @@ mocking; ensure entity invariants and validation rules remain stable.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.agent.plugin_entities import (
|
||||
AgentFeature,
|
||||
@ -28,12 +29,12 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_identity(mocker):
|
||||
def mock_identity(mocker: MockerFixture):
|
||||
return mocker.MagicMock(spec=AgentStrategyIdentity)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_identity(mocker):
|
||||
def mock_provider_identity(mocker: MockerFixture):
|
||||
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
|
||||
|
||||
|
||||
@ -47,7 +48,7 @@ class TestAgentStrategyParameterType:
|
||||
"enum_member",
|
||||
list(AgentStrategyParameter.AgentStrategyParameterType),
|
||||
)
|
||||
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
|
||||
def test_as_normal_type_calls_external_function(self, mocker: MockerFixture, enum_member) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
return_value="normalized",
|
||||
@ -58,7 +59,7 @@ class TestAgentStrategyParameterType:
|
||||
mock_func.assert_called_once_with(enum_member)
|
||||
assert result == "normalized"
|
||||
|
||||
def test_as_normal_type_propagates_exception(self, mocker) -> None:
|
||||
def test_as_normal_type_propagates_exception(self, mocker: MockerFixture) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
@ -79,7 +80,7 @@ class TestAgentStrategyParameterType:
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
|
||||
],
|
||||
)
|
||||
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
|
||||
def test_cast_value_calls_external_function(self, mocker: MockerFixture, enum_member, value) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
return_value="casted",
|
||||
@ -90,7 +91,7 @@ class TestAgentStrategyParameterType:
|
||||
mock_func.assert_called_once_with(enum_member, value)
|
||||
assert result == "casted"
|
||||
|
||||
def test_cast_value_propagates_exception(self, mocker) -> None:
|
||||
def test_cast_value_propagates_exception(self, mocker: MockerFixture) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
@ -136,7 +137,7 @@ class TestAgentStrategyParameter:
|
||||
|
||||
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
|
||||
|
||||
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
|
||||
def test_init_frontend_parameter_calls_external(self, mocker: MockerFixture) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
return_value="frontend",
|
||||
@ -153,7 +154,7 @@ class TestAgentStrategyParameter:
|
||||
mock_func.assert_called_once_with(param, param.type, "value")
|
||||
assert result == "frontend"
|
||||
|
||||
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
|
||||
def test_init_frontend_parameter_propagates_exception(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
side_effect=RuntimeError("error"),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user